chore: Refactor chat flow – remove loop, centralize tool handling, add step limit

* Move the assistant‑loop logic out of `useChat` and into `postMessageProcessing`.
* Eliminate the while‑loop that drove repeated completions; now a single completion is sent and subsequent tool calls are processed recursively.
* Introduce early‑abort checks and guard against missing provider before proceeding.
* Add `ReasoningProcessor` import and use it consistently for streaming reasoning chunks.
* Add `ToolCallEntry` type and a global `toolStepCounter` to track and cap total tool steps (default 20) to prevent infinite loops.
* Extend `postMessageProcessing` signature to accept thread, provider, tools, UI update callback, and max tool steps.
* Update UI‑update logic to use a single `updateStreamingUI` callback and ensure RAF scheduling is cleaned up reliably.
* Refactor token‑speed / progress handling, improve error handling for out‑of‑context situations, and tidy up code formatting.
* Minor clean‑ups: const‑ify `availableTools`, remove unused variables, improve readability.
This commit is contained in:
Akarshan 2025-10-20 21:06:23 +05:30
parent 2f00ae0d33
commit c129757097
No known key found for this signature in database
GPG Key ID: D75C9634A870665F
2 changed files with 450 additions and 352 deletions

View File

@ -425,10 +425,8 @@ export const useChat = () => {
// Using addUserMessage to respect legacy code. Should be using the userContent above.
if (troubleshooting) builder.addUserMessage(userContent)
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools')
const availableTools = selectedModel?.capabilities?.includes('tools')
? useAppState.getState().tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
@ -436,13 +434,21 @@ export const useChat = () => {
: []
// Check if proactive mode is enabled
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false
const isProactiveMode =
selectedModel?.capabilities?.includes('proactive') ?? false
// Proactive mode: Capture initial screenshot/snapshot before first LLM call
if (isProactiveMode && availableTools.length > 0 && !abortController.signal.aborted) {
console.log('Proactive mode: Capturing initial screenshots before LLM call')
if (
isProactiveMode &&
availableTools.length > 0 &&
!abortController.signal.aborted
) {
console.log(
'Proactive mode: Capturing initial screenshots before LLM call'
)
try {
const initialScreenshots = await captureProactiveScreenshots(abortController)
const initialScreenshots =
await captureProactiveScreenshots(abortController)
// Add initial screenshots to builder
for (const screenshot of initialScreenshots) {
@ -456,17 +462,14 @@ export const useChat = () => {
}
}
let assistantLoopSteps = 0
// The agent logic is now self-contained within postMessageProcessing.
// We no longer need a `while` loop here.
if (abortController.signal.aborted || !activeProvider) return
while (
!isCompleted &&
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id
)
assistantLoopSteps += 1
const modelSettings = modelConfig?.settings
? Object.fromEntries(
@ -510,8 +513,7 @@ export const useChat = () => {
// Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
accumulatedText = `<think>${reasoning}</think>` + accumulatedText
}
if (message?.tool_calls) {
@ -694,14 +696,17 @@ export const useChat = () => {
selectedModel.id,
activeProvider
)
continue
// NOTE: This will exit and not retry. A more robust solution might re-call sendMessage.
// For this change, we keep the existing behavior.
return
} else if (method === 'context_shift' && selectedModel?.id) {
/// Enable context_shift
activeProvider = await toggleOnContextShifting(
selectedModel?.id,
activeProvider
)
continue
// NOTE: See above comment about retry.
return
} else throw error
} else {
throw error
@ -714,15 +719,12 @@ export const useChat = () => {
activeThread.model?.id &&
activeProvider?.provider === 'llamacpp'
) {
await serviceHub
.models()
.stopModel(activeThread.model.id, 'llamacpp')
await serviceHub.models().stopModel(activeThread.model.id, 'llamacpp')
throw new Error('No response received from the model')
}
const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time
// Create a final content object for adding to the thread
const messageMetadata: Record<string, any> = {
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
@ -732,6 +734,7 @@ export const useChat = () => {
messageMetadata.totalThinkingTime = totalThinkingTime
}
// This is the message object that will be built upon by postMessageProcessing
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
@ -739,10 +742,7 @@ export const useChat = () => {
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
// Check if proactive mode is enabled for this model
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false
// All subsequent tool calls and follow-up completions will modify `finalContent`.
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
@ -751,6 +751,11 @@ export const useChat = () => {
useToolApproval.getState().approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions,
activeThread,
activeProvider,
availableTools,
updateStreamingContent, // Pass the callback to update UI
currentAssistant?.tool_steps,
isProactiveMode
)
@ -761,19 +766,11 @@ export const useChat = () => {
}
}
// Add the single, final, composite message to the store.
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
// Check if assistant loop steps are within limits
if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) {
// Stop the assistant tool call if it exceeds the maximum steps
availableTools = []
}
}
} catch (error) {
if (!abortController.signal.aborted) {
if (error && typeof error === 'object' && 'message' in error) {

View File

@ -41,6 +41,7 @@ import { useAppState } from '@/hooks/useAppState'
import { injectFilesIntoPrompt } from './fileMetadata'
import { Attachment } from '@/types/attachment'
import { ModelCapabilities } from '@/types/models'
import { ReasoningProcessor } from '@/utils/reasoning'
export type ChatCompletionResponse =
| chatCompletion
@ -48,6 +49,12 @@ export type ChatCompletionResponse =
| StreamCompletionResponse
| CompletionResponse
type ToolCallEntry = {
tool: object
response: any
state: 'pending' | 'ready'
}
/**
* @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects
@ -73,11 +80,14 @@ export const newUserThreadContent = (
name: doc.name,
type: doc.fileType,
size: typeof doc.size === 'number' ? doc.size : undefined,
chunkCount: typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined,
chunkCount:
typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined,
}))
const textWithFiles =
docMetadata.length > 0 ? injectFilesIntoPrompt(content, docMetadata) : content
docMetadata.length > 0
? injectFilesIntoPrompt(content, docMetadata)
: content
const contentParts = [
{
@ -238,10 +248,8 @@ export const sendCompletion = async (
const providerModelConfig = provider.models?.find(
(model) => model.id === thread.model?.id || model.model === thread.model?.id
)
const effectiveCapabilities = Array.isArray(
providerModelConfig?.capabilities
)
? providerModelConfig?.capabilities ?? []
const effectiveCapabilities = Array.isArray(providerModelConfig?.capabilities)
? (providerModelConfig?.capabilities ?? [])
: getModelCapabilities(provider.provider, thread.model.id)
const modelSupportsTools = effectiveCapabilities.includes(
ModelCapabilities.TOOLS
@ -254,7 +262,10 @@ export const sendCompletion = async (
PlatformFeatures[PlatformFeature.ATTACHMENTS] &&
modelSupportsTools
) {
const ragTools = await getServiceHub().rag().getTools().catch(() => [])
const ragTools = await getServiceHub()
.rag()
.getTools()
.catch(() => [])
if (Array.isArray(ragTools) && ragTools.length) {
usableTools = [...tools, ...ragTools]
}
@ -396,6 +407,9 @@ export const extractToolCall = (
return calls
}
// Keep track of total tool steps to prevent infinite loops
let toolStepCounter = 0
/**
* Helper function to check if a tool call is a browser MCP tool
* @param toolName - The name of the tool
@ -533,10 +547,22 @@ export const postMessageProcessing = async (
toolParameters?: object
) => Promise<boolean>,
allowAllMCPPermissions: boolean = false,
thread?: Thread,
provider?: ModelProvider,
tools: MCPTool[] = [],
updateStreamingUI?: (content: ThreadMessage) => void,
maxToolSteps: number = 20,
isProactiveMode: boolean = false
) => {
): Promise<ThreadMessage> => {
// Reset counter at the start of a new message processing chain
if (toolStepCounter === 0) {
toolStepCounter = 0
}
// Handle completed tool calls
if (calls.length) {
if (calls.length > 0) {
toolStepCounter++
// Fetch RAG tool names from RAG service
let ragToolNames = new Set<string>()
try {
@ -546,43 +572,41 @@ export const postMessageProcessing = async (
console.error('Failed to load RAG tool names:', e)
}
const ragFeatureAvailable =
useAttachments.getState().enabled && PlatformFeatures[PlatformFeature.ATTACHMENTS]
useAttachments.getState().enabled &&
PlatformFeatures[PlatformFeature.ATTACHMENTS]
const currentToolCalls =
message.metadata?.tool_calls && Array.isArray(message.metadata.tool_calls)
? [...message.metadata.tool_calls]
: []
for (const toolCall of calls) {
if (abortController.signal.aborted) break
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
const toolCallEntry: ToolCallEntry = {
tool: {
...(toolCall as object),
id: toolId,
},
response: undefined,
state: 'pending',
},
],
state: 'pending' as 'pending' | 'ready',
}
currentToolCalls.push(toolCallEntry)
message.metadata = {
...(message.metadata ?? {}),
tool_calls: currentToolCalls,
}
if (updateStreamingUI) updateStreamingUI({ ...message }) // Show pending call
// Check if tool is approved or show modal for approval
let toolParameters = {}
if (toolCall.function.arguments.length) {
try {
console.log('Raw tool arguments:', toolCall.function.arguments)
toolParameters = JSON.parse(toolCall.function.arguments)
console.log('Parsed tool parameters:', toolParameters)
} catch (error) {
console.error('Failed to parse tool arguments:', error)
console.error(
'Raw arguments that failed:',
toolCall.function.arguments
)
}
}
@ -591,7 +615,6 @@ export const postMessageProcessing = async (
const isRagTool = ragToolNames.has(toolName)
const isBrowserTool = isBrowserMCPTool(toolName)
// Auto-approve RAG tools (local/safe operations), require permission for MCP tools
const approved = isRagTool
? true
: allowAllMCPPermissions ||
@ -607,7 +630,11 @@ export const postMessageProcessing = async (
const { promise, cancel } = isRagTool
? ragFeatureAvailable
? {
promise: getServiceHub().rag().callTool({ toolName, arguments: toolArgs, threadId: message.thread_id }),
promise: getServiceHub().rag().callTool({
toolName,
arguments: toolArgs,
threadId: message.thread_id,
}),
cancel: async () => {},
}
: {
@ -630,9 +657,7 @@ export const postMessageProcessing = async (
useAppState.getState().setCancelToolCall(cancel)
let result = approved
? await promise.catch((e) => {
console.error('Tool call failed:', e)
return {
? await promise.catch((e) => ({
content: [
{
type: 'text',
@ -640,8 +665,7 @@ export const postMessageProcessing = async (
},
],
error: String(e?.message ?? e ?? 'Tool call failed'),
}
})
}))
: {
content: [
{
@ -654,30 +678,15 @@ export const postMessageProcessing = async (
if (typeof result === 'string') {
result = {
content: [
{
type: 'text',
text: result,
},
],
content: [{ type: 'text', text: result }],
error: '',
}
}
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
// Update the entry in the metadata array
toolCallEntry.response = result
toolCallEntry.state = 'ready'
if (updateStreamingUI) updateStreamingUI({ ...message }) // Show result
builder.addToolMessage(result as ToolResult, toolCall.id)
// Proactive mode: Capture screenshot/snapshot after browser tool execution
@ -702,6 +711,98 @@ export const postMessageProcessing = async (
// update message metadata
}
if (
thread &&
provider &&
!abortController.signal.aborted &&
toolStepCounter < maxToolSteps
) {
try {
const messagesWithToolResults = builder.getMessages()
const followUpCompletion = await sendCompletion(
thread,
provider,
messagesWithToolResults,
abortController,
tools,
true,
{}
)
if (followUpCompletion) {
let followUpText = ''
const newToolCalls: ChatCompletionMessageToolCall[] = []
const textContent = message.content.find(
(c) => c.type === ContentType.Text
)
if (isCompletionResponse(followUpCompletion)) {
const choice = followUpCompletion.choices[0]
const content = choice?.message?.content
if (content) followUpText = content as string
if (choice?.message?.tool_calls) {
newToolCalls.push(...choice.message.tool_calls)
}
if (textContent?.text) textContent.text.value += followUpText
if (updateStreamingUI) updateStreamingUI({ ...message })
} else {
const reasoningProcessor = new ReasoningProcessor()
for await (const chunk of followUpCompletion) {
if (abortController.signal.aborted) break
const deltaReasoning =
reasoningProcessor.processReasoningChunk(chunk)
const deltaContent = chunk.choices[0]?.delta?.content || ''
if (textContent?.text) {
if (deltaReasoning) textContent.text.value += deltaReasoning
if (deltaContent) textContent.text.value += deltaContent
}
if (deltaContent) followUpText += deltaContent
if (chunk.choices[0]?.delta?.tool_calls) {
extractToolCall(chunk, null, newToolCalls)
}
if (updateStreamingUI) updateStreamingUI({ ...message })
}
if (textContent?.text) {
textContent.text.value += reasoningProcessor.finalize()
if (updateStreamingUI) updateStreamingUI({ ...message })
}
}
if (newToolCalls.length > 0) {
builder.addAssistantMessage(followUpText, undefined, newToolCalls)
await postMessageProcessing(
newToolCalls,
builder,
message,
abortController,
approvedTools,
showModal,
allowAllMCPPermissions,
thread,
provider,
tools,
updateStreamingUI,
maxToolSteps,
isProactiveMode
)
}
}
} catch (error) {
console.error(
'Failed to get follow-up completion after tool execution:',
String(error)
)
}
}
}
// Reset counter when the chain is fully resolved
toolStepCounter = 0
return message
}
}