feat: Prompt progress when streaming (#6503)

* feat: Prompt progress when streaming

- BE changes:
    - Add a `return_progress` flag to `chatCompletionRequest` and a corresponding `prompt_progress` payload in `chatCompletionChunk`. Introduce `chatCompletionPromptProgress` interface to capture cache, processed, time, and total token counts.
    - Update the Llamacpp extension to always request progress data when streaming, enabling UI components to display real‑time generation progress and leverage llama.cpp’s built‑in progress reporting.

* Make return_progress optional

* chore: update ui prompt progress before streaming content

* chore: remove log

* chore: remove progress when percentage >= 100

* chore: set timeout prompt progress

* chore: move prompt progress outside streaming content

* fix: tests

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
This commit is contained in:
Akarshan Biswas 2025-09-22 20:37:27 +05:30 committed by GitHub
parent e1294cdc30
commit bf7f176741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 190 additions and 10 deletions

View File

@ -54,6 +54,7 @@ export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec
export interface chatCompletionRequest {
model: string // Model ID, though for local it might be implicit via sessionInfo
messages: chatCompletionRequestMessage[]
return_progress?: boolean
tools?: Tool[]
tool_choice?: ToolChoice
// Core sampling parameters
@ -119,6 +120,13 @@ export interface chatCompletionChunkChoice {
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null
}
export interface chatCompletionPromptProgress {
cache: number
processed: number
time_ms: number
total: number
}
export interface chatCompletionChunk {
id: string
object: 'chat.completion.chunk'
@ -126,6 +134,7 @@ export interface chatCompletionChunk {
model: string
choices: chatCompletionChunkChoice[]
system_fingerprint?: string
prompt_progress?: chatCompletionPromptProgress
}
export interface chatCompletionChoice {

View File

@ -1802,6 +1802,13 @@ export default class llamacpp_extension extends AIEngine {
'Content-Type': 'application/json',
'Authorization': `Bearer ${sessionInfo.api_key}`,
}
// always enable prompt progress return if stream is true
// Requires llamacpp version > b6399
// Example json returned from server
// {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1758113912,"id":"chatcmpl-UwZwgxQKyJMo7WzMzXlsi90YTUK2BJro","model":"qwen","system_fingerprint":"b1-e4912fc","object":"chat.completion.chunk","prompt_progress":{"total":36,"cache":0,"processed":36,"time_ms":5706760300}}
// (chunk.prompt_progress?.processed / chunk.prompt_progress?.total) * 100
// chunk.prompt_progress?.cache is for past tokens already in kv cache
opts.return_progress = true
const body = JSON.stringify(opts)
if (opts.stream) {

View File

@ -0,0 +1,27 @@
import { useAppState } from '@/hooks/useAppState'
export function PromptProgress() {
const promptProgress = useAppState((state) => state.promptProgress)
const percentage =
promptProgress && promptProgress.total > 0
? Math.round((promptProgress.processed / promptProgress.total) * 100)
: 0
// Show progress only when promptProgress exists and has valid data, and not completed
if (
!promptProgress ||
!promptProgress.total ||
promptProgress.total <= 0 ||
percentage >= 100
) {
return null
}
return (
<div className="flex items-center gap-2 text-sm text-muted-foreground mb-2">
<div className="animate-spin rounded-full h-4 w-4 border-b-2 border-primary"></div>
<span>Reading: {percentage}%</span>
</div>
)
}

View File

@ -0,0 +1,71 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { render, screen } from '@testing-library/react'
import { PromptProgress } from '../PromptProgress'
import { useAppState } from '@/hooks/useAppState'
// Mock the useAppState hook
vi.mock('@/hooks/useAppState', () => ({
useAppState: vi.fn(),
}))
const mockUseAppState = useAppState as ReturnType<typeof vi.fn>
describe('PromptProgress', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should not render when promptProgress is undefined', () => {
mockUseAppState.mockReturnValue(undefined)
const { container } = render(<PromptProgress />)
expect(container.firstChild).toBeNull()
})
it('should render progress when promptProgress is available', () => {
const mockProgress = {
cache: 0,
processed: 50,
time_ms: 1000,
total: 100,
}
mockUseAppState.mockReturnValue(mockProgress)
render(<PromptProgress />)
expect(screen.getByText('Reading: 50%')).toBeInTheDocument()
expect(document.querySelector('.animate-spin')).toBeInTheDocument()
})
it('should calculate percentage correctly', () => {
const mockProgress = {
cache: 0,
processed: 75,
time_ms: 1500,
total: 150,
}
mockUseAppState.mockReturnValue(mockProgress)
render(<PromptProgress />)
expect(screen.getByText('Reading: 50%')).toBeInTheDocument()
})
it('should handle zero total gracefully', () => {
const mockProgress = {
cache: 0,
processed: 0,
time_ms: 0,
total: 0,
}
mockUseAppState.mockReturnValue(mockProgress)
const { container } = render(<PromptProgress />)
// Component should not render when total is 0
expect(container.firstChild).toBeNull()
})
})

View File

@ -72,7 +72,11 @@ export const ThreadContent = memo(
streamTools?: any
contextOverflowModal?: React.ReactNode | null
updateMessage?: (item: ThreadMessage, message: string, imageUrls?: string[]) => void
updateMessage?: (
item: ThreadMessage,
message: string,
imageUrls?: string[]
) => void
}
) => {
const { t } = useTranslation()
@ -281,7 +285,12 @@ export const ThreadContent = memo(
item.content?.find((c) => c.type === 'text')?.text?.value ||
''
}
imageUrls={item.content?.filter((c) => c.type === 'image_url' && c.image_url?.url).map((c) => c.image_url!.url).filter((url): url is string => url !== undefined) || []}
imageUrls={
item.content
?.filter((c) => c.type === 'image_url' && c.image_url?.url)
.map((c) => c.image_url!.url)
.filter((url): url is string => url !== undefined) || []
}
onSave={(message, imageUrls) => {
if (item.updateMessage) {
item.updateMessage(item, message, imageUrls)
@ -397,7 +406,9 @@ export const ThreadContent = memo(
</div>
<TokenSpeedIndicator
streaming={Boolean(item.isLastMessage && isStreamingThisThread)}
streaming={Boolean(
item.isLastMessage && isStreamingThisThread
)}
metadata={item.metadata}
/>
</div>

View File

@ -35,6 +35,7 @@ vi.mock('../../hooks/useAppState', () => ({
resetTokenSpeed: vi.fn(),
updateTools: vi.fn(),
updateStreamingContent: vi.fn(),
updatePromptProgress: vi.fn(),
updateLoadingModel: vi.fn(),
setAbortController: vi.fn(),
}
@ -106,7 +107,11 @@ vi.mock('../../hooks/useMessages', () => ({
vi.mock('../../hooks/useToolApproval', () => ({
useToolApproval: (selector: any) => {
const state = { approvedTools: [], showApprovalModal: vi.fn(), allowAllMCPPermissions: false }
const state = {
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
}
return selector ? selector(state) : state
},
}))
@ -132,14 +137,24 @@ vi.mock('@tanstack/react-router', () => ({
vi.mock('@/lib/completion', () => ({
emptyThreadContent: { thread_id: 'test-thread', content: '' },
extractToolCall: vi.fn(),
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
sendCompletion: vi.fn(() => Promise.resolve({ choices: [{ message: { content: '' } }] })),
newUserThreadContent: vi.fn(() => ({
thread_id: 'test-thread',
content: 'user message',
})),
newAssistantThreadContent: vi.fn(() => ({
thread_id: 'test-thread',
content: 'assistant message',
})),
sendCompletion: vi.fn(() =>
Promise.resolve({ choices: [{ message: { content: '' } }] })
),
postMessageProcessing: vi.fn(),
isCompletionResponse: vi.fn(() => true),
}))
vi.mock('@/services/mcp', () => ({ getTools: vi.fn(() => Promise.resolve([])) }))
vi.mock('@/services/mcp', () => ({
getTools: vi.fn(() => Promise.resolve([])),
}))
vi.mock('@/services/models', () => ({
startModel: vi.fn(() => Promise.resolve()),
@ -147,9 +162,13 @@ vi.mock('@/services/models', () => ({
stopAllModels: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/services/providers', () => ({ updateSettings: vi.fn(() => Promise.resolve()) }))
vi.mock('@/services/providers', () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}))
vi.mock('@tauri-apps/api/event', () => ({ listen: vi.fn(() => Promise.resolve(vi.fn())) }))
vi.mock('@tauri-apps/api/event', () => ({
listen: vi.fn(() => Promise.resolve(vi.fn())),
}))
vi.mock('@/hooks/useServiceHub', () => ({
useServiceHub: () => ({

View File

@ -25,6 +25,7 @@ vi.mock('../useAppState', () => ({
resetTokenSpeed: vi.fn(),
updateTools: vi.fn(),
updateStreamingContent: vi.fn(),
updatePromptProgress: vi.fn(),
updateLoadingModel: vi.fn(),
setAbortController: vi.fn(),
}

View File

@ -4,6 +4,13 @@ import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources'
type PromptProgress = {
cache: number
processed: number
time_ms: number
total: number
}
type AppErrorMessage = {
message?: string
title?: string
@ -20,6 +27,7 @@ type AppState = {
currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean
errorMessage?: AppErrorMessage
promptProgress?: PromptProgress
cancelToolCall?: () => void
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void
@ -34,6 +42,7 @@ type AppState = {
setOutOfContextDialog: (show: boolean) => void
setCancelToolCall: (cancel: (() => void) | undefined) => void
setErrorMessage: (error: AppErrorMessage | undefined) => void
updatePromptProgress: (progress: PromptProgress | undefined) => void
}
export const useAppState = create<AppState>()((set) => ({
@ -44,6 +53,7 @@ export const useAppState = create<AppState>()((set) => ({
abortControllers: {},
tokenSpeed: undefined,
currentToolCall: undefined,
promptProgress: undefined,
cancelToolCall: undefined,
updateStreamingContent: (content: ThreadMessage | undefined) => {
const assistants = useAssistant.getState().assistants
@ -133,4 +143,9 @@ export const useAppState = create<AppState>()((set) => ({
errorMessage: error,
}))
},
updatePromptProgress: (progress) => {
set(() => ({
promptProgress: progress,
}))
},
}))

View File

@ -1,4 +1,5 @@
import { useCallback, useMemo } from 'react'
import { flushSync } from 'react-dom'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
@ -49,6 +50,9 @@ export const useChat = () => {
state.setAbortController,
])
)
const updatePromptProgress = useAppState(
(state) => state.updatePromptProgress
)
const updateProvider = useModelProvider((state) => state.updateProvider)
const serviceHub = useServiceHub()
@ -229,6 +233,7 @@ export const useChat = () => {
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message, attachments))
@ -397,6 +402,16 @@ export const useChat = () => {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
@ -513,6 +528,7 @@ export const useChat = () => {
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
@ -534,6 +550,7 @@ export const useChat = () => {
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
updatePromptProgress(undefined)
}
},
[
@ -543,6 +560,7 @@ export const useChat = () => {
getMessages,
setAbortController,
updateStreamingContent,
updatePromptProgress,
addMessage,
updateThreadTimestamp,
updateLoadingModel,

View File

@ -20,6 +20,7 @@ import { useSmallScreen } from '@/hooks/useMediaQuery'
import { PlatformFeatures } from '@/lib/platform/const'
import { PlatformFeature } from '@/lib/platform/types'
import ScrollToBottom from '@/containers/ScrollToBottom'
import { PromptProgress } from '@/components/PromptProgress'
// as route.threadsDetail
export const Route = createFileRoute('/threads/$threadId')({
@ -170,6 +171,7 @@ function ThreadDetail() {
</div>
)
})}
<PromptProgress />
<StreamingContent
threadId={threadId}
data-test-id="thread-content-text"