chore: Refactor chat flow – remove loop, centralize tool handling, add step limit
* Move the assistant‑loop logic out of `useChat` and into `postMessageProcessing`. * Eliminate the while‑loop that drove repeated completions; now a single completion is sent and subsequent tool calls are processed recursively. * Introduce early‑abort checks and guard against missing provider before proceeding. * Add `ReasoningProcessor` import and use it consistently for streaming reasoning chunks. * Add `ToolCallEntry` type and a global `toolStepCounter` to track and cap total tool steps (default 20) to prevent infinite loops. * Extend `postMessageProcessing` signature to accept thread, provider, tools, UI update callback, and max tool steps. * Update UI‑update logic to use a single `updateStreamingUI` callback and ensure RAF scheduling is cleaned up reliably. * Refactor token‑speed / progress handling, improve error handling for out‑of‑context situations, and tidy up code formatting. * Minor clean‑ups: const‑ify `availableTools`, remove unused variables, improve readability.
This commit is contained in:
parent
2f00ae0d33
commit
c129757097
@ -425,10 +425,8 @@ export const useChat = () => {
|
||||
// Using addUserMessage to respect legacy code. Should be using the userContent above.
|
||||
if (troubleshooting) builder.addUserMessage(userContent)
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
// Filter tools based on model capabilities and available tools for this thread
|
||||
let availableTools = selectedModel?.capabilities?.includes('tools')
|
||||
const availableTools = selectedModel?.capabilities?.includes('tools')
|
||||
? useAppState.getState().tools.filter((tool) => {
|
||||
const disabledTools = getDisabledToolsForThread(activeThread.id)
|
||||
return !disabledTools.includes(tool.name)
|
||||
@ -436,13 +434,21 @@ export const useChat = () => {
|
||||
: []
|
||||
|
||||
// Check if proactive mode is enabled
|
||||
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false
|
||||
const isProactiveMode =
|
||||
selectedModel?.capabilities?.includes('proactive') ?? false
|
||||
|
||||
// Proactive mode: Capture initial screenshot/snapshot before first LLM call
|
||||
if (isProactiveMode && availableTools.length > 0 && !abortController.signal.aborted) {
|
||||
console.log('Proactive mode: Capturing initial screenshots before LLM call')
|
||||
if (
|
||||
isProactiveMode &&
|
||||
availableTools.length > 0 &&
|
||||
!abortController.signal.aborted
|
||||
) {
|
||||
console.log(
|
||||
'Proactive mode: Capturing initial screenshots before LLM call'
|
||||
)
|
||||
try {
|
||||
const initialScreenshots = await captureProactiveScreenshots(abortController)
|
||||
const initialScreenshots =
|
||||
await captureProactiveScreenshots(abortController)
|
||||
|
||||
// Add initial screenshots to builder
|
||||
for (const screenshot of initialScreenshots) {
|
||||
@ -456,131 +462,91 @@ export const useChat = () => {
|
||||
}
|
||||
}
|
||||
|
||||
let assistantLoopSteps = 0
|
||||
// The agent logic is now self-contained within postMessageProcessing.
|
||||
// We no longer need a `while` loop here.
|
||||
|
||||
while (
|
||||
!isCompleted &&
|
||||
!abortController.signal.aborted &&
|
||||
activeProvider
|
||||
) {
|
||||
const modelConfig = activeProvider.models.find(
|
||||
(m) => m.id === selectedModel?.id
|
||||
)
|
||||
assistantLoopSteps += 1
|
||||
if (abortController.signal.aborted || !activeProvider) return
|
||||
|
||||
const modelSettings = modelConfig?.settings
|
||||
? Object.fromEntries(
|
||||
Object.entries(modelConfig.settings)
|
||||
.filter(
|
||||
([key, value]) =>
|
||||
key !== 'ctx_len' &&
|
||||
key !== 'ngl' &&
|
||||
value.controller_props?.value !== undefined &&
|
||||
value.controller_props?.value !== null &&
|
||||
value.controller_props?.value !== ''
|
||||
)
|
||||
.map(([key, value]) => [key, value.controller_props?.value])
|
||||
)
|
||||
: undefined
|
||||
const modelConfig = activeProvider.models.find(
|
||||
(m) => m.id === selectedModel?.id
|
||||
)
|
||||
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
activeProvider,
|
||||
builder.getMessages(),
|
||||
abortController,
|
||||
availableTools,
|
||||
currentAssistant?.parameters?.stream === false ? false : true,
|
||||
{
|
||||
...modelSettings,
|
||||
...(currentAssistant?.parameters || {}),
|
||||
} as unknown as Record<string, object>
|
||||
)
|
||||
const modelSettings = modelConfig?.settings
|
||||
? Object.fromEntries(
|
||||
Object.entries(modelConfig.settings)
|
||||
.filter(
|
||||
([key, value]) =>
|
||||
key !== 'ctx_len' &&
|
||||
key !== 'ngl' &&
|
||||
value.controller_props?.value !== undefined &&
|
||||
value.controller_props?.value !== null &&
|
||||
value.controller_props?.value !== ''
|
||||
)
|
||||
.map(([key, value]) => [key, value.controller_props?.value])
|
||||
)
|
||||
: undefined
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
const timeToFirstToken = Date.now()
|
||||
let tokenUsage: CompletionUsage | undefined = undefined
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
const message = completion.choices[0]?.message
|
||||
accumulatedText = (message?.content as string) || ''
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
activeProvider,
|
||||
builder.getMessages(),
|
||||
abortController,
|
||||
availableTools,
|
||||
currentAssistant?.parameters?.stream === false ? false : true,
|
||||
{
|
||||
...modelSettings,
|
||||
...(currentAssistant?.parameters || {}),
|
||||
} as unknown as Record<string, object>
|
||||
)
|
||||
|
||||
// Handle reasoning field if there is one
|
||||
const reasoning = extractReasoningFromMessage(message)
|
||||
if (reasoning) {
|
||||
accumulatedText =
|
||||
`<think>${reasoning}</think>` + accumulatedText
|
||||
}
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
const timeToFirstToken = Date.now()
|
||||
let tokenUsage: CompletionUsage | undefined = undefined
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
const message = completion.choices[0]?.message
|
||||
accumulatedText = (message?.content as string) || ''
|
||||
|
||||
if (message?.tool_calls) {
|
||||
toolCalls.push(...message.tool_calls)
|
||||
}
|
||||
if ('usage' in completion) {
|
||||
tokenUsage = completion.usage
|
||||
}
|
||||
} else {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
// Handle reasoning field if there is one
|
||||
const reasoning = extractReasoningFromMessage(message)
|
||||
if (reasoning) {
|
||||
accumulatedText = `<think>${reasoning}</think>` + accumulatedText
|
||||
}
|
||||
|
||||
if (message?.tool_calls) {
|
||||
toolCalls.push(...message.tool_calls)
|
||||
}
|
||||
if ('usage' in completion) {
|
||||
tokenUsage = completion.usage
|
||||
}
|
||||
} else {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (tokenUsage) {
|
||||
setTokenSpeed(
|
||||
currentContent,
|
||||
tokenUsage.completion_tokens /
|
||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||
tokenUsage.completion_tokens
|
||||
)
|
||||
} else if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
})
|
||||
}
|
||||
const flushIfPending = () => {
|
||||
if (!rafScheduled) return
|
||||
if (
|
||||
typeof cancelAnimationFrame !== 'undefined' &&
|
||||
rafHandle !== undefined
|
||||
) {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else if (rafHandle !== undefined) {
|
||||
clearTimeout(rafHandle)
|
||||
return
|
||||
}
|
||||
// Do an immediate flush
|
||||
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
@ -604,176 +570,207 @@ export const useChat = () => {
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if ('usage' in part && part.usage) {
|
||||
tokenUsage = part.usage
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
if (rafHandle !== undefined) {
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
// Only finalize and flush if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText += reasoningProcessor.finalize()
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error && typeof error === 'object' && 'message' in error
|
||||
? error.message
|
||||
: error
|
||||
if (
|
||||
typeof errorMessage === 'string' &&
|
||||
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
|
||||
selectedModel
|
||||
) {
|
||||
const method = await showIncreaseContextSizeModal()
|
||||
if (method === 'ctx_len') {
|
||||
/// Increase context size
|
||||
activeProvider = await increaseModelContextSize(
|
||||
selectedModel.id,
|
||||
activeProvider
|
||||
const flushIfPending = () => {
|
||||
if (!rafScheduled) return
|
||||
if (
|
||||
typeof cancelAnimationFrame !== 'undefined' &&
|
||||
rafHandle !== undefined
|
||||
) {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else if (rafHandle !== undefined) {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
// Do an immediate flush
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (tokenUsage) {
|
||||
setTokenSpeed(
|
||||
currentContent,
|
||||
tokenUsage.completion_tokens /
|
||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||
tokenUsage.completion_tokens
|
||||
)
|
||||
continue
|
||||
} else if (method === 'context_shift' && selectedModel?.id) {
|
||||
/// Enable context_shift
|
||||
activeProvider = await toggleOnContextShifting(
|
||||
selectedModel?.id,
|
||||
activeProvider
|
||||
)
|
||||
continue
|
||||
} else throw error
|
||||
} else {
|
||||
throw error
|
||||
} else if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if ('usage' in part && part.usage) {
|
||||
tokenUsage = part.usage
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
if (rafHandle !== undefined) {
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
// Only finalize and flush if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText += reasoningProcessor.finalize()
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Remove this check when integrating new llama.cpp extension
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error && typeof error === 'object' && 'message' in error
|
||||
? error.message
|
||||
: error
|
||||
if (
|
||||
accumulatedText.length === 0 &&
|
||||
toolCalls.length === 0 &&
|
||||
activeThread.model?.id &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
typeof errorMessage === 'string' &&
|
||||
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
|
||||
selectedModel
|
||||
) {
|
||||
await serviceHub
|
||||
.models()
|
||||
.stopModel(activeThread.model.id, 'llamacpp')
|
||||
throw new Error('No response received from the model')
|
||||
}
|
||||
|
||||
const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time
|
||||
|
||||
// Create a final content object for adding to the thread
|
||||
const messageMetadata: Record<string, any> = {
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
|
||||
if (accumulatedText.includes('<think>') || toolCalls.length > 0) {
|
||||
messageMetadata.totalThinkingTime = totalThinkingTime
|
||||
}
|
||||
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
messageMetadata
|
||||
)
|
||||
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
|
||||
// Check if proactive mode is enabled for this model
|
||||
const isProactiveMode = selectedModel?.capabilities?.includes('proactive') ?? false
|
||||
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent,
|
||||
abortController,
|
||||
useToolApproval.getState().approvedTools,
|
||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||
allowAllMCPPermissions,
|
||||
isProactiveMode
|
||||
)
|
||||
|
||||
if (updatedMessage && updatedMessage.metadata) {
|
||||
if (finalContent.metadata?.totalThinkingTime !== undefined) {
|
||||
updatedMessage.metadata.totalThinkingTime =
|
||||
finalContent.metadata.totalThinkingTime
|
||||
}
|
||||
}
|
||||
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
// Do not create agent loop if there is no need for it
|
||||
// Check if assistant loop steps are within limits
|
||||
if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) {
|
||||
// Stop the assistant tool call if it exceeds the maximum steps
|
||||
availableTools = []
|
||||
const method = await showIncreaseContextSizeModal()
|
||||
if (method === 'ctx_len') {
|
||||
/// Increase context size
|
||||
activeProvider = await increaseModelContextSize(
|
||||
selectedModel.id,
|
||||
activeProvider
|
||||
)
|
||||
// NOTE: This will exit and not retry. A more robust solution might re-call sendMessage.
|
||||
// For this change, we keep the existing behavior.
|
||||
return
|
||||
} else if (method === 'context_shift' && selectedModel?.id) {
|
||||
/// Enable context_shift
|
||||
activeProvider = await toggleOnContextShifting(
|
||||
selectedModel?.id,
|
||||
activeProvider
|
||||
)
|
||||
// NOTE: See above comment about retry.
|
||||
return
|
||||
} else throw error
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
// TODO: Remove this check when integrating new llama.cpp extension
|
||||
if (
|
||||
accumulatedText.length === 0 &&
|
||||
toolCalls.length === 0 &&
|
||||
activeThread.model?.id &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
) {
|
||||
await serviceHub.models().stopModel(activeThread.model.id, 'llamacpp')
|
||||
throw new Error('No response received from the model')
|
||||
}
|
||||
|
||||
const totalThinkingTime = Date.now() - startTime // Calculate total elapsed time
|
||||
|
||||
const messageMetadata: Record<string, any> = {
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
|
||||
if (accumulatedText.includes('<think>') || toolCalls.length > 0) {
|
||||
messageMetadata.totalThinkingTime = totalThinkingTime
|
||||
}
|
||||
|
||||
// This is the message object that will be built upon by postMessageProcessing
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
messageMetadata
|
||||
)
|
||||
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
// All subsequent tool calls and follow-up completions will modify `finalContent`.
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent,
|
||||
abortController,
|
||||
useToolApproval.getState().approvedTools,
|
||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||
allowAllMCPPermissions,
|
||||
activeThread,
|
||||
activeProvider,
|
||||
availableTools,
|
||||
updateStreamingContent, // Pass the callback to update UI
|
||||
currentAssistant?.tool_steps,
|
||||
isProactiveMode
|
||||
)
|
||||
|
||||
if (updatedMessage && updatedMessage.metadata) {
|
||||
if (finalContent.metadata?.totalThinkingTime !== undefined) {
|
||||
updatedMessage.metadata.totalThinkingTime =
|
||||
finalContent.metadata.totalThinkingTime
|
||||
}
|
||||
}
|
||||
|
||||
// Add the single, final, composite message to the store.
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
} catch (error) {
|
||||
if (!abortController.signal.aborted) {
|
||||
if (error && typeof error === 'object' && 'message' in error) {
|
||||
|
||||
@ -41,6 +41,7 @@ import { useAppState } from '@/hooks/useAppState'
|
||||
import { injectFilesIntoPrompt } from './fileMetadata'
|
||||
import { Attachment } from '@/types/attachment'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
import { ReasoningProcessor } from '@/utils/reasoning'
|
||||
|
||||
export type ChatCompletionResponse =
|
||||
| chatCompletion
|
||||
@ -48,6 +49,12 @@ export type ChatCompletionResponse =
|
||||
| StreamCompletionResponse
|
||||
| CompletionResponse
|
||||
|
||||
type ToolCallEntry = {
|
||||
tool: object
|
||||
response: any
|
||||
state: 'pending' | 'ready'
|
||||
}
|
||||
|
||||
/**
|
||||
* @fileoverview Helper functions for creating thread content.
|
||||
* These functions are used to create thread content objects
|
||||
@ -73,11 +80,14 @@ export const newUserThreadContent = (
|
||||
name: doc.name,
|
||||
type: doc.fileType,
|
||||
size: typeof doc.size === 'number' ? doc.size : undefined,
|
||||
chunkCount: typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined,
|
||||
chunkCount:
|
||||
typeof doc.chunkCount === 'number' ? doc.chunkCount : undefined,
|
||||
}))
|
||||
|
||||
const textWithFiles =
|
||||
docMetadata.length > 0 ? injectFilesIntoPrompt(content, docMetadata) : content
|
||||
docMetadata.length > 0
|
||||
? injectFilesIntoPrompt(content, docMetadata)
|
||||
: content
|
||||
|
||||
const contentParts = [
|
||||
{
|
||||
@ -238,10 +248,8 @@ export const sendCompletion = async (
|
||||
const providerModelConfig = provider.models?.find(
|
||||
(model) => model.id === thread.model?.id || model.model === thread.model?.id
|
||||
)
|
||||
const effectiveCapabilities = Array.isArray(
|
||||
providerModelConfig?.capabilities
|
||||
)
|
||||
? providerModelConfig?.capabilities ?? []
|
||||
const effectiveCapabilities = Array.isArray(providerModelConfig?.capabilities)
|
||||
? (providerModelConfig?.capabilities ?? [])
|
||||
: getModelCapabilities(provider.provider, thread.model.id)
|
||||
const modelSupportsTools = effectiveCapabilities.includes(
|
||||
ModelCapabilities.TOOLS
|
||||
@ -254,7 +262,10 @@ export const sendCompletion = async (
|
||||
PlatformFeatures[PlatformFeature.ATTACHMENTS] &&
|
||||
modelSupportsTools
|
||||
) {
|
||||
const ragTools = await getServiceHub().rag().getTools().catch(() => [])
|
||||
const ragTools = await getServiceHub()
|
||||
.rag()
|
||||
.getTools()
|
||||
.catch(() => [])
|
||||
if (Array.isArray(ragTools) && ragTools.length) {
|
||||
usableTools = [...tools, ...ragTools]
|
||||
}
|
||||
@ -396,6 +407,9 @@ export const extractToolCall = (
|
||||
return calls
|
||||
}
|
||||
|
||||
// Keep track of total tool steps to prevent infinite loops
|
||||
let toolStepCounter = 0
|
||||
|
||||
/**
|
||||
* Helper function to check if a tool call is a browser MCP tool
|
||||
* @param toolName - The name of the tool
|
||||
@ -533,10 +547,22 @@ export const postMessageProcessing = async (
|
||||
toolParameters?: object
|
||||
) => Promise<boolean>,
|
||||
allowAllMCPPermissions: boolean = false,
|
||||
thread?: Thread,
|
||||
provider?: ModelProvider,
|
||||
tools: MCPTool[] = [],
|
||||
updateStreamingUI?: (content: ThreadMessage) => void,
|
||||
maxToolSteps: number = 20,
|
||||
isProactiveMode: boolean = false
|
||||
) => {
|
||||
): Promise<ThreadMessage> => {
|
||||
// Reset counter at the start of a new message processing chain
|
||||
if (toolStepCounter === 0) {
|
||||
toolStepCounter = 0
|
||||
}
|
||||
|
||||
// Handle completed tool calls
|
||||
if (calls.length) {
|
||||
if (calls.length > 0) {
|
||||
toolStepCounter++
|
||||
|
||||
// Fetch RAG tool names from RAG service
|
||||
let ragToolNames = new Set<string>()
|
||||
try {
|
||||
@ -546,43 +572,41 @@ export const postMessageProcessing = async (
|
||||
console.error('Failed to load RAG tool names:', e)
|
||||
}
|
||||
const ragFeatureAvailable =
|
||||
useAttachments.getState().enabled && PlatformFeatures[PlatformFeature.ATTACHMENTS]
|
||||
useAttachments.getState().enabled &&
|
||||
PlatformFeatures[PlatformFeature.ATTACHMENTS]
|
||||
|
||||
const currentToolCalls =
|
||||
message.metadata?.tool_calls && Array.isArray(message.metadata.tool_calls)
|
||||
? [...message.metadata.tool_calls]
|
||||
: []
|
||||
|
||||
for (const toolCall of calls) {
|
||||
if (abortController.signal.aborted) break
|
||||
const toolId = ulid()
|
||||
const toolCallsMetadata =
|
||||
message.metadata?.tool_calls &&
|
||||
Array.isArray(message.metadata?.tool_calls)
|
||||
? message.metadata?.tool_calls
|
||||
: []
|
||||
|
||||
const toolCallEntry: ToolCallEntry = {
|
||||
tool: {
|
||||
...(toolCall as object),
|
||||
id: toolId,
|
||||
},
|
||||
response: undefined,
|
||||
state: 'pending' as 'pending' | 'ready',
|
||||
}
|
||||
currentToolCalls.push(toolCallEntry)
|
||||
|
||||
message.metadata = {
|
||||
...(message.metadata ?? {}),
|
||||
tool_calls: [
|
||||
...toolCallsMetadata,
|
||||
{
|
||||
tool: {
|
||||
...(toolCall as object),
|
||||
id: toolId,
|
||||
},
|
||||
response: undefined,
|
||||
state: 'pending',
|
||||
},
|
||||
],
|
||||
tool_calls: currentToolCalls,
|
||||
}
|
||||
if (updateStreamingUI) updateStreamingUI({ ...message }) // Show pending call
|
||||
|
||||
// Check if tool is approved or show modal for approval
|
||||
let toolParameters = {}
|
||||
if (toolCall.function.arguments.length) {
|
||||
try {
|
||||
console.log('Raw tool arguments:', toolCall.function.arguments)
|
||||
toolParameters = JSON.parse(toolCall.function.arguments)
|
||||
console.log('Parsed tool parameters:', toolParameters)
|
||||
} catch (error) {
|
||||
console.error('Failed to parse tool arguments:', error)
|
||||
console.error(
|
||||
'Raw arguments that failed:',
|
||||
toolCall.function.arguments
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -591,7 +615,6 @@ export const postMessageProcessing = async (
|
||||
const isRagTool = ragToolNames.has(toolName)
|
||||
const isBrowserTool = isBrowserMCPTool(toolName)
|
||||
|
||||
// Auto-approve RAG tools (local/safe operations), require permission for MCP tools
|
||||
const approved = isRagTool
|
||||
? true
|
||||
: allowAllMCPPermissions ||
|
||||
@ -607,7 +630,11 @@ export const postMessageProcessing = async (
|
||||
const { promise, cancel } = isRagTool
|
||||
? ragFeatureAvailable
|
||||
? {
|
||||
promise: getServiceHub().rag().callTool({ toolName, arguments: toolArgs, threadId: message.thread_id }),
|
||||
promise: getServiceHub().rag().callTool({
|
||||
toolName,
|
||||
arguments: toolArgs,
|
||||
threadId: message.thread_id,
|
||||
}),
|
||||
cancel: async () => {},
|
||||
}
|
||||
: {
|
||||
@ -630,18 +657,15 @@ export const postMessageProcessing = async (
|
||||
useAppState.getState().setCancelToolCall(cancel)
|
||||
|
||||
let result = approved
|
||||
? await promise.catch((e) => {
|
||||
console.error('Tool call failed:', e)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error calling tool ${toolCall.function.name}: ${e.message ?? e}`,
|
||||
},
|
||||
],
|
||||
error: String(e?.message ?? e ?? 'Tool call failed'),
|
||||
}
|
||||
})
|
||||
? await promise.catch((e) => ({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error calling tool ${toolCall.function.name}: ${e.message ?? e}`,
|
||||
},
|
||||
],
|
||||
error: String(e?.message ?? e ?? 'Tool call failed'),
|
||||
}))
|
||||
: {
|
||||
content: [
|
||||
{
|
||||
@ -654,30 +678,15 @@ export const postMessageProcessing = async (
|
||||
|
||||
if (typeof result === 'string') {
|
||||
result = {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: result,
|
||||
},
|
||||
],
|
||||
content: [{ type: 'text', text: result }],
|
||||
error: '',
|
||||
}
|
||||
}
|
||||
|
||||
message.metadata = {
|
||||
...(message.metadata ?? {}),
|
||||
tool_calls: [
|
||||
...toolCallsMetadata,
|
||||
{
|
||||
tool: {
|
||||
...toolCall,
|
||||
id: toolId,
|
||||
},
|
||||
response: result,
|
||||
state: 'ready',
|
||||
},
|
||||
],
|
||||
}
|
||||
// Update the entry in the metadata array
|
||||
toolCallEntry.response = result
|
||||
toolCallEntry.state = 'ready'
|
||||
if (updateStreamingUI) updateStreamingUI({ ...message }) // Show result
|
||||
builder.addToolMessage(result as ToolResult, toolCall.id)
|
||||
|
||||
// Proactive mode: Capture screenshot/snapshot after browser tool execution
|
||||
@ -702,6 +711,98 @@ export const postMessageProcessing = async (
|
||||
|
||||
// update message metadata
|
||||
}
|
||||
return message
|
||||
|
||||
if (
|
||||
thread &&
|
||||
provider &&
|
||||
!abortController.signal.aborted &&
|
||||
toolStepCounter < maxToolSteps
|
||||
) {
|
||||
try {
|
||||
const messagesWithToolResults = builder.getMessages()
|
||||
|
||||
const followUpCompletion = await sendCompletion(
|
||||
thread,
|
||||
provider,
|
||||
messagesWithToolResults,
|
||||
abortController,
|
||||
tools,
|
||||
true,
|
||||
{}
|
||||
)
|
||||
|
||||
if (followUpCompletion) {
|
||||
let followUpText = ''
|
||||
const newToolCalls: ChatCompletionMessageToolCall[] = []
|
||||
const textContent = message.content.find(
|
||||
(c) => c.type === ContentType.Text
|
||||
)
|
||||
|
||||
if (isCompletionResponse(followUpCompletion)) {
|
||||
const choice = followUpCompletion.choices[0]
|
||||
const content = choice?.message?.content
|
||||
if (content) followUpText = content as string
|
||||
if (choice?.message?.tool_calls) {
|
||||
newToolCalls.push(...choice.message.tool_calls)
|
||||
}
|
||||
if (textContent?.text) textContent.text.value += followUpText
|
||||
if (updateStreamingUI) updateStreamingUI({ ...message })
|
||||
} else {
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
for await (const chunk of followUpCompletion) {
|
||||
if (abortController.signal.aborted) break
|
||||
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(chunk)
|
||||
const deltaContent = chunk.choices[0]?.delta?.content || ''
|
||||
|
||||
if (textContent?.text) {
|
||||
if (deltaReasoning) textContent.text.value += deltaReasoning
|
||||
if (deltaContent) textContent.text.value += deltaContent
|
||||
}
|
||||
if (deltaContent) followUpText += deltaContent
|
||||
|
||||
if (chunk.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(chunk, null, newToolCalls)
|
||||
}
|
||||
|
||||
if (updateStreamingUI) updateStreamingUI({ ...message })
|
||||
}
|
||||
if (textContent?.text) {
|
||||
textContent.text.value += reasoningProcessor.finalize()
|
||||
if (updateStreamingUI) updateStreamingUI({ ...message })
|
||||
}
|
||||
}
|
||||
|
||||
if (newToolCalls.length > 0) {
|
||||
builder.addAssistantMessage(followUpText, undefined, newToolCalls)
|
||||
await postMessageProcessing(
|
||||
newToolCalls,
|
||||
builder,
|
||||
message,
|
||||
abortController,
|
||||
approvedTools,
|
||||
showModal,
|
||||
allowAllMCPPermissions,
|
||||
thread,
|
||||
provider,
|
||||
tools,
|
||||
updateStreamingUI,
|
||||
maxToolSteps,
|
||||
isProactiveMode
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'Failed to get follow-up completion after tool execution:',
|
||||
String(error)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset counter when the chain is fully resolved
|
||||
toolStepCounter = 0
|
||||
return message
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user