diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 7a223e468..3236994b2 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -54,6 +54,7 @@ export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec export interface chatCompletionRequest { model: string // Model ID, though for local it might be implicit via sessionInfo messages: chatCompletionRequestMessage[] + return_progress?: boolean tools?: Tool[] tool_choice?: ToolChoice // Core sampling parameters @@ -119,6 +120,13 @@ export interface chatCompletionChunkChoice { finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null } +export interface chatCompletionPromptProgress { + cache: number + processed: number + time_ms: number + total: number +} + export interface chatCompletionChunk { id: string object: 'chat.completion.chunk' @@ -126,6 +134,7 @@ export interface chatCompletionChunk { model: string choices: chatCompletionChunkChoice[] system_fingerprint?: string + prompt_progress?: chatCompletionPromptProgress } export interface chatCompletionChoice { diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index b2ca7b9c7..c296e06af 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -1802,6 +1802,13 @@ export default class llamacpp_extension extends AIEngine { 'Content-Type': 'application/json', 'Authorization': `Bearer ${sessionInfo.api_key}`, } + // always enable prompt progress return if stream is true + // Requires llamacpp version > b6399 + // Example json returned from server + // {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1758113912,"id":"chatcmpl-UwZwgxQKyJMo7WzMzXlsi90YTUK2BJro","model":"qwen","system_fingerprint":"b1-e4912fc","object":"chat.completion.chunk","prompt_progress":{"total":36,"cache":0,"processed":36,"time_ms":5706760300}} + // (chunk.prompt_progress?.processed / chunk.prompt_progress?.total) * 100 + // chunk.prompt_progress?.cache is for past tokens already in kv cache + opts.return_progress = true const body = JSON.stringify(opts) if (opts.stream) { diff --git a/web-app/src/components/PromptProgress.tsx b/web-app/src/components/PromptProgress.tsx new file mode 100644 index 000000000..3d25f6a05 --- /dev/null +++ b/web-app/src/components/PromptProgress.tsx @@ -0,0 +1,27 @@ +import { useAppState } from '@/hooks/useAppState' + +export function PromptProgress() { + const promptProgress = useAppState((state) => state.promptProgress) + + const percentage = + promptProgress && promptProgress.total > 0 + ? Math.round((promptProgress.processed / promptProgress.total) * 100) + : 0 + + // Show progress only when promptProgress exists and has valid data, and not completed + if ( + !promptProgress || + !promptProgress.total || + promptProgress.total <= 0 || + percentage >= 100 + ) { + return null + } + + return ( +
+
+ Reading: {percentage}% +
+ ) +} diff --git a/web-app/src/components/__tests__/PromptProgress.test.tsx b/web-app/src/components/__tests__/PromptProgress.test.tsx new file mode 100644 index 000000000..829656d42 --- /dev/null +++ b/web-app/src/components/__tests__/PromptProgress.test.tsx @@ -0,0 +1,71 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { render, screen } from '@testing-library/react' +import { PromptProgress } from '../PromptProgress' +import { useAppState } from '@/hooks/useAppState' + +// Mock the useAppState hook +vi.mock('@/hooks/useAppState', () => ({ + useAppState: vi.fn(), +})) + +const mockUseAppState = useAppState as ReturnType + +describe('PromptProgress', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should not render when promptProgress is undefined', () => { + mockUseAppState.mockReturnValue(undefined) + + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should render progress when promptProgress is available', () => { + const mockProgress = { + cache: 0, + processed: 50, + time_ms: 1000, + total: 100, + } + + mockUseAppState.mockReturnValue(mockProgress) + + render() + + expect(screen.getByText('Reading: 50%')).toBeInTheDocument() + expect(document.querySelector('.animate-spin')).toBeInTheDocument() + }) + + it('should calculate percentage correctly', () => { + const mockProgress = { + cache: 0, + processed: 75, + time_ms: 1500, + total: 150, + } + + mockUseAppState.mockReturnValue(mockProgress) + + render() + + expect(screen.getByText('Reading: 50%')).toBeInTheDocument() + }) + + it('should handle zero total gracefully', () => { + const mockProgress = { + cache: 0, + processed: 0, + time_ms: 0, + total: 0, + } + + mockUseAppState.mockReturnValue(mockProgress) + + const { container } = render() + + // Component should not render when total is 0 + expect(container.firstChild).toBeNull() + }) +}) diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index e5ceebabb..2f83ad513 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -72,7 +72,11 @@ export const ThreadContent = memo( streamTools?: any contextOverflowModal?: React.ReactNode | null - updateMessage?: (item: ThreadMessage, message: string, imageUrls?: string[]) => void + updateMessage?: ( + item: ThreadMessage, + message: string, + imageUrls?: string[] + ) => void } ) => { const { t } = useTranslation() @@ -281,7 +285,12 @@ export const ThreadContent = memo( item.content?.find((c) => c.type === 'text')?.text?.value || '' } - imageUrls={item.content?.filter((c) => c.type === 'image_url' && c.image_url?.url).map((c) => c.image_url!.url).filter((url): url is string => url !== undefined) || []} + imageUrls={ + item.content + ?.filter((c) => c.type === 'image_url' && c.image_url?.url) + .map((c) => c.image_url!.url) + .filter((url): url is string => url !== undefined) || [] + } onSave={(message, imageUrls) => { if (item.updateMessage) { item.updateMessage(item, message, imageUrls) @@ -397,7 +406,9 @@ export const ThreadContent = memo( diff --git a/web-app/src/hooks/__tests__/useChat.instructions.test.ts b/web-app/src/hooks/__tests__/useChat.instructions.test.ts index a9a022752..6c58bb1bc 100644 --- a/web-app/src/hooks/__tests__/useChat.instructions.test.ts +++ b/web-app/src/hooks/__tests__/useChat.instructions.test.ts @@ -35,6 +35,7 @@ vi.mock('../../hooks/useAppState', () => ({ resetTokenSpeed: vi.fn(), updateTools: vi.fn(), updateStreamingContent: vi.fn(), + updatePromptProgress: vi.fn(), updateLoadingModel: vi.fn(), setAbortController: vi.fn(), } @@ -106,7 +107,11 @@ vi.mock('../../hooks/useMessages', () => ({ vi.mock('../../hooks/useToolApproval', () => ({ useToolApproval: (selector: any) => { - const state = { approvedTools: [], showApprovalModal: vi.fn(), allowAllMCPPermissions: false } + const state = { + approvedTools: [], + showApprovalModal: vi.fn(), + allowAllMCPPermissions: false, + } return selector ? selector(state) : state }, })) @@ -132,14 +137,24 @@ vi.mock('@tanstack/react-router', () => ({ vi.mock('@/lib/completion', () => ({ emptyThreadContent: { thread_id: 'test-thread', content: '' }, extractToolCall: vi.fn(), - newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })), - newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })), - sendCompletion: vi.fn(() => Promise.resolve({ choices: [{ message: { content: '' } }] })), + newUserThreadContent: vi.fn(() => ({ + thread_id: 'test-thread', + content: 'user message', + })), + newAssistantThreadContent: vi.fn(() => ({ + thread_id: 'test-thread', + content: 'assistant message', + })), + sendCompletion: vi.fn(() => + Promise.resolve({ choices: [{ message: { content: '' } }] }) + ), postMessageProcessing: vi.fn(), isCompletionResponse: vi.fn(() => true), })) -vi.mock('@/services/mcp', () => ({ getTools: vi.fn(() => Promise.resolve([])) })) +vi.mock('@/services/mcp', () => ({ + getTools: vi.fn(() => Promise.resolve([])), +})) vi.mock('@/services/models', () => ({ startModel: vi.fn(() => Promise.resolve()), @@ -147,9 +162,13 @@ vi.mock('@/services/models', () => ({ stopAllModels: vi.fn(() => Promise.resolve()), })) -vi.mock('@/services/providers', () => ({ updateSettings: vi.fn(() => Promise.resolve()) })) +vi.mock('@/services/providers', () => ({ + updateSettings: vi.fn(() => Promise.resolve()), +})) -vi.mock('@tauri-apps/api/event', () => ({ listen: vi.fn(() => Promise.resolve(vi.fn())) })) +vi.mock('@tauri-apps/api/event', () => ({ + listen: vi.fn(() => Promise.resolve(vi.fn())), +})) vi.mock('@/hooks/useServiceHub', () => ({ useServiceHub: () => ({ diff --git a/web-app/src/hooks/__tests__/useChat.test.ts b/web-app/src/hooks/__tests__/useChat.test.ts index 362c557c0..45d46eb53 100644 --- a/web-app/src/hooks/__tests__/useChat.test.ts +++ b/web-app/src/hooks/__tests__/useChat.test.ts @@ -25,6 +25,7 @@ vi.mock('../useAppState', () => ({ resetTokenSpeed: vi.fn(), updateTools: vi.fn(), updateStreamingContent: vi.fn(), + updatePromptProgress: vi.fn(), updateLoadingModel: vi.fn(), setAbortController: vi.fn(), } diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index 837ed8c38..0ed9491d7 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -4,6 +4,13 @@ import { MCPTool } from '@/types/completion' import { useAssistant } from './useAssistant' import { ChatCompletionMessageToolCall } from 'openai/resources' +type PromptProgress = { + cache: number + processed: number + time_ms: number + total: number +} + type AppErrorMessage = { message?: string title?: string @@ -20,6 +27,7 @@ type AppState = { currentToolCall?: ChatCompletionMessageToolCall showOutOfContextDialog?: boolean errorMessage?: AppErrorMessage + promptProgress?: PromptProgress cancelToolCall?: () => void setServerStatus: (value: 'running' | 'stopped' | 'pending') => void updateStreamingContent: (content: ThreadMessage | undefined) => void @@ -34,6 +42,7 @@ type AppState = { setOutOfContextDialog: (show: boolean) => void setCancelToolCall: (cancel: (() => void) | undefined) => void setErrorMessage: (error: AppErrorMessage | undefined) => void + updatePromptProgress: (progress: PromptProgress | undefined) => void } export const useAppState = create()((set) => ({ @@ -44,6 +53,7 @@ export const useAppState = create()((set) => ({ abortControllers: {}, tokenSpeed: undefined, currentToolCall: undefined, + promptProgress: undefined, cancelToolCall: undefined, updateStreamingContent: (content: ThreadMessage | undefined) => { const assistants = useAssistant.getState().assistants @@ -133,4 +143,9 @@ export const useAppState = create()((set) => ({ errorMessage: error, })) }, + updatePromptProgress: (progress) => { + set(() => ({ + promptProgress: progress, + })) + }, })) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 5f913c340..516a61b20 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -1,4 +1,5 @@ import { useCallback, useMemo } from 'react' +import { flushSync } from 'react-dom' import { usePrompt } from './usePrompt' import { useModelProvider } from './useModelProvider' import { useThreads } from './useThreads' @@ -49,6 +50,9 @@ export const useChat = () => { state.setAbortController, ]) ) + const updatePromptProgress = useAppState( + (state) => state.updatePromptProgress + ) const updateProvider = useModelProvider((state) => state.updateProvider) const serviceHub = useServiceHub() @@ -229,6 +233,7 @@ export const useChat = () => { const abortController = new AbortController() setAbortController(activeThread.id, abortController) updateStreamingContent(emptyThreadContent) + updatePromptProgress(undefined) // Do not add new message on retry if (troubleshooting) addMessage(newUserThreadContent(activeThread.id, message, attachments)) @@ -397,6 +402,16 @@ export const useChat = () => { break } + // Handle prompt progress if available + if ('prompt_progress' in part && part.prompt_progress) { + // Force immediate state update to ensure we see intermediate values + flushSync(() => { + updatePromptProgress(part.prompt_progress) + }) + // Add a small delay to make progress visible + await new Promise((resolve) => setTimeout(resolve, 100)) + } + // Error message if (!part.choices) { throw new Error( @@ -513,6 +528,7 @@ export const useChat = () => { ) addMessage(updatedMessage ?? finalContent) updateStreamingContent(emptyThreadContent) + updatePromptProgress(undefined) updateThreadTimestamp(activeThread.id) isCompleted = !toolCalls.length @@ -534,6 +550,7 @@ export const useChat = () => { } finally { updateLoadingModel(false) updateStreamingContent(undefined) + updatePromptProgress(undefined) } }, [ @@ -543,6 +560,7 @@ export const useChat = () => { getMessages, setAbortController, updateStreamingContent, + updatePromptProgress, addMessage, updateThreadTimestamp, updateLoadingModel, diff --git a/web-app/src/routes/threads/$threadId.tsx b/web-app/src/routes/threads/$threadId.tsx index f301bac62..49b1e20e6 100644 --- a/web-app/src/routes/threads/$threadId.tsx +++ b/web-app/src/routes/threads/$threadId.tsx @@ -20,6 +20,7 @@ import { useSmallScreen } from '@/hooks/useMediaQuery' import { PlatformFeatures } from '@/lib/platform/const' import { PlatformFeature } from '@/lib/platform/types' import ScrollToBottom from '@/containers/ScrollToBottom' +import { PromptProgress } from '@/components/PromptProgress' // as route.threadsDetail export const Route = createFileRoute('/threads/$threadId')({ @@ -170,6 +171,7 @@ function ThreadDetail() { ) })} +