feat: Prompt progress when streaming (#6503)
* feat: Prompt progress when streaming
- BE changes:
- Add a `return_progress` flag to `chatCompletionRequest` and a corresponding `prompt_progress` payload in `chatCompletionChunk`. Introduce `chatCompletionPromptProgress` interface to capture cache, processed, time, and total token counts.
- Update the Llamacpp extension to always request progress data when streaming, enabling UI components to display real‑time generation progress and leverage llama.cpp’s built‑in progress reporting.
* Make return_progress optional
* chore: update ui prompt progress before streaming content
* chore: remove log
* chore: remove progress when percentage >= 100
* chore: set timeout prompt progress
* chore: move prompt progress outside streaming content
* fix: tests
---------
Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
This commit is contained in:
parent
e1294cdc30
commit
bf7f176741
@ -54,6 +54,7 @@ export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec
|
||||
export interface chatCompletionRequest {
|
||||
model: string // Model ID, though for local it might be implicit via sessionInfo
|
||||
messages: chatCompletionRequestMessage[]
|
||||
return_progress?: boolean
|
||||
tools?: Tool[]
|
||||
tool_choice?: ToolChoice
|
||||
// Core sampling parameters
|
||||
@ -119,6 +120,13 @@ export interface chatCompletionChunkChoice {
|
||||
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null
|
||||
}
|
||||
|
||||
export interface chatCompletionPromptProgress {
|
||||
cache: number
|
||||
processed: number
|
||||
time_ms: number
|
||||
total: number
|
||||
}
|
||||
|
||||
export interface chatCompletionChunk {
|
||||
id: string
|
||||
object: 'chat.completion.chunk'
|
||||
@ -126,6 +134,7 @@ export interface chatCompletionChunk {
|
||||
model: string
|
||||
choices: chatCompletionChunkChoice[]
|
||||
system_fingerprint?: string
|
||||
prompt_progress?: chatCompletionPromptProgress
|
||||
}
|
||||
|
||||
export interface chatCompletionChoice {
|
||||
|
||||
@ -1802,6 +1802,13 @@ export default class llamacpp_extension extends AIEngine {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${sessionInfo.api_key}`,
|
||||
}
|
||||
// always enable prompt progress return if stream is true
|
||||
// Requires llamacpp version > b6399
|
||||
// Example json returned from server
|
||||
// {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1758113912,"id":"chatcmpl-UwZwgxQKyJMo7WzMzXlsi90YTUK2BJro","model":"qwen","system_fingerprint":"b1-e4912fc","object":"chat.completion.chunk","prompt_progress":{"total":36,"cache":0,"processed":36,"time_ms":5706760300}}
|
||||
// (chunk.prompt_progress?.processed / chunk.prompt_progress?.total) * 100
|
||||
// chunk.prompt_progress?.cache is for past tokens already in kv cache
|
||||
opts.return_progress = true
|
||||
|
||||
const body = JSON.stringify(opts)
|
||||
if (opts.stream) {
|
||||
|
||||
27
web-app/src/components/PromptProgress.tsx
Normal file
27
web-app/src/components/PromptProgress.tsx
Normal file
@ -0,0 +1,27 @@
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
|
||||
export function PromptProgress() {
|
||||
const promptProgress = useAppState((state) => state.promptProgress)
|
||||
|
||||
const percentage =
|
||||
promptProgress && promptProgress.total > 0
|
||||
? Math.round((promptProgress.processed / promptProgress.total) * 100)
|
||||
: 0
|
||||
|
||||
// Show progress only when promptProgress exists and has valid data, and not completed
|
||||
if (
|
||||
!promptProgress ||
|
||||
!promptProgress.total ||
|
||||
promptProgress.total <= 0 ||
|
||||
percentage >= 100
|
||||
) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground mb-2">
|
||||
<div className="animate-spin rounded-full h-4 w-4 border-b-2 border-primary"></div>
|
||||
<span>Reading: {percentage}%</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
71
web-app/src/components/__tests__/PromptProgress.test.tsx
Normal file
71
web-app/src/components/__tests__/PromptProgress.test.tsx
Normal file
@ -0,0 +1,71 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { PromptProgress } from '../PromptProgress'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
|
||||
// Mock the useAppState hook
|
||||
vi.mock('@/hooks/useAppState', () => ({
|
||||
useAppState: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockUseAppState = useAppState as ReturnType<typeof vi.fn>
|
||||
|
||||
describe('PromptProgress', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should not render when promptProgress is undefined', () => {
|
||||
mockUseAppState.mockReturnValue(undefined)
|
||||
|
||||
const { container } = render(<PromptProgress />)
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should render progress when promptProgress is available', () => {
|
||||
const mockProgress = {
|
||||
cache: 0,
|
||||
processed: 50,
|
||||
time_ms: 1000,
|
||||
total: 100,
|
||||
}
|
||||
|
||||
mockUseAppState.mockReturnValue(mockProgress)
|
||||
|
||||
render(<PromptProgress />)
|
||||
|
||||
expect(screen.getByText('Reading: 50%')).toBeInTheDocument()
|
||||
expect(document.querySelector('.animate-spin')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should calculate percentage correctly', () => {
|
||||
const mockProgress = {
|
||||
cache: 0,
|
||||
processed: 75,
|
||||
time_ms: 1500,
|
||||
total: 150,
|
||||
}
|
||||
|
||||
mockUseAppState.mockReturnValue(mockProgress)
|
||||
|
||||
render(<PromptProgress />)
|
||||
|
||||
expect(screen.getByText('Reading: 50%')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle zero total gracefully', () => {
|
||||
const mockProgress = {
|
||||
cache: 0,
|
||||
processed: 0,
|
||||
time_ms: 0,
|
||||
total: 0,
|
||||
}
|
||||
|
||||
mockUseAppState.mockReturnValue(mockProgress)
|
||||
|
||||
const { container } = render(<PromptProgress />)
|
||||
|
||||
// Component should not render when total is 0
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
})
|
||||
@ -72,7 +72,11 @@ export const ThreadContent = memo(
|
||||
|
||||
streamTools?: any
|
||||
contextOverflowModal?: React.ReactNode | null
|
||||
updateMessage?: (item: ThreadMessage, message: string, imageUrls?: string[]) => void
|
||||
updateMessage?: (
|
||||
item: ThreadMessage,
|
||||
message: string,
|
||||
imageUrls?: string[]
|
||||
) => void
|
||||
}
|
||||
) => {
|
||||
const { t } = useTranslation()
|
||||
@ -281,7 +285,12 @@ export const ThreadContent = memo(
|
||||
item.content?.find((c) => c.type === 'text')?.text?.value ||
|
||||
''
|
||||
}
|
||||
imageUrls={item.content?.filter((c) => c.type === 'image_url' && c.image_url?.url).map((c) => c.image_url!.url).filter((url): url is string => url !== undefined) || []}
|
||||
imageUrls={
|
||||
item.content
|
||||
?.filter((c) => c.type === 'image_url' && c.image_url?.url)
|
||||
.map((c) => c.image_url!.url)
|
||||
.filter((url): url is string => url !== undefined) || []
|
||||
}
|
||||
onSave={(message, imageUrls) => {
|
||||
if (item.updateMessage) {
|
||||
item.updateMessage(item, message, imageUrls)
|
||||
@ -397,7 +406,9 @@ export const ThreadContent = memo(
|
||||
</div>
|
||||
|
||||
<TokenSpeedIndicator
|
||||
streaming={Boolean(item.isLastMessage && isStreamingThisThread)}
|
||||
streaming={Boolean(
|
||||
item.isLastMessage && isStreamingThisThread
|
||||
)}
|
||||
metadata={item.metadata}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@ -35,6 +35,7 @@ vi.mock('../../hooks/useAppState', () => ({
|
||||
resetTokenSpeed: vi.fn(),
|
||||
updateTools: vi.fn(),
|
||||
updateStreamingContent: vi.fn(),
|
||||
updatePromptProgress: vi.fn(),
|
||||
updateLoadingModel: vi.fn(),
|
||||
setAbortController: vi.fn(),
|
||||
}
|
||||
@ -106,7 +107,11 @@ vi.mock('../../hooks/useMessages', () => ({
|
||||
|
||||
vi.mock('../../hooks/useToolApproval', () => ({
|
||||
useToolApproval: (selector: any) => {
|
||||
const state = { approvedTools: [], showApprovalModal: vi.fn(), allowAllMCPPermissions: false }
|
||||
const state = {
|
||||
approvedTools: [],
|
||||
showApprovalModal: vi.fn(),
|
||||
allowAllMCPPermissions: false,
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
}))
|
||||
@ -132,14 +137,24 @@ vi.mock('@tanstack/react-router', () => ({
|
||||
vi.mock('@/lib/completion', () => ({
|
||||
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
||||
extractToolCall: vi.fn(),
|
||||
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
|
||||
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
|
||||
sendCompletion: vi.fn(() => Promise.resolve({ choices: [{ message: { content: '' } }] })),
|
||||
newUserThreadContent: vi.fn(() => ({
|
||||
thread_id: 'test-thread',
|
||||
content: 'user message',
|
||||
})),
|
||||
newAssistantThreadContent: vi.fn(() => ({
|
||||
thread_id: 'test-thread',
|
||||
content: 'assistant message',
|
||||
})),
|
||||
sendCompletion: vi.fn(() =>
|
||||
Promise.resolve({ choices: [{ message: { content: '' } }] })
|
||||
),
|
||||
postMessageProcessing: vi.fn(),
|
||||
isCompletionResponse: vi.fn(() => true),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/mcp', () => ({ getTools: vi.fn(() => Promise.resolve([])) }))
|
||||
vi.mock('@/services/mcp', () => ({
|
||||
getTools: vi.fn(() => Promise.resolve([])),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/models', () => ({
|
||||
startModel: vi.fn(() => Promise.resolve()),
|
||||
@ -147,9 +162,13 @@ vi.mock('@/services/models', () => ({
|
||||
stopAllModels: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/providers', () => ({ updateSettings: vi.fn(() => Promise.resolve()) }))
|
||||
vi.mock('@/services/providers', () => ({
|
||||
updateSettings: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@tauri-apps/api/event', () => ({ listen: vi.fn(() => Promise.resolve(vi.fn())) }))
|
||||
vi.mock('@tauri-apps/api/event', () => ({
|
||||
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/useServiceHub', () => ({
|
||||
useServiceHub: () => ({
|
||||
|
||||
@ -25,6 +25,7 @@ vi.mock('../useAppState', () => ({
|
||||
resetTokenSpeed: vi.fn(),
|
||||
updateTools: vi.fn(),
|
||||
updateStreamingContent: vi.fn(),
|
||||
updatePromptProgress: vi.fn(),
|
||||
updateLoadingModel: vi.fn(),
|
||||
setAbortController: vi.fn(),
|
||||
}
|
||||
|
||||
@ -4,6 +4,13 @@ import { MCPTool } from '@/types/completion'
|
||||
import { useAssistant } from './useAssistant'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
type PromptProgress = {
|
||||
cache: number
|
||||
processed: number
|
||||
time_ms: number
|
||||
total: number
|
||||
}
|
||||
|
||||
type AppErrorMessage = {
|
||||
message?: string
|
||||
title?: string
|
||||
@ -20,6 +27,7 @@ type AppState = {
|
||||
currentToolCall?: ChatCompletionMessageToolCall
|
||||
showOutOfContextDialog?: boolean
|
||||
errorMessage?: AppErrorMessage
|
||||
promptProgress?: PromptProgress
|
||||
cancelToolCall?: () => void
|
||||
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||
@ -34,6 +42,7 @@ type AppState = {
|
||||
setOutOfContextDialog: (show: boolean) => void
|
||||
setCancelToolCall: (cancel: (() => void) | undefined) => void
|
||||
setErrorMessage: (error: AppErrorMessage | undefined) => void
|
||||
updatePromptProgress: (progress: PromptProgress | undefined) => void
|
||||
}
|
||||
|
||||
export const useAppState = create<AppState>()((set) => ({
|
||||
@ -44,6 +53,7 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
abortControllers: {},
|
||||
tokenSpeed: undefined,
|
||||
currentToolCall: undefined,
|
||||
promptProgress: undefined,
|
||||
cancelToolCall: undefined,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => {
|
||||
const assistants = useAssistant.getState().assistants
|
||||
@ -133,4 +143,9 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
errorMessage: error,
|
||||
}))
|
||||
},
|
||||
updatePromptProgress: (progress) => {
|
||||
set(() => ({
|
||||
promptProgress: progress,
|
||||
}))
|
||||
},
|
||||
}))
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { flushSync } from 'react-dom'
|
||||
import { usePrompt } from './usePrompt'
|
||||
import { useModelProvider } from './useModelProvider'
|
||||
import { useThreads } from './useThreads'
|
||||
@ -49,6 +50,9 @@ export const useChat = () => {
|
||||
state.setAbortController,
|
||||
])
|
||||
)
|
||||
const updatePromptProgress = useAppState(
|
||||
(state) => state.updatePromptProgress
|
||||
)
|
||||
|
||||
const updateProvider = useModelProvider((state) => state.updateProvider)
|
||||
const serviceHub = useServiceHub()
|
||||
@ -229,6 +233,7 @@ export const useChat = () => {
|
||||
const abortController = new AbortController()
|
||||
setAbortController(activeThread.id, abortController)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
// Do not add new message on retry
|
||||
if (troubleshooting)
|
||||
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||
@ -397,6 +402,16 @@ export const useChat = () => {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
@ -513,6 +528,7 @@ export const useChat = () => {
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
@ -534,6 +550,7 @@ export const useChat = () => {
|
||||
} finally {
|
||||
updateLoadingModel(false)
|
||||
updateStreamingContent(undefined)
|
||||
updatePromptProgress(undefined)
|
||||
}
|
||||
},
|
||||
[
|
||||
@ -543,6 +560,7 @@ export const useChat = () => {
|
||||
getMessages,
|
||||
setAbortController,
|
||||
updateStreamingContent,
|
||||
updatePromptProgress,
|
||||
addMessage,
|
||||
updateThreadTimestamp,
|
||||
updateLoadingModel,
|
||||
|
||||
@ -20,6 +20,7 @@ import { useSmallScreen } from '@/hooks/useMediaQuery'
|
||||
import { PlatformFeatures } from '@/lib/platform/const'
|
||||
import { PlatformFeature } from '@/lib/platform/types'
|
||||
import ScrollToBottom from '@/containers/ScrollToBottom'
|
||||
import { PromptProgress } from '@/components/PromptProgress'
|
||||
|
||||
// as route.threadsDetail
|
||||
export const Route = createFileRoute('/threads/$threadId')({
|
||||
@ -170,6 +171,7 @@ function ThreadDetail() {
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
<PromptProgress />
|
||||
<StreamingContent
|
||||
threadId={threadId}
|
||||
data-test-id="thread-content-text"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user