* feat: Prompt progress when streaming
- BE changes:
- Add a `return_progress` flag to `chatCompletionRequest` and a corresponding `prompt_progress` payload in `chatCompletionChunk`. Introduce `chatCompletionPromptProgress` interface to capture cache, processed, time, and total token counts.
- Update the Llamacpp extension to always request progress data when streaming, enabling UI components to display real‑time generation progress and leverage llama.cpp’s built‑in progress reporting.
* Make return_progress optional
* chore: update ui prompt progress before streaming content
* chore: remove log
* chore: remove progress when percentage >= 100
* chore: set timeout prompt progress
* chore: move prompt progress outside streaming content
* fix: tests
---------
Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
152 lines
4.3 KiB
TypeScript
152 lines
4.3 KiB
TypeScript
import { create } from 'zustand'
|
|
import { ThreadMessage } from '@janhq/core'
|
|
import { MCPTool } from '@/types/completion'
|
|
import { useAssistant } from './useAssistant'
|
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
|
|
|
type PromptProgress = {
|
|
cache: number
|
|
processed: number
|
|
time_ms: number
|
|
total: number
|
|
}
|
|
|
|
type AppErrorMessage = {
|
|
message?: string
|
|
title?: string
|
|
subtitle: string
|
|
}
|
|
|
|
type AppState = {
|
|
streamingContent?: ThreadMessage
|
|
loadingModel?: boolean
|
|
tools: MCPTool[]
|
|
serverStatus: 'running' | 'stopped' | 'pending'
|
|
abortControllers: Record<string, AbortController>
|
|
tokenSpeed?: TokenSpeed
|
|
currentToolCall?: ChatCompletionMessageToolCall
|
|
showOutOfContextDialog?: boolean
|
|
errorMessage?: AppErrorMessage
|
|
promptProgress?: PromptProgress
|
|
cancelToolCall?: () => void
|
|
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
|
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
|
updateCurrentToolCall: (
|
|
toolCall: ChatCompletionMessageToolCall | undefined
|
|
) => void
|
|
updateLoadingModel: (loading: boolean) => void
|
|
updateTools: (tools: MCPTool[]) => void
|
|
setAbortController: (threadId: string, controller: AbortController) => void
|
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
|
resetTokenSpeed: () => void
|
|
setOutOfContextDialog: (show: boolean) => void
|
|
setCancelToolCall: (cancel: (() => void) | undefined) => void
|
|
setErrorMessage: (error: AppErrorMessage | undefined) => void
|
|
updatePromptProgress: (progress: PromptProgress | undefined) => void
|
|
}
|
|
|
|
export const useAppState = create<AppState>()((set) => ({
|
|
streamingContent: undefined,
|
|
loadingModel: false,
|
|
tools: [],
|
|
serverStatus: 'stopped',
|
|
abortControllers: {},
|
|
tokenSpeed: undefined,
|
|
currentToolCall: undefined,
|
|
promptProgress: undefined,
|
|
cancelToolCall: undefined,
|
|
updateStreamingContent: (content: ThreadMessage | undefined) => {
|
|
const assistants = useAssistant.getState().assistants
|
|
const currentAssistant = useAssistant.getState().currentAssistant
|
|
|
|
const selectedAssistant =
|
|
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
|
|
|
|
set(() => ({
|
|
streamingContent: content
|
|
? {
|
|
...content,
|
|
created_at: content.created_at || Date.now(),
|
|
metadata: {
|
|
...content.metadata,
|
|
assistant: selectedAssistant,
|
|
},
|
|
}
|
|
: undefined,
|
|
}))
|
|
},
|
|
updateCurrentToolCall: (toolCall) => {
|
|
set(() => ({
|
|
currentToolCall: toolCall,
|
|
}))
|
|
},
|
|
updateLoadingModel: (loading) => {
|
|
set({ loadingModel: loading })
|
|
},
|
|
updateTools: (tools) => {
|
|
set({ tools })
|
|
},
|
|
setServerStatus: (value) => set({ serverStatus: value }),
|
|
setAbortController: (threadId, controller) => {
|
|
set((state) => ({
|
|
abortControllers: {
|
|
...state.abortControllers,
|
|
[threadId]: controller,
|
|
},
|
|
}))
|
|
},
|
|
updateTokenSpeed: (message, increment = 1) =>
|
|
set((state) => {
|
|
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
|
if (!state.tokenSpeed) {
|
|
// If this is the first update, just set the lastTimestamp and return
|
|
return {
|
|
tokenSpeed: {
|
|
lastTimestamp: currentTimestamp,
|
|
tokenSpeed: 0,
|
|
tokenCount: increment,
|
|
message: message.id,
|
|
},
|
|
}
|
|
}
|
|
|
|
const timeDiffInSeconds =
|
|
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds
|
|
const totalTokenCount = state.tokenSpeed.tokenCount + increment
|
|
const averageTokenSpeed =
|
|
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
|
|
return {
|
|
tokenSpeed: {
|
|
...state.tokenSpeed,
|
|
tokenSpeed: averageTokenSpeed,
|
|
tokenCount: totalTokenCount,
|
|
message: message.id,
|
|
},
|
|
}
|
|
}),
|
|
resetTokenSpeed: () =>
|
|
set({
|
|
tokenSpeed: undefined,
|
|
}),
|
|
setOutOfContextDialog: (show) => {
|
|
set(() => ({
|
|
showOutOfContextDialog: show,
|
|
}))
|
|
},
|
|
setCancelToolCall: (cancel) => {
|
|
set(() => ({
|
|
cancelToolCall: cancel,
|
|
}))
|
|
},
|
|
setErrorMessage: (error) => {
|
|
set(() => ({
|
|
errorMessage: error,
|
|
}))
|
|
},
|
|
updatePromptProgress: (progress) => {
|
|
set(() => ({
|
|
promptProgress: progress,
|
|
}))
|
|
},
|
|
}))
|