jan/web-app/src/hooks/useAppState.ts
Akarshan Biswas bf7f176741
feat: Prompt progress when streaming (#6503)
* feat: Prompt progress when streaming

- BE changes:
    - Add a `return_progress` flag to `chatCompletionRequest` and a corresponding `prompt_progress` payload in `chatCompletionChunk`. Introduce `chatCompletionPromptProgress` interface to capture cache, processed, time, and total token counts.
    - Update the Llamacpp extension to always request progress data when streaming, enabling UI components to display real‑time generation progress and leverage llama.cpp’s built‑in progress reporting.

* Make return_progress optional

* chore: update ui prompt progress before streaming content

* chore: remove log

* chore: remove progress when percentage >= 100

* chore: set timeout prompt progress

* chore: move prompt progress outside streaming content

* fix: tests

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
2025-09-22 20:37:27 +05:30

152 lines
4.3 KiB
TypeScript

import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core'
import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources'
type PromptProgress = {
cache: number
processed: number
time_ms: number
total: number
}
type AppErrorMessage = {
message?: string
title?: string
subtitle: string
}
type AppState = {
streamingContent?: ThreadMessage
loadingModel?: boolean
tools: MCPTool[]
serverStatus: 'running' | 'stopped' | 'pending'
abortControllers: Record<string, AbortController>
tokenSpeed?: TokenSpeed
currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean
errorMessage?: AppErrorMessage
promptProgress?: PromptProgress
cancelToolCall?: () => void
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void
updateCurrentToolCall: (
toolCall: ChatCompletionMessageToolCall | undefined
) => void
updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
setAbortController: (threadId: string, controller: AbortController) => void
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
resetTokenSpeed: () => void
setOutOfContextDialog: (show: boolean) => void
setCancelToolCall: (cancel: (() => void) | undefined) => void
setErrorMessage: (error: AppErrorMessage | undefined) => void
updatePromptProgress: (progress: PromptProgress | undefined) => void
}
export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined,
loadingModel: false,
tools: [],
serverStatus: 'stopped',
abortControllers: {},
tokenSpeed: undefined,
currentToolCall: undefined,
promptProgress: undefined,
cancelToolCall: undefined,
updateStreamingContent: (content: ThreadMessage | undefined) => {
const assistants = useAssistant.getState().assistants
const currentAssistant = useAssistant.getState().currentAssistant
const selectedAssistant =
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
set(() => ({
streamingContent: content
? {
...content,
created_at: content.created_at || Date.now(),
metadata: {
...content.metadata,
assistant: selectedAssistant,
},
}
: undefined,
}))
},
updateCurrentToolCall: (toolCall) => {
set(() => ({
currentToolCall: toolCall,
}))
},
updateLoadingModel: (loading) => {
set({ loadingModel: loading })
},
updateTools: (tools) => {
set({ tools })
},
setServerStatus: (value) => set({ serverStatus: value }),
setAbortController: (threadId, controller) => {
set((state) => ({
abortControllers: {
...state.abortControllers,
[threadId]: controller,
},
}))
},
updateTokenSpeed: (message, increment = 1) =>
set((state) => {
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!state.tokenSpeed) {
// If this is the first update, just set the lastTimestamp and return
return {
tokenSpeed: {
lastTimestamp: currentTimestamp,
tokenSpeed: 0,
tokenCount: increment,
message: message.id,
},
}
}
const timeDiffInSeconds =
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + increment
const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
return {
tokenSpeed: {
...state.tokenSpeed,
tokenSpeed: averageTokenSpeed,
tokenCount: totalTokenCount,
message: message.id,
},
}
}),
resetTokenSpeed: () =>
set({
tokenSpeed: undefined,
}),
setOutOfContextDialog: (show) => {
set(() => ({
showOutOfContextDialog: show,
}))
},
setCancelToolCall: (cancel) => {
set(() => ({
cancelToolCall: cancel,
}))
},
setErrorMessage: (error) => {
set(() => ({
errorMessage: error,
}))
},
updatePromptProgress: (progress) => {
set(() => ({
promptProgress: progress,
}))
},
}))