From c1297570972c99934df13f6a36357ea4b9769410 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Mon, 20 Oct 2025 21:06:23 +0530 Subject: [PATCH] =?UTF-8?q?chore:=20Refactor=20chat=20flow=20=E2=80=93=20r?= =?UTF-8?q?emove=20loop,=20centralize=20tool=20handling,=20add=20step=20li?= =?UTF-8?q?mit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move the assistant‑loop logic out of `useChat` and into `postMessageProcessing`. * Eliminate the while‑loop that drove repeated completions; now a single completion is sent and subsequent tool calls are processed recursively. * Introduce early‑abort checks and guard against missing provider before proceeding. * Add `ReasoningProcessor` import and use it consistently for streaming reasoning chunks. * Add `ToolCallEntry` type and a global `toolStepCounter` to track and cap total tool steps (default 20) to prevent infinite loops. * Extend `postMessageProcessing` signature to accept thread, provider, tools, UI update callback, and max tool steps. * Update UI‑update logic to use a single `updateStreamingUI` callback and ensure RAF scheduling is cleaned up reliably. * Refactor token‑speed / progress handling, improve error handling for out‑of‑context situations, and tidy up code formatting. * Minor clean‑ups: const‑ify `availableTools`, remove unused variables, improve readability. --- web-app/src/hooks/useChat.ts | 567 +++++++++++++++++----------------- web-app/src/lib/completion.ts | 235 ++++++++++---- 2 files changed, 450 insertions(+), 352 deletions(-) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index a312ea061..60541a1a1 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -425,10 +425,8 @@ export const useChat = () => { // Using addUserMessage to respect legacy code. Should be using the userContent above. if (troubleshooting) builder.addUserMessage(userContent) - let isCompleted = false - // Filter tools based on model capabilities and available tools for this thread - let availableTools = selectedModel?.capabilities?.includes('tools') + const availableTools = selectedModel?.capabilities?.includes('tools') ? useAppState.getState().tools.filter((tool) => { const disabledTools = getDisabledToolsForThread(activeThread.id) return !disabledTools.includes(tool.name) @@ -436,13 +434,21 @@ export const useChat = () => { : [] // Check if proactive mode is enabled - const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false + const isProactiveMode = + selectedModel?.capabilities?.includes('proactive') ?? false // Proactive mode: Capture initial screenshot/snapshot before first LLM call - if (isProactiveMode && availableTools.length > 0 && !abortController.signal.aborted) { - console.log('Proactive mode: Capturing initial screenshots before LLM call') + if ( + isProactiveMode && + availableTools.length > 0 && + !abortController.signal.aborted + ) { + console.log( + 'Proactive mode: Capturing initial screenshots before LLM call' + ) try { - const initialScreenshots = await captureProactiveScreenshots(abortController) + const initialScreenshots = + await captureProactiveScreenshots(abortController) // Add initial screenshots to builder for (const screenshot of initialScreenshots) { @@ -456,131 +462,91 @@ export const useChat = () => { } } - let assistantLoopSteps = 0 + // The agent logic is now self-contained within postMessageProcessing. + // We no longer need a `while` loop here. - while ( - !isCompleted && - !abortController.signal.aborted && - activeProvider - ) { - const modelConfig = activeProvider.models.find( - (m) => m.id === selectedModel?.id - ) - assistantLoopSteps += 1 + if (abortController.signal.aborted || !activeProvider) return - const modelSettings = modelConfig?.settings - ? Object.fromEntries( - Object.entries(modelConfig.settings) - .filter( - ([key, value]) => - key !== 'ctx_len' && - key !== 'ngl' && - value.controller_props?.value !== undefined && - value.controller_props?.value !== null && - value.controller_props?.value !== '' - ) - .map(([key, value]) => [key, value.controller_props?.value]) - ) - : undefined + const modelConfig = activeProvider.models.find( + (m) => m.id === selectedModel?.id + ) - const completion = await sendCompletion( - activeThread, - activeProvider, - builder.getMessages(), - abortController, - availableTools, - currentAssistant?.parameters?.stream === false ? false : true, - { - ...modelSettings, - ...(currentAssistant?.parameters || {}), - } as unknown as Record - ) + const modelSettings = modelConfig?.settings + ? Object.fromEntries( + Object.entries(modelConfig.settings) + .filter( + ([key, value]) => + key !== 'ctx_len' && + key !== 'ngl' && + value.controller_props?.value !== undefined && + value.controller_props?.value !== null && + value.controller_props?.value !== '' + ) + .map(([key, value]) => [key, value.controller_props?.value]) + ) + : undefined - if (!completion) throw new Error('No completion received') - let accumulatedText = '' - const currentCall: ChatCompletionMessageToolCall | null = null - const toolCalls: ChatCompletionMessageToolCall[] = [] - const timeToFirstToken = Date.now() - let tokenUsage: CompletionUsage | undefined = undefined - try { - if (isCompletionResponse(completion)) { - const message = completion.choices[0]?.message - accumulatedText = (message?.content as string) || '' + const completion = await sendCompletion( + activeThread, + activeProvider, + builder.getMessages(), + abortController, + availableTools, + currentAssistant?.parameters?.stream === false ? false : true, + { + ...modelSettings, + ...(currentAssistant?.parameters || {}), + } as unknown as Record + ) - // Handle reasoning field if there is one - const reasoning = extractReasoningFromMessage(message) - if (reasoning) { - accumulatedText = - `${reasoning}` + accumulatedText - } + if (!completion) throw new Error('No completion received') + let accumulatedText = '' + const currentCall: ChatCompletionMessageToolCall | null = null + const toolCalls: ChatCompletionMessageToolCall[] = [] + const timeToFirstToken = Date.now() + let tokenUsage: CompletionUsage | undefined = undefined + try { + if (isCompletionResponse(completion)) { + const message = completion.choices[0]?.message + accumulatedText = (message?.content as string) || '' - if (message?.tool_calls) { - toolCalls.push(...message.tool_calls) - } - if ('usage' in completion) { - tokenUsage = completion.usage - } - } else { - // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) - let rafScheduled = false - let rafHandle: number | undefined - let pendingDeltaCount = 0 - const reasoningProcessor = new ReasoningProcessor() - const scheduleFlush = () => { - if (rafScheduled || abortController.signal.aborted) return - rafScheduled = true - const doSchedule = (cb: () => void) => { - if (typeof requestAnimationFrame !== 'undefined') { - rafHandle = requestAnimationFrame(() => cb()) - } else { - // Fallback for non-browser test environments - const t = setTimeout(() => cb(), 0) as unknown as number - rafHandle = t - } + // Handle reasoning field if there is one + const reasoning = extractReasoningFromMessage(message) + if (reasoning) { + accumulatedText = `${reasoning}` + accumulatedText + } + + if (message?.tool_calls) { + toolCalls.push(...message.tool_calls) + } + if ('usage' in completion) { + tokenUsage = completion.usage + } + } else { + // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) + let rafScheduled = false + let rafHandle: number | undefined + let pendingDeltaCount = 0 + const reasoningProcessor = new ReasoningProcessor() + const scheduleFlush = () => { + if (rafScheduled || abortController.signal.aborted) return + rafScheduled = true + const doSchedule = (cb: () => void) => { + if (typeof requestAnimationFrame !== 'undefined') { + rafHandle = requestAnimationFrame(() => cb()) + } else { + // Fallback for non-browser test environments + const t = setTimeout(() => cb(), 0) as unknown as number + rafHandle = t } - doSchedule(() => { - // Check abort status before executing the scheduled callback - if (abortController.signal.aborted) { - rafScheduled = false - return - } - - const currentContent = newAssistantThreadContent( - activeThread.id, - accumulatedText, - { - tool_calls: toolCalls.map((e) => ({ - ...e, - state: 'pending', - })), - } - ) - updateStreamingContent(currentContent) - if (tokenUsage) { - setTokenSpeed( - currentContent, - tokenUsage.completion_tokens / - Math.max((Date.now() - timeToFirstToken) / 1000, 1), - tokenUsage.completion_tokens - ) - } else if (pendingDeltaCount > 0) { - updateTokenSpeed(currentContent, pendingDeltaCount) - } - pendingDeltaCount = 0 + } + doSchedule(() => { + // Check abort status before executing the scheduled callback + if (abortController.signal.aborted) { rafScheduled = false - }) - } - const flushIfPending = () => { - if (!rafScheduled) return - if ( - typeof cancelAnimationFrame !== 'undefined' && - rafHandle !== undefined - ) { - cancelAnimationFrame(rafHandle) - } else if (rafHandle !== undefined) { - clearTimeout(rafHandle) + return } - // Do an immediate flush + const currentContent = newAssistantThreadContent( activeThread.id, accumulatedText, @@ -604,176 +570,207 @@ export const useChat = () => { } pendingDeltaCount = 0 rafScheduled = false - } - try { - for await (const part of completion) { - // Check if aborted before processing each part - if (abortController.signal.aborted) { - break - } - - // Handle prompt progress if available - if ('prompt_progress' in part && part.prompt_progress) { - // Force immediate state update to ensure we see intermediate values - flushSync(() => { - updatePromptProgress(part.prompt_progress) - }) - // Add a small delay to make progress visible - await new Promise((resolve) => setTimeout(resolve, 100)) - } - - // Error message - if (!part.choices) { - throw new Error( - 'message' in part - ? (part.message as string) - : (JSON.stringify(part) ?? '') - ) - } - - if ('usage' in part && part.usage) { - tokenUsage = part.usage - } - - if (part.choices[0]?.delta?.tool_calls) { - extractToolCall(part, currentCall, toolCalls) - // Schedule a flush to reflect tool update - scheduleFlush() - } - const deltaReasoning = - reasoningProcessor.processReasoningChunk(part) - if (deltaReasoning) { - accumulatedText += deltaReasoning - pendingDeltaCount += 1 - // Schedule flush for reasoning updates - scheduleFlush() - } - const deltaContent = part.choices[0]?.delta?.content || '' - if (deltaContent) { - accumulatedText += deltaContent - pendingDeltaCount += 1 - // Batch UI update on next animation frame - scheduleFlush() - } - } - } finally { - // Always clean up scheduled RAF when stream ends (either normally or via abort) - if (rafHandle !== undefined) { - if (typeof cancelAnimationFrame !== 'undefined') { - cancelAnimationFrame(rafHandle) - } else { - clearTimeout(rafHandle) - } - rafHandle = undefined - rafScheduled = false - } - - // Only finalize and flush if not aborted - if (!abortController.signal.aborted) { - // Finalize reasoning (close any open think tags) - accumulatedText += reasoningProcessor.finalize() - // Ensure any pending buffered content is rendered at the end - flushIfPending() - } - } + }) } - } catch (error) { - const errorMessage = - error && typeof error === 'object' && 'message' in error - ? error.message - : error - if ( - typeof errorMessage === 'string' && - errorMessage.includes(OUT_OF_CONTEXT_SIZE) && - selectedModel - ) { - const method = await showIncreaseContextSizeModal() - if (method === 'ctx_len') { - /// Increase context size - activeProvider = await increaseModelContextSize( - selectedModel.id, - activeProvider + const flushIfPending = () => { + if (!rafScheduled) return + if ( + typeof cancelAnimationFrame !== 'undefined' && + rafHandle !== undefined + ) { + cancelAnimationFrame(rafHandle) + } else if (rafHandle !== undefined) { + clearTimeout(rafHandle) + } + // Do an immediate flush + const currentContent = newAssistantThreadContent( + activeThread.id, + accumulatedText, + { + tool_calls: toolCalls.map((e) => ({ + ...e, + state: 'pending', + })), + } + ) + updateStreamingContent(currentContent) + if (tokenUsage) { + setTokenSpeed( + currentContent, + tokenUsage.completion_tokens / + Math.max((Date.now() - timeToFirstToken) / 1000, 1), + tokenUsage.completion_tokens ) - continue - } else if (method === 'context_shift' && selectedModel?.id) { - /// Enable context_shift - activeProvider = await toggleOnContextShifting( - selectedModel?.id, - activeProvider - ) - continue - } else throw error - } else { - throw error + } else if (pendingDeltaCount > 0) { + updateTokenSpeed(currentContent, pendingDeltaCount) + } + pendingDeltaCount = 0 + rafScheduled = false + } + try { + for await (const part of completion) { + // Check if aborted before processing each part + if (abortController.signal.aborted) { + break + } + + // Handle prompt progress if available + if ('prompt_progress' in part && part.prompt_progress) { + // Force immediate state update to ensure we see intermediate values + flushSync(() => { + updatePromptProgress(part.prompt_progress) + }) + // Add a small delay to make progress visible + await new Promise((resolve) => setTimeout(resolve, 100)) + } + + // Error message + if (!part.choices) { + throw new Error( + 'message' in part + ? (part.message as string) + : (JSON.stringify(part) ?? '') + ) + } + + if ('usage' in part && part.usage) { + tokenUsage = part.usage + } + + if (part.choices[0]?.delta?.tool_calls) { + extractToolCall(part, currentCall, toolCalls) + // Schedule a flush to reflect tool update + scheduleFlush() + } + const deltaReasoning = + reasoningProcessor.processReasoningChunk(part) + if (deltaReasoning) { + accumulatedText += deltaReasoning + pendingDeltaCount += 1 + // Schedule flush for reasoning updates + scheduleFlush() + } + const deltaContent = part.choices[0]?.delta?.content || '' + if (deltaContent) { + accumulatedText += deltaContent + pendingDeltaCount += 1 + // Batch UI update on next animation frame + scheduleFlush() + } + } + } finally { + // Always clean up scheduled RAF when stream ends (either normally or via abort) + if (rafHandle !== undefined) { + if (typeof cancelAnimationFrame !== 'undefined') { + cancelAnimationFrame(rafHandle) + } else { + clearTimeout(rafHandle) + } + rafHandle = undefined + rafScheduled = false + } + + // Only finalize and flush if not aborted + if (!abortController.signal.aborted) { + // Finalize reasoning (close any open think tags) + accumulatedText += reasoningProcessor.finalize() + // Ensure any pending buffered content is rendered at the end + flushIfPending() + } } } - // TODO: Remove this check when integrating new llama.cpp extension + } catch (error) { + const errorMessage = + error && typeof error === 'object' && 'message' in error + ? error.message + : error if ( - accumulatedText.length === 0 && - toolCalls.length === 0 && - activeThread.model?.id && - activeProvider?.provider === 'llamacpp' + typeof errorMessage === 'string' && + errorMessage.includes(OUT_OF_CONTEXT_SIZE) && + selectedModel ) { - await serviceHub - .models() - .stopModel(activeThread.model.id, 'llamacpp') - throw new Error('No response received from the model') - } - - const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time - - // Create a final content object for adding to the thread - const messageMetadata: Record = { - tokenSpeed: useAppState.getState().tokenSpeed, - assistant: currentAssistant, - } - - if (accumulatedText.includes('') || toolCalls.length > 0) { - messageMetadata.totalThinkingTime = totalThinkingTime - } - - const finalContent = newAssistantThreadContent( - activeThread.id, - accumulatedText, - messageMetadata - ) - - builder.addAssistantMessage(accumulatedText, undefined, toolCalls) - - // Check if proactive mode is enabled for this model - const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false - - const updatedMessage = await postMessageProcessing( - toolCalls, - builder, - finalContent, - abortController, - useToolApproval.getState().approvedTools, - allowAllMCPPermissions ? undefined : showApprovalModal, - allowAllMCPPermissions, - isProactiveMode - ) - - if (updatedMessage && updatedMessage.metadata) { - if (finalContent.metadata?.totalThinkingTime !== undefined) { - updatedMessage.metadata.totalThinkingTime = - finalContent.metadata.totalThinkingTime - } - } - - addMessage(updatedMessage ?? finalContent) - updateStreamingContent(emptyThreadContent) - updatePromptProgress(undefined) - updateThreadTimestamp(activeThread.id) - - isCompleted = !toolCalls.length - // Do not create agent loop if there is no need for it - // Check if assistant loop steps are within limits - if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) { - // Stop the assistant tool call if it exceeds the maximum steps - availableTools = [] + const method = await showIncreaseContextSizeModal() + if (method === 'ctx_len') { + /// Increase context size + activeProvider = await increaseModelContextSize( + selectedModel.id, + activeProvider + ) + // NOTE: This will exit and not retry. A more robust solution might re-call sendMessage. + // For this change, we keep the existing behavior. + return + } else if (method === 'context_shift' && selectedModel?.id) { + /// Enable context_shift + activeProvider = await toggleOnContextShifting( + selectedModel?.id, + activeProvider + ) + // NOTE: See above comment about retry. + return + } else throw error + } else { + throw error } } + // TODO: Remove this check when integrating new llama.cpp extension + if ( + accumulatedText.length === 0 && + toolCalls.length === 0 && + activeThread.model?.id && + activeProvider?.provider === 'llamacpp' + ) { + await serviceHub.models().stopModel(activeThread.model.id, 'llamacpp') + throw new Error('No response received from the model') + } + + const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time + + const messageMetadata: Record = { + tokenSpeed: useAppState.getState().tokenSpeed, + assistant: currentAssistant, + } + + if (accumulatedText.includes('') || toolCalls.length > 0) { + messageMetadata.totalThinkingTime = totalThinkingTime + } + + // This is the message object that will be built upon by postMessageProcessing + const finalContent = newAssistantThreadContent( + activeThread.id, + accumulatedText, + messageMetadata + ) + + builder.addAssistantMessage(accumulatedText, undefined, toolCalls) + // All subsequent tool calls and follow-up completions will modify `finalContent`. + const updatedMessage = await postMessageProcessing( + toolCalls, + builder, + finalContent, + abortController, + useToolApproval.getState().approvedTools, + allowAllMCPPermissions ? undefined : showApprovalModal, + allowAllMCPPermissions, + activeThread, + activeProvider, + availableTools, + updateStreamingContent, // Pass the callback to update UI + currentAssistant?.tool_steps, + isProactiveMode + ) + + if (updatedMessage && updatedMessage.metadata) { + if (finalContent.metadata?.totalThinkingTime !== undefined) { + updatedMessage.metadata.totalThinkingTime = + finalContent.metadata.totalThinkingTime + } + } + + // Add the single, final, composite message to the store. + addMessage(updatedMessage ?? finalContent) + updateStreamingContent(emptyThreadContent) + updatePromptProgress(undefined) + updateThreadTimestamp(activeThread.id) } catch (error) { if (!abortController.signal.aborted) { if (error && typeof error === 'object' && 'message' in error) { diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 4a90982de..14f4ff148 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -41,6 +41,7 @@ import { useAppState } from '@/hooks/useAppState' import { injectFilesIntoPrompt } from './fileMetadata' import { Attachment } from '@/types/attachment' import { ModelCapabilities } from '@/types/models' +import { ReasoningProcessor } from '@/utils/reasoning' export type ChatCompletionResponse = | chatCompletion @@ -48,6 +49,12 @@ export type ChatCompletionResponse = | StreamCompletionResponse | CompletionResponse +type ToolCallEntry = { + tool: object + response: any + state: 'pending' | 'ready' +} + /** * @fileoverview Helper functions for creating thread content. * These functions are used to create thread content objects @@ -73,11 +80,14 @@ export const newUserThreadContent = ( name: doc.name, type: doc.fileType, size: typeof doc.size === 'number' ? doc.size : undefined, - chunkCount: typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined, + chunkCount: + typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined, })) const textWithFiles = - docMetadata.length > 0 ? injectFilesIntoPrompt(content, docMetadata) : content + docMetadata.length > 0 + ? injectFilesIntoPrompt(content, docMetadata) + : content const contentParts = [ { @@ -238,10 +248,8 @@ export const sendCompletion = async ( const providerModelConfig = provider.models?.find( (model) => model.id === thread.model?.id || model.model === thread.model?.id ) - const effectiveCapabilities = Array.isArray( - providerModelConfig?.capabilities - ) - ? providerModelConfig?.capabilities ?? [] + const effectiveCapabilities = Array.isArray(providerModelConfig?.capabilities) + ? (providerModelConfig?.capabilities ?? []) : getModelCapabilities(provider.provider, thread.model.id) const modelSupportsTools = effectiveCapabilities.includes( ModelCapabilities.TOOLS @@ -254,7 +262,10 @@ export const sendCompletion = async ( PlatformFeatures[PlatformFeature.ATTACHMENTS] && modelSupportsTools ) { - const ragTools = await getServiceHub().rag().getTools().catch(() => []) + const ragTools = await getServiceHub() + .rag() + .getTools() + .catch(() => []) if (Array.isArray(ragTools) && ragTools.length) { usableTools = [...tools, ...ragTools] } @@ -396,6 +407,9 @@ export const extractToolCall = ( return calls } +// Keep track of total tool steps to prevent infinite loops +let toolStepCounter = 0 + /** * Helper function to check if a tool call is a browser MCP tool * @param toolName - The name of the tool @@ -533,10 +547,22 @@ export const postMessageProcessing = async ( toolParameters?: object ) => Promise, allowAllMCPPermissions: boolean = false, + thread?: Thread, + provider?: ModelProvider, + tools: MCPTool[] = [], + updateStreamingUI?: (content: ThreadMessage) => void, + maxToolSteps: number = 20, isProactiveMode: boolean = false -) => { +): Promise => { + // Reset counter at the start of a new message processing chain + if (toolStepCounter === 0) { + toolStepCounter = 0 + } + // Handle completed tool calls - if (calls.length) { + if (calls.length > 0) { + toolStepCounter++ + // Fetch RAG tool names from RAG service let ragToolNames = new Set() try { @@ -546,43 +572,41 @@ export const postMessageProcessing = async ( console.error('Failed to load RAG tool names:', e) } const ragFeatureAvailable = - useAttachments.getState().enabled && PlatformFeatures[PlatformFeature.ATTACHMENTS] + useAttachments.getState().enabled && + PlatformFeatures[PlatformFeature.ATTACHMENTS] + + const currentToolCalls = + message.metadata?.tool_calls && Array.isArray(message.metadata.tool_calls) + ? [...message.metadata.tool_calls] + : [] + for (const toolCall of calls) { if (abortController.signal.aborted) break const toolId = ulid() - const toolCallsMetadata = - message.metadata?.tool_calls && - Array.isArray(message.metadata?.tool_calls) - ? message.metadata?.tool_calls - : [] + + const toolCallEntry: ToolCallEntry = { + tool: { + ...(toolCall as object), + id: toolId, + }, + response: undefined, + state: 'pending' as 'pending' | 'ready', + } + currentToolCalls.push(toolCallEntry) + message.metadata = { ...(message.metadata ?? {}), - tool_calls: [ - ...toolCallsMetadata, - { - tool: { - ...(toolCall as object), - id: toolId, - }, - response: undefined, - state: 'pending', - }, - ], + tool_calls: currentToolCalls, } + if (updateStreamingUI) updateStreamingUI({ ...message }) // Show pending call // Check if tool is approved or show modal for approval let toolParameters = {} if (toolCall.function.arguments.length) { try { - console.log('Raw tool arguments:', toolCall.function.arguments) toolParameters = JSON.parse(toolCall.function.arguments) - console.log('Parsed tool parameters:', toolParameters) } catch (error) { console.error('Failed to parse tool arguments:', error) - console.error( - 'Raw arguments that failed:', - toolCall.function.arguments - ) } } @@ -591,7 +615,6 @@ export const postMessageProcessing = async ( const isRagTool = ragToolNames.has(toolName) const isBrowserTool = isBrowserMCPTool(toolName) - // Auto-approve RAG tools (local/safe operations), require permission for MCP tools const approved = isRagTool ? true : allowAllMCPPermissions || @@ -607,7 +630,11 @@ export const postMessageProcessing = async ( const { promise, cancel } = isRagTool ? ragFeatureAvailable ? { - promise: getServiceHub().rag().callTool({ toolName, arguments: toolArgs, threadId: message.thread_id }), + promise: getServiceHub().rag().callTool({ + toolName, + arguments: toolArgs, + threadId: message.thread_id, + }), cancel: async () => {}, } : { @@ -630,18 +657,15 @@ export const postMessageProcessing = async ( useAppState.getState().setCancelToolCall(cancel) let result = approved - ? await promise.catch((e) => { - console.error('Tool call failed:', e) - return { - content: [ - { - type: 'text', - text: `Error calling tool ${toolCall.function.name}: ${e.message ?? e}`, - }, - ], - error: String(e?.message ?? e ?? 'Tool call failed'), - } - }) + ? await promise.catch((e) => ({ + content: [ + { + type: 'text', + text: `Error calling tool ${toolCall.function.name}: ${e.message ?? e}`, + }, + ], + error: String(e?.message ?? e ?? 'Tool call failed'), + })) : { content: [ { @@ -654,30 +678,15 @@ export const postMessageProcessing = async ( if (typeof result === 'string') { result = { - content: [ - { - type: 'text', - text: result, - }, - ], + content: [{ type: 'text', text: result }], error: '', } } - message.metadata = { - ...(message.metadata ?? {}), - tool_calls: [ - ...toolCallsMetadata, - { - tool: { - ...toolCall, - id: toolId, - }, - response: result, - state: 'ready', - }, - ], - } + // Update the entry in the metadata array + toolCallEntry.response = result + toolCallEntry.state = 'ready' + if (updateStreamingUI) updateStreamingUI({ ...message }) // Show result builder.addToolMessage(result as ToolResult, toolCall.id) // Proactive mode: Capture screenshot/snapshot after browser tool execution @@ -702,6 +711,98 @@ export const postMessageProcessing = async ( // update message metadata } - return message + + if ( + thread && + provider && + !abortController.signal.aborted && + toolStepCounter < maxToolSteps + ) { + try { + const messagesWithToolResults = builder.getMessages() + + const followUpCompletion = await sendCompletion( + thread, + provider, + messagesWithToolResults, + abortController, + tools, + true, + {} + ) + + if (followUpCompletion) { + let followUpText = '' + const newToolCalls: ChatCompletionMessageToolCall[] = [] + const textContent = message.content.find( + (c) => c.type === ContentType.Text + ) + + if (isCompletionResponse(followUpCompletion)) { + const choice = followUpCompletion.choices[0] + const content = choice?.message?.content + if (content) followUpText = content as string + if (choice?.message?.tool_calls) { + newToolCalls.push(...choice.message.tool_calls) + } + if (textContent?.text) textContent.text.value += followUpText + if (updateStreamingUI) updateStreamingUI({ ...message }) + } else { + const reasoningProcessor = new ReasoningProcessor() + for await (const chunk of followUpCompletion) { + if (abortController.signal.aborted) break + + const deltaReasoning = + reasoningProcessor.processReasoningChunk(chunk) + const deltaContent = chunk.choices[0]?.delta?.content || '' + + if (textContent?.text) { + if (deltaReasoning) textContent.text.value += deltaReasoning + if (deltaContent) textContent.text.value += deltaContent + } + if (deltaContent) followUpText += deltaContent + + if (chunk.choices[0]?.delta?.tool_calls) { + extractToolCall(chunk, null, newToolCalls) + } + + if (updateStreamingUI) updateStreamingUI({ ...message }) + } + if (textContent?.text) { + textContent.text.value += reasoningProcessor.finalize() + if (updateStreamingUI) updateStreamingUI({ ...message }) + } + } + + if (newToolCalls.length > 0) { + builder.addAssistantMessage(followUpText, undefined, newToolCalls) + await postMessageProcessing( + newToolCalls, + builder, + message, + abortController, + approvedTools, + showModal, + allowAllMCPPermissions, + thread, + provider, + tools, + updateStreamingUI, + maxToolSteps, + isProactiveMode + ) + } + } + } catch (error) { + console.error( + 'Failed to get follow-up completion after tool execution:', + String(error) + ) + } + } } + + // Reset counter when the chain is fully resolved + toolStepCounter = 0 + return message }