jan/web-app/src/hooks/useChat.ts
Louis 76bcf33f80
fix: generate response button disappear on tool call (#5988)
* fix: generate a response button should appear when an incomplete tool call message is present

* fix: wording

* fix: do not send duplicate messages on regenerating

* fix: tests
2025-07-30 21:04:12 +07:00

473 lines
16 KiB
TypeScript

import { useCallback, useEffect, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
isCompletionResponse,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models'
import { useToolApproval } from '@/hooks/useToolApproval'
import { useToolAvailable } from '@/hooks/useToolAvailable'
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
import { updateSettings } from '@/services/providers'
import { useContextSizeApproval } from './useModelContextApproval'
import { useModelLoad } from './useModelLoad'
import { useGeneralSetting } from './useGeneralSetting'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const { experimentalFeatures } = useGeneralSetting()
const {
tools,
updateTokenSpeed,
resetTokenSpeed,
updateTools,
updateStreamingContent,
updateLoadingModel,
setAbortController,
} = useAppState()
const { assistants, currentAssistant } = useAssistant()
const { updateProvider } = useModelProvider()
const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
useToolApproval()
const { showApprovalModal: showIncreaseContextSizeModal } =
useContextSizeApproval()
const { getDisabledToolsForThread } = useToolAvailable()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
const {
getCurrentThread: retrieveThread,
createThread,
updateThreadTimestamp,
} = useThreads()
const { getMessages, addMessage } = useMessages()
const { setModelLoadError } = useModelLoad()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const currentProviderId = useMemo(() => {
return provider?.provider || selectedProvider
}, [provider, selectedProvider])
const selectedAssistant =
assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt,
selectedAssistant
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
selectedAssistant,
])
const restartModel = useCallback(
async (provider: ProviderObject, modelId: string) => {
await stopAllModels()
await new Promise((resolve) => setTimeout(resolve, 1000))
updateLoadingModel(true)
await startModel(provider, modelId).catch(console.error)
updateLoadingModel(false)
await new Promise((resolve) => setTimeout(resolve, 1000))
},
[updateLoadingModel]
)
const increaseModelContextSize = useCallback(
async (modelId: string, provider: ProviderObject) => {
/**
* Should increase the context size of the model by 2x
* If the context size is not set or too low, it defaults to 8192.
*/
const model = provider.models.find((m) => m.id === modelId)
if (!model) return undefined
const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === 'string'
? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number)
: 16384,
16384
)
const updatedModel = {
...model,
settings: {
...model.settings,
ctx_len: {
...(model.settings?.ctx_len != null ? model.settings?.ctx_len : {}),
controller_props: {
...(model.settings?.ctx_len?.controller_props ?? {}),
value: ctxSize * 2,
},
},
},
}
// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id)
if (modelIndex !== -1) {
// Create a copy of the provider's models array
const updatedModels = [...provider.models]
// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model
// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
})
}
const updatedProvider = getProviderByName(provider.provider)
if (updatedProvider) await restartModel(updatedProvider, model.id)
return updatedProvider
},
[getProviderByName, restartModel, updateProvider]
)
const toggleOnContextShifting = useCallback(
async (modelId: string, provider: ProviderObject) => {
const providerName = provider.provider
const newSettings = [...provider.settings]
const settingKey = 'ctx_shift'
// Handle different value types by forcing the type
// Use type assertion to bypass type checking
const settingIndex = provider.settings.findIndex(
(s) => s.key === settingKey
)
;(
newSettings[settingIndex].controller_props as {
value: string | boolean | number
}
).value = true
// Create update object with updated settings
const updateObj: Partial<ModelProvider> = {
settings: newSettings,
}
await updateSettings(providerName, updateObj.settings ?? [])
updateProvider(providerName, {
...provider,
...updateObj,
})
const updatedProvider = getProviderByName(providerName)
if (updatedProvider) await restartModel(updatedProvider, modelId)
return updatedProvider
},
[updateProvider, getProviderByName, restartModel]
)
const sendMessage = useCallback(
async (message: string, troubleshooting = true) => {
const activeThread = await getCurrentThread()
resetTokenSpeed()
let activeProvider = currentProviderId
? getProviderByName(currentProviderId)
: provider
if (!activeThread || !activeProvider) return
const messages = getMessages(activeThread.id)
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message))
updateThreadTimestamp(activeThread.id)
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(activeProvider, selectedModel.id)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant?.instructions
)
if (troubleshooting) builder.addUserMessage(message)
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread
let availableTools =
experimentalFeatures && selectedModel?.capabilities?.includes('tools')
? tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
})
: []
// TODO: Later replaced by Agent setup?
const followUpWithToolUse = true
while (
!isCompleted &&
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id
)
const modelSettings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings)
.filter(
([key, value]) =>
key !== 'ctx_len' &&
key !== 'ngl' &&
value.controller_props?.value !== undefined &&
value.controller_props?.value !== null &&
value.controller_props?.value !== ''
)
.map(([key, value]) => [key, value.controller_props?.value])
)
: undefined
const completion = await sendCompletion(
activeThread,
activeProvider,
builder.getMessages(),
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
{
...modelSettings,
...currentAssistant.parameters,
} as unknown as Record<string, object>
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
try {
if (isCompletionResponse(completion)) {
accumulatedText =
(completion.choices[0]?.message?.content as string) || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
}
} else {
for await (const part of completion) {
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
const calls = extractToolCall(part, currentCall, toolCalls)
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: calls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
}
} catch (error) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
? error.message
: error
if (
typeof errorMessage === 'string' &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel &&
troubleshooting
) {
const method = await showIncreaseContextSizeModal()
if (method === 'ctx_len') {
/// Increase context size
activeProvider = await increaseModelContextSize(
selectedModel.id,
activeProvider
)
continue
} else if (method === 'context_shift' && selectedModel?.id) {
/// Enable context_shift
activeProvider = await toggleOnContextShifting(
selectedModel?.id,
activeProvider
)
continue
} else throw error
} else {
throw error
}
}
// TODO: Remove this check when integrating new llama.cpp extension
if (
accumulatedText.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
provider?.provider === 'llamacpp'
) {
await stopModel(activeThread.model.id, 'llamacpp')
throw new Error('No response received from the model')
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent,
abortController,
approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = []
}
} catch (error) {
if (!abortController.signal.aborted) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
? error.message
: error
setModelLoadError(`${errorMessage}`)
}
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
}
},
[
getCurrentThread,
resetTokenSpeed,
currentProviderId,
getProviderByName,
provider,
getMessages,
setAbortController,
updateStreamingContent,
addMessage,
updateThreadTimestamp,
setPrompt,
selectedModel,
currentAssistant,
tools,
updateLoadingModel,
getDisabledToolsForThread,
approvedTools,
allowAllMCPPermissions,
showApprovalModal,
updateTokenSpeed,
showIncreaseContextSizeModal,
increaseModelContextSize,
toggleOnContextShifting,
setModelLoadError,
]
)
return { sendMessage }
}