chore: Refactor chat flow – remove loop, centralize tool handling, add step limit

* Move the assistant‑loop logic out of `useChat` and into `postMessageProcessing`.
* Eliminate the while‑loop that drove repeated completions; now a single completion is sent and subsequent tool calls are processed recursively.
* Introduce early‑abort checks and guard against missing provider before proceeding.
* Add `ReasoningProcessor` import and use it consistently for streaming reasoning chunks.
* Add `ToolCallEntry` type and a global `toolStepCounter` to track and cap total tool steps (default 20) to prevent infinite loops.
* Extend `postMessageProcessing` signature to accept thread, provider, tools, UI update callback, and max tool steps.
* Update UI‑update logic to use a single `updateStreamingUI` callback and ensure RAF scheduling is cleaned up reliably.
* Refactor token‑speed / progress handling, improve error handling for out‑of‑context situations, and tidy up code formatting.
* Minor clean‑ups: const‑ify `availableTools`, remove unused variables, improve readability.
This commit is contained in:
Akarshan 2025-10-20 21:06:23 +05:30
parent 2f00ae0d33
commit c129757097
No known key found for this signature in database
GPG Key ID: D75C9634A870665F
2 changed files with 450 additions and 352 deletions

View File

@ -425,10 +425,8 @@ export const useChat = () => {
// Using addUserMessage to respect legacy code. Should be using the userContent above. // Using addUserMessage to respect legacy code. Should be using the userContent above.
if (troubleshooting) builder.addUserMessage(userContent) if (troubleshooting) builder.addUserMessage(userContent)
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread // Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools') const availableTools = selectedModel?.capabilities?.includes('tools')
? useAppState.getState().tools.filter((tool) => { ? useAppState.getState().tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id) const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name) return !disabledTools.includes(tool.name)
@ -436,13 +434,21 @@ export const useChat = () => {
: [] : []
// Check if proactive mode is enabled // Check if proactive mode is enabled
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false const isProactiveMode =
selectedModel?.capabilities?.includes('proactive') ?? false
// Proactive mode: Capture initial screenshot/snapshot before first LLM call // Proactive mode: Capture initial screenshot/snapshot before first LLM call
if (isProactiveMode && availableTools.length > 0 && !abortController.signal.aborted) { if (
console.log('Proactive mode: Capturing initial screenshots before LLM call') isProactiveMode &&
availableTools.length > 0 &&
!abortController.signal.aborted
) {
console.log(
'Proactive mode: Capturing initial screenshots before LLM call'
)
try { try {
const initialScreenshots = await captureProactiveScreenshots(abortController) const initialScreenshots =
await captureProactiveScreenshots(abortController)
// Add initial screenshots to builder // Add initial screenshots to builder
for (const screenshot of initialScreenshots) { for (const screenshot of initialScreenshots) {
@ -456,17 +462,14 @@ export const useChat = () => {
} }
} }
let assistantLoopSteps = 0 // The agent logic is now self-contained within postMessageProcessing.
// We no longer need a `while` loop here.
if (abortController.signal.aborted || !activeProvider) return
while (
!isCompleted &&
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find( const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id (m) => m.id === selectedModel?.id
) )
assistantLoopSteps += 1
const modelSettings = modelConfig?.settings const modelSettings = modelConfig?.settings
? Object.fromEntries( ? Object.fromEntries(
@ -510,8 +513,7 @@ export const useChat = () => {
// Handle reasoning field if there is one // Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message) const reasoning = extractReasoningFromMessage(message)
if (reasoning) { if (reasoning) {
accumulatedText = accumulatedText = `<think>${reasoning}</think>` + accumulatedText
`<think>${reasoning}</think>` + accumulatedText
} }
if (message?.tool_calls) { if (message?.tool_calls) {
@ -694,14 +696,17 @@ export const useChat = () => {
selectedModel.id, selectedModel.id,
activeProvider activeProvider
) )
continue // NOTE: This will exit and not retry. A more robust solution might re-call sendMessage.
// For this change, we keep the existing behavior.
return
} else if (method === 'context_shift' && selectedModel?.id) { } else if (method === 'context_shift' && selectedModel?.id) {
/// Enable context_shift /// Enable context_shift
activeProvider = await toggleOnContextShifting( activeProvider = await toggleOnContextShifting(
selectedModel?.id, selectedModel?.id,
activeProvider activeProvider
) )
continue // NOTE: See above comment about retry.
return
} else throw error } else throw error
} else { } else {
throw error throw error
@ -714,15 +719,12 @@ export const useChat = () => {
activeThread.model?.id && activeThread.model?.id &&
activeProvider?.provider === 'llamacpp' activeProvider?.provider === 'llamacpp'
) { ) {
await serviceHub await serviceHub.models().stopModel(activeThread.model.id, 'llamacpp')
.models()
.stopModel(activeThread.model.id, 'llamacpp')
throw new Error('No response received from the model') throw new Error('No response received from the model')
} }
const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time
// Create a final content object for adding to the thread
const messageMetadata: Record<string, any> = { const messageMetadata: Record<string, any> = {
tokenSpeed: useAppState.getState().tokenSpeed, tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant, assistant: currentAssistant,
@ -732,6 +734,7 @@ export const useChat = () => {
messageMetadata.totalThinkingTime = totalThinkingTime messageMetadata.totalThinkingTime = totalThinkingTime
} }
// This is the message object that will be built upon by postMessageProcessing
const finalContent = newAssistantThreadContent( const finalContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText, accumulatedText,
@ -739,10 +742,7 @@ export const useChat = () => {
) )
builder.addAssistantMessage(accumulatedText, undefined, toolCalls) builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
// All subsequent tool calls and follow-up completions will modify `finalContent`.
// Check if proactive mode is enabled for this model
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false
const updatedMessage = await postMessageProcessing( const updatedMessage = await postMessageProcessing(
toolCalls, toolCalls,
builder, builder,
@ -751,6 +751,11 @@ export const useChat = () => {
useToolApproval.getState().approvedTools, useToolApproval.getState().approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal, allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions, allowAllMCPPermissions,
activeThread,
activeProvider,
availableTools,
updateStreamingContent, // Pass the callback to update UI
currentAssistant?.tool_steps,
isProactiveMode isProactiveMode
) )
@ -761,19 +766,11 @@ export const useChat = () => {
} }
} }
// Add the single, final, composite message to the store.
addMessage(updatedMessage ?? finalContent) addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined) updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
// Check if assistant loop steps are within limits
if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) {
// Stop the assistant tool call if it exceeds the maximum steps
availableTools = []
}
}
} catch (error) { } catch (error) {
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {
if (error && typeof error === 'object' && 'message' in error) { if (error && typeof error === 'object' && 'message' in error) {

View File

@ -41,6 +41,7 @@ import { useAppState } from '@/hooks/useAppState'
import { injectFilesIntoPrompt } from './fileMetadata' import { injectFilesIntoPrompt } from './fileMetadata'
import { Attachment } from '@/types/attachment' import { Attachment } from '@/types/attachment'
import { ModelCapabilities } from '@/types/models' import { ModelCapabilities } from '@/types/models'
import { ReasoningProcessor } from '@/utils/reasoning'
export type ChatCompletionResponse = export type ChatCompletionResponse =
| chatCompletion | chatCompletion
@ -48,6 +49,12 @@ export type ChatCompletionResponse =
| StreamCompletionResponse | StreamCompletionResponse
| CompletionResponse | CompletionResponse
type ToolCallEntry = {
tool: object
response: any
state: 'pending' | 'ready'
}
/** /**
* @fileoverview Helper functions for creating thread content. * @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects * These functions are used to create thread content objects
@ -73,11 +80,14 @@ export const newUserThreadContent = (
name: doc.name, name: doc.name,
type: doc.fileType, type: doc.fileType,
size: typeof doc.size === 'number' ? doc.size : undefined, size: typeof doc.size === 'number' ? doc.size : undefined,
chunkCount: typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined, chunkCount:
typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined,
})) }))
const textWithFiles = const textWithFiles =
docMetadata.length > 0 ? injectFilesIntoPrompt(content, docMetadata) : content docMetadata.length > 0
? injectFilesIntoPrompt(content, docMetadata)
: content
const contentParts = [ const contentParts = [
{ {
@ -238,10 +248,8 @@ export const sendCompletion = async (
const providerModelConfig = provider.models?.find( const providerModelConfig = provider.models?.find(
(model) => model.id === thread.model?.id || model.model === thread.model?.id (model) => model.id === thread.model?.id || model.model === thread.model?.id
) )
const effectiveCapabilities = Array.isArray( const effectiveCapabilities = Array.isArray(providerModelConfig?.capabilities)
providerModelConfig?.capabilities ? (providerModelConfig?.capabilities ?? [])
)
? providerModelConfig?.capabilities ?? []
: getModelCapabilities(provider.provider, thread.model.id) : getModelCapabilities(provider.provider, thread.model.id)
const modelSupportsTools = effectiveCapabilities.includes( const modelSupportsTools = effectiveCapabilities.includes(
ModelCapabilities.TOOLS ModelCapabilities.TOOLS
@ -254,7 +262,10 @@ export const sendCompletion = async (
PlatformFeatures[PlatformFeature.ATTACHMENTS] && PlatformFeatures[PlatformFeature.ATTACHMENTS] &&
modelSupportsTools modelSupportsTools
) { ) {
const ragTools = await getServiceHub().rag().getTools().catch(() => []) const ragTools = await getServiceHub()
.rag()
.getTools()
.catch(() => [])
if (Array.isArray(ragTools) && ragTools.length) { if (Array.isArray(ragTools) && ragTools.length) {
usableTools = [...tools, ...ragTools] usableTools = [...tools, ...ragTools]
} }
@ -396,6 +407,9 @@ export const extractToolCall = (
return calls return calls
} }
// Keep track of total tool steps to prevent infinite loops
let toolStepCounter = 0
/** /**
* Helper function to check if a tool call is a browser MCP tool * Helper function to check if a tool call is a browser MCP tool
* @param toolName - The name of the tool * @param toolName - The name of the tool
@ -533,10 +547,22 @@ export const postMessageProcessing = async (
toolParameters?: object toolParameters?: object
) => Promise<boolean>, ) => Promise<boolean>,
allowAllMCPPermissions: boolean = false, allowAllMCPPermissions: boolean = false,
thread?: Thread,
provider?: ModelProvider,
tools: MCPTool[] = [],
updateStreamingUI?: (content: ThreadMessage) => void,
maxToolSteps: number = 20,
isProactiveMode: boolean = false isProactiveMode: boolean = false
) => { ): Promise<ThreadMessage> => {
// Reset counter at the start of a new message processing chain
if (toolStepCounter === 0) {
toolStepCounter = 0
}
// Handle completed tool calls // Handle completed tool calls
if (calls.length) { if (calls.length > 0) {
toolStepCounter++
// Fetch RAG tool names from RAG service // Fetch RAG tool names from RAG service
let ragToolNames = new Set<string>() let ragToolNames = new Set<string>()
try { try {
@ -546,43 +572,41 @@ export const postMessageProcessing = async (
console.error('Failed to load RAG tool names:', e) console.error('Failed to load RAG tool names:', e)
} }
const ragFeatureAvailable = const ragFeatureAvailable =
useAttachments.getState().enabled && PlatformFeatures[PlatformFeature.ATTACHMENTS] useAttachments.getState().enabled &&
PlatformFeatures[PlatformFeature.ATTACHMENTS]
const currentToolCalls =
message.metadata?.tool_calls && Array.isArray(message.metadata.tool_calls)
? [...message.metadata.tool_calls]
: []
for (const toolCall of calls) { for (const toolCall of calls) {
if (abortController.signal.aborted) break if (abortController.signal.aborted) break
const toolId = ulid() const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls && const toolCallEntry: ToolCallEntry = {
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: { tool: {
...(toolCall as object), ...(toolCall as object),
id: toolId, id: toolId,
}, },
response: undefined, response: undefined,
state: 'pending', state: 'pending' as 'pending' | 'ready',
},
],
} }
currentToolCalls.push(toolCallEntry)
message.metadata = {
...(message.metadata ?? {}),
tool_calls: currentToolCalls,
}
if (updateStreamingUI) updateStreamingUI({ ...message }) // Show pending call
// Check if tool is approved or show modal for approval // Check if tool is approved or show modal for approval
let toolParameters = {} let toolParameters = {}
if (toolCall.function.arguments.length) { if (toolCall.function.arguments.length) {
try { try {
console.log('Raw tool arguments:', toolCall.function.arguments)
toolParameters = JSON.parse(toolCall.function.arguments) toolParameters = JSON.parse(toolCall.function.arguments)
console.log('Parsed tool parameters:', toolParameters)
} catch (error) { } catch (error) {
console.error('Failed to parse tool arguments:', error) console.error('Failed to parse tool arguments:', error)
console.error(
'Raw arguments that failed:',
toolCall.function.arguments
)
} }
} }
@ -591,7 +615,6 @@ export const postMessageProcessing = async (
const isRagTool = ragToolNames.has(toolName) const isRagTool = ragToolNames.has(toolName)
const isBrowserTool = isBrowserMCPTool(toolName) const isBrowserTool = isBrowserMCPTool(toolName)
// Auto-approve RAG tools (local/safe operations), require permission for MCP tools
const approved = isRagTool const approved = isRagTool
? true ? true
: allowAllMCPPermissions || : allowAllMCPPermissions ||
@ -607,7 +630,11 @@ export const postMessageProcessing = async (
const { promise, cancel } = isRagTool const { promise, cancel } = isRagTool
? ragFeatureAvailable ? ragFeatureAvailable
? { ? {
promise: getServiceHub().rag().callTool({ toolName, arguments: toolArgs, threadId: message.thread_id }), promise: getServiceHub().rag().callTool({
toolName,
arguments: toolArgs,
threadId: message.thread_id,
}),
cancel: async () => {}, cancel: async () => {},
} }
: { : {
@ -630,9 +657,7 @@ export const postMessageProcessing = async (
useAppState.getState().setCancelToolCall(cancel) useAppState.getState().setCancelToolCall(cancel)
let result = approved let result = approved
? await promise.catch((e) => { ? await promise.catch((e) => ({
console.error('Tool call failed:', e)
return {
content: [ content: [
{ {
type: 'text', type: 'text',
@ -640,8 +665,7 @@ export const postMessageProcessing = async (
}, },
], ],
error: String(e?.message ?? e ?? 'Tool call failed'), error: String(e?.message ?? e ?? 'Tool call failed'),
} }))
})
: { : {
content: [ content: [
{ {
@ -654,30 +678,15 @@ export const postMessageProcessing = async (
if (typeof result === 'string') { if (typeof result === 'string') {
result = { result = {
content: [ content: [{ type: 'text', text: result }],
{
type: 'text',
text: result,
},
],
error: '', error: '',
} }
} }
message.metadata = { // Update the entry in the metadata array
...(message.metadata ?? {}), toolCallEntry.response = result
tool_calls: [ toolCallEntry.state = 'ready'
...toolCallsMetadata, if (updateStreamingUI) updateStreamingUI({ ...message }) // Show result
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
builder.addToolMessage(result as ToolResult, toolCall.id) builder.addToolMessage(result as ToolResult, toolCall.id)
// Proactive mode: Capture screenshot/snapshot after browser tool execution // Proactive mode: Capture screenshot/snapshot after browser tool execution
@ -702,6 +711,98 @@ export const postMessageProcessing = async (
// update message metadata // update message metadata
} }
return message
if (
thread &&
provider &&
!abortController.signal.aborted &&
toolStepCounter < maxToolSteps
) {
try {
const messagesWithToolResults = builder.getMessages()
const followUpCompletion = await sendCompletion(
thread,
provider,
messagesWithToolResults,
abortController,
tools,
true,
{}
)
if (followUpCompletion) {
let followUpText = ''
const newToolCalls: ChatCompletionMessageToolCall[] = []
const textContent = message.content.find(
(c) => c.type === ContentType.Text
)
if (isCompletionResponse(followUpCompletion)) {
const choice = followUpCompletion.choices[0]
const content = choice?.message?.content
if (content) followUpText = content as string
if (choice?.message?.tool_calls) {
newToolCalls.push(...choice.message.tool_calls)
} }
if (textContent?.text) textContent.text.value += followUpText
if (updateStreamingUI) updateStreamingUI({ ...message })
} else {
const reasoningProcessor = new ReasoningProcessor()
for await (const chunk of followUpCompletion) {
if (abortController.signal.aborted) break
const deltaReasoning =
reasoningProcessor.processReasoningChunk(chunk)
const deltaContent = chunk.choices[0]?.delta?.content || ''
if (textContent?.text) {
if (deltaReasoning) textContent.text.value += deltaReasoning
if (deltaContent) textContent.text.value += deltaContent
}
if (deltaContent) followUpText += deltaContent
if (chunk.choices[0]?.delta?.tool_calls) {
extractToolCall(chunk, null, newToolCalls)
}
if (updateStreamingUI) updateStreamingUI({ ...message })
}
if (textContent?.text) {
textContent.text.value += reasoningProcessor.finalize()
if (updateStreamingUI) updateStreamingUI({ ...message })
}
}
if (newToolCalls.length > 0) {
builder.addAssistantMessage(followUpText, undefined, newToolCalls)
await postMessageProcessing(
newToolCalls,
builder,
message,
abortController,
approvedTools,
showModal,
allowAllMCPPermissions,
thread,
provider,
tools,
updateStreamingUI,
maxToolSteps,
isProactiveMode
)
}
}
} catch (error) {
console.error(
'Failed to get follow-up completion after tool execution:',
String(error)
)
}
}
}
// Reset counter when the chain is fully resolved
toolStepCounter = 0
return message
} }