jan/web-app/src/hooks/useChat.ts
Dinh Long Nguyen 1b9efee52c feat: improve projects (#6698)
* decouple successfully

* only show movable projects for project items

* handle delete covnersations when projects is removed

* fix leftpanel assignemtn

* fix lint
2025-10-01 22:53:34 +07:00

597 lines
21 KiB
TypeScript

import { useCallback, useMemo } from 'react'
import { flushSync } from 'react-dom'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
isCompletionResponse,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { renderInstructions } from '@/lib/instructionTemplate'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useServiceHub } from '@/hooks/useServiceHub'
import { useToolApproval } from '@/hooks/useToolApproval'
import { useToolAvailable } from '@/hooks/useToolAvailable'
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
import { useContextSizeApproval } from './useModelContextApproval'
import { useModelLoad } from './useModelLoad'
import {
ReasoningProcessor,
extractReasoningFromMessage,
} from '@/utils/reasoning'
import { useAssistant } from './useAssistant'
import { useShallow } from 'zustand/shallow'
export const useChat = () => {
const [
updateTokenSpeed,
resetTokenSpeed,
updateStreamingContent,
updateLoadingModel,
setAbortController,
] = useAppState(
useShallow((state) => [
state.updateTokenSpeed,
state.resetTokenSpeed,
state.updateStreamingContent,
state.updateLoadingModel,
state.setAbortController,
])
)
const updatePromptProgress = useAppState(
(state) => state.updatePromptProgress
)
const updateProvider = useModelProvider((state) => state.updateProvider)
const serviceHub = useServiceHub()
const showApprovalModal = useToolApproval((state) => state.showApprovalModal)
const allowAllMCPPermissions = useToolApproval(
(state) => state.allowAllMCPPermissions
)
const showIncreaseContextSizeModal = useContextSizeApproval(
(state) => state.showApprovalModal
)
const getDisabledToolsForThread = useToolAvailable(
(state) => state.getDisabledToolsForThread
)
const getProviderByName = useModelProvider((state) => state.getProviderByName)
const [createThread, retrieveThread, updateThreadTimestamp] = useThreads(
useShallow((state) => [
state.createThread,
state.getCurrentThread,
state.updateThreadTimestamp,
])
)
const getMessages = useMessages((state) => state.getMessages)
const addMessage = useMessages((state) => state.addMessage)
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
const router = useRouter()
const getCurrentThread = useCallback(async (projectId?: string) => {
let currentThread = retrieveThread()
if (!currentThread) {
// Get prompt directly from store when needed
const currentPrompt = usePrompt.getState().prompt
const currentAssistant = useAssistant.getState().currentAssistant
const assistants = useAssistant.getState().assistants
const selectedModel = useModelProvider.getState().selectedModel
const selectedProvider = useModelProvider.getState().selectedProvider
// Get project metadata if projectId is provided
let projectMetadata: { id: string; name: string; updated_at: number } | undefined
if (projectId) {
const project = await serviceHub.projects().getProjectById(projectId)
if (project) {
projectMetadata = {
id: project.id,
name: project.name,
updated_at: project.updated_at,
}
}
}
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
currentPrompt,
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0],
projectMetadata,
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [createThread, retrieveThread, router])
const restartModel = useCallback(
async (provider: ProviderObject, modelId: string) => {
await serviceHub.models().stopAllModels()
await new Promise((resolve) => setTimeout(resolve, 1000))
updateLoadingModel(true)
await serviceHub
.models()
.startModel(provider, modelId)
.catch(console.error)
updateLoadingModel(false)
await new Promise((resolve) => setTimeout(resolve, 1000))
},
[updateLoadingModel, serviceHub]
)
const increaseModelContextSize = useCallback(
async (modelId: string, provider: ProviderObject) => {
/**
* Should increase the context size of the model by 2x
* If the context size is not set or too low, it defaults to 8192.
*/
const model = provider.models.find((m) => m.id === modelId)
if (!model) return undefined
const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === 'string'
? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number)
: 16384,
16384
)
const updatedModel = {
...model,
settings: {
...model.settings,
ctx_len: {
...(model.settings?.ctx_len != null ? model.settings?.ctx_len : {}),
controller_props: {
...(model.settings?.ctx_len?.controller_props ?? {}),
value: ctxSize * 2,
},
},
},
}
// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id)
if (modelIndex !== -1) {
// Create a copy of the provider's models array
const updatedModels = [...provider.models]
// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model
// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
})
}
const updatedProvider = getProviderByName(provider.provider)
if (updatedProvider) await restartModel(updatedProvider, model.id)
return updatedProvider
},
[getProviderByName, restartModel, updateProvider]
)
const toggleOnContextShifting = useCallback(
async (modelId: string, provider: ProviderObject) => {
const providerName = provider.provider
const newSettings = [...provider.settings]
const settingKey = 'ctx_shift'
// Handle different value types by forcing the type
// Use type assertion to bypass type checking
const settingIndex = provider.settings.findIndex(
(s) => s.key === settingKey
)
;(
newSettings[settingIndex].controller_props as {
value: string | boolean | number
}
).value = true
// Create update object with updated settings
const updateObj: Partial<ModelProvider> = {
settings: newSettings,
}
await serviceHub
.providers()
.updateSettings(providerName, updateObj.settings ?? [])
updateProvider(providerName, {
...provider,
...updateObj,
})
const updatedProvider = getProviderByName(providerName)
if (updatedProvider) await restartModel(updatedProvider, modelId)
return updatedProvider
},
[updateProvider, getProviderByName, restartModel, serviceHub]
)
const sendMessage = useCallback(
async (
message: string,
troubleshooting = true,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>,
projectId?: string
) => {
const activeThread = await getCurrentThread(projectId)
const selectedProvider = useModelProvider.getState().selectedProvider
let activeProvider = getProviderByName(selectedProvider)
resetTokenSpeed()
if (!activeThread || !activeProvider) return
const messages = getMessages(activeThread.id)
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message, attachments))
updateThreadTimestamp(activeThread.id)
usePrompt.getState().setPrompt('')
const selectedModel = useModelProvider.getState().selectedModel
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await serviceHub.models().startModel(activeProvider, selectedModel.id)
updateLoadingModel(false)
}
const currentAssistant = useAssistant.getState().currentAssistant
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant
? renderInstructions(currentAssistant.instructions)
: undefined
)
if (troubleshooting) builder.addUserMessage(message, attachments)
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools')
? useAppState.getState().tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
})
: []
let assistantLoopSteps = 0
while (
!isCompleted &&
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id
)
assistantLoopSteps += 1
const modelSettings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings)
.filter(
([key, value]) =>
key !== 'ctx_len' &&
key !== 'ngl' &&
value.controller_props?.value !== undefined &&
value.controller_props?.value !== null &&
value.controller_props?.value !== ''
)
.map(([key, value]) => [key, value.controller_props?.value])
)
: undefined
const completion = await sendCompletion(
activeThread,
activeProvider,
builder.getMessages(),
abortController,
availableTools,
currentAssistant?.parameters?.stream === false ? false : true,
{
...modelSettings,
...(currentAssistant?.parameters || {}),
} as unknown as Record<string, object>
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
try {
if (isCompletionResponse(completion)) {
const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || ''
// Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
}
if (message?.tool_calls) {
toolCalls.push(...message.tool_calls)
}
} else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
})
}
const flushIfPending = () => {
if (!rafScheduled) return
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
}
// Do an immediate flush
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
}
} catch (error) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
? error.message
: error
if (
typeof errorMessage === 'string' &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel
) {
const method = await showIncreaseContextSizeModal()
if (method === 'ctx_len') {
/// Increase context size
activeProvider = await increaseModelContextSize(
selectedModel.id,
activeProvider
)
continue
} else if (method === 'context_shift' && selectedModel?.id) {
/// Enable context_shift
activeProvider = await toggleOnContextShifting(
selectedModel?.id,
activeProvider
)
continue
} else throw error
} else {
throw error
}
}
// TODO: Remove this check when integrating new llama.cpp extension
if (
accumulatedText.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider?.provider === 'llamacpp'
) {
await serviceHub
.models()
.stopModel(activeThread.model.id, 'llamacpp')
throw new Error('No response received from the model')
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent,
abortController,
useToolApproval.getState().approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
// Check if assistant loop steps are within limits
if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) {
// Stop the assistant tool call if it exceeds the maximum steps
availableTools = []
}
}
} catch (error) {
if (!abortController.signal.aborted) {
if (error && typeof error === 'object' && 'message' in error) {
setModelLoadError(error as ErrorObject)
} else {
setModelLoadError(`${error}`)
}
}
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
updatePromptProgress(undefined)
}
},
[
getCurrentThread,
resetTokenSpeed,
getProviderByName,
getMessages,
setAbortController,
updateStreamingContent,
updatePromptProgress,
addMessage,
updateThreadTimestamp,
updateLoadingModel,
getDisabledToolsForThread,
allowAllMCPPermissions,
showApprovalModal,
updateTokenSpeed,
showIncreaseContextSizeModal,
increaseModelContextSize,
toggleOnContextShifting,
setModelLoadError,
serviceHub,
]
)
return useMemo(() => sendMessage, [sendMessage])
}