jan/web-app/src/hooks/useChat.ts
Louis 035cc0f79c
Sync Release/v0.6.0 into dev (#5293)
* chore: enable shortcut zoom (#5261)

* chore: enable shortcut zoom

* chore: update shortcut setting

* fix: thinking block (#5263)

* Merge pull request #5262 from menloresearch/chore/sync-new-hub-data

chore: sync new hub data

* enhancement: model run improvement (#5268)

* fix: mcp tool error handling

* fix: error message

* fix: trigger download from recommend model

* fix: can't scroll hub

* fix: show progress

* enhancement: prompt users to increase context size

* enhancement: rearrange action buttons for a better UX

* 🔧chore: clean up logics

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>

* fix: glitch download from onboarding (#5269)

* enhancement: Model sources should not be hard coded from frontend (#5270)

* 🐛fix: default onboarding model should use recommended quantizations (#5273)

* 🐛fix: default onboarding model should use recommended quantizations

* enhancement: show context shift option in provider settings

* 🔧chore: wording

* 🔧 config: add to gitignore

* 🐛fix: Jan-nano repo name changed (#5274)

* 🚧 wip: disable showSpeedToken in ChatInput

* 🐛 fix: commented out the wrong import

* fix: masking value MCP env field (#5276)

*  feat: add token speed to each message that persist

* ♻️ refactor: to follow prettier convention

* 🐛 fix: exclude deleted field

* 🧹 clean: all the missed console.log

* enhancement: out of context troubleshooting (#5275)

* enhancement: out of context troubleshooting

* 🔧refactor: clean up

* enhancement: add setting chat width container (#5289)

* enhancement: add setting conversation width

* enahncement: cleanup log and change improve accesibility

* enahcement: move const beta version

* 🐛fix: optional additional_information gpu (#5291)

* 🐛fix: showing release notes for beta and prod (#5292)

* 🐛fix: showing release notes for beta and prod

* ♻️refactor: make an utils env

* ♻️refactor: hide MCP for production

* ♻️refactor: simplify the boolean expression fetch release note

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: LazyYuuki <huy2840@gmail.com>
Co-authored-by: Bui Quang Huy <34532913+LazyYuuki@users.noreply.github.com>
2025-06-16 17:27:42 +07:00

470 lines
15 KiB
TypeScript

import { useCallback, useEffect, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
isCompletionResponse,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { toast } from 'sonner'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models'
import { useToolApproval } from '@/hooks/useToolApproval'
import { useToolAvailable } from '@/hooks/useToolAvailable'
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
import { updateSettings } from '@/services/providers'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const {
tools,
updateTokenSpeed,
resetTokenSpeed,
updateTools,
updateStreamingContent,
updateLoadingModel,
setAbortController,
} = useAppState()
const { currentAssistant } = useAssistant()
const { updateProvider } = useModelProvider()
const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
useToolApproval()
const { getDisabledToolsForThread } = useToolAvailable()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
const {
getCurrentThread: retrieveThread,
createThread,
updateThreadTimestamp,
} = useThreads()
const { getMessages, addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const currentProviderId = useMemo(() => {
return provider?.provider || selectedProvider
}, [provider, selectedProvider])
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt,
currentAssistant
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
currentAssistant,
])
const restartModel = useCallback(
async (
provider: ProviderObject,
modelId: string,
abortController: AbortController
) => {
await stopAllModels()
await new Promise((resolve) => setTimeout(resolve, 1000))
updateLoadingModel(true)
await startModel(provider, modelId, abortController).catch(console.error)
updateLoadingModel(false)
await new Promise((resolve) => setTimeout(resolve, 1000))
},
[updateLoadingModel]
)
const increaseModelContextSize = useCallback(
async (
modelId: string,
provider: ProviderObject,
controller: AbortController
) => {
/**
* Should increase the context size of the model by 2x
* If the context size is not set or too low, it defaults to 8192.
*/
const model = provider.models.find((m) => m.id === modelId)
if (!model) return undefined
const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === 'string'
? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number)
: 16384,
16384
)
const updatedModel = {
...model,
settings: {
...model.settings,
ctx_len: {
...(model.settings?.ctx_len != null ? model.settings?.ctx_len : {}),
controller_props: {
...(model.settings?.ctx_len?.controller_props ?? {}),
value: ctxSize * 2,
},
},
},
}
// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id)
if (modelIndex !== -1) {
// Create a copy of the provider's models array
const updatedModels = [...provider.models]
// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model
// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
})
}
const updatedProvider = getProviderByName(provider.provider)
if (updatedProvider)
await restartModel(updatedProvider, model.id, controller)
console.log(
updatedProvider?.models.find((e) => e.id === model.id)?.settings
?.ctx_len?.controller_props.value
)
return updatedProvider
},
[getProviderByName, restartModel, updateProvider]
)
const toggleOnContextShifting = useCallback(
async (
modelId: string,
provider: ProviderObject,
controller: AbortController
) => {
const providerName = provider.provider
const newSettings = [...provider.settings]
const settingKey = 'context_shift'
// Handle different value types by forcing the type
// Use type assertion to bypass type checking
const settingIndex = provider.settings.findIndex(
(s) => s.key === settingKey
)
;(
newSettings[settingIndex].controller_props as {
value: string | boolean | number
}
).value = true
// Create update object with updated settings
const updateObj: Partial<ModelProvider> = {
settings: newSettings,
}
await updateSettings(providerName, updateObj.settings ?? [])
updateProvider(providerName, {
...provider,
...updateObj,
})
const updatedProvider = getProviderByName(providerName)
if (updatedProvider)
await restartModel(updatedProvider, modelId, controller)
return updatedProvider
},
[updateProvider, getProviderByName, restartModel]
)
const sendMessage = useCallback(
async (
message: string,
showModal?: () => Promise<unknown>,
troubleshooting = true
) => {
const activeThread = await getCurrentThread()
resetTokenSpeed()
let activeProvider = currentProviderId
? getProviderByName(currentProviderId)
: provider
if (!activeThread || !activeProvider) return
const messages = getMessages(activeThread.id)
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message))
updateThreadTimestamp(activeThread.id)
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(
activeProvider,
selectedModel.id,
abortController
).catch(console.error)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant?.instructions
)
builder.addUserMessage(message)
let isCompleted = false
// Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools')
? tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
})
: []
// TODO: Later replaced by Agent setup?
const followUpWithToolUse = true
while (
!isCompleted &&
!abortController.signal.aborted &&
activeProvider
) {
const completion = await sendCompletion(
activeThread,
activeProvider,
builder.getMessages(),
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object>
// TODO: replace it with according provider setting later on
// selectedProvider === 'llama.cpp' && availableTools.length > 0
// ? false
// : true
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
try {
if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
}
} else {
for await (const part of completion) {
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
const calls = extractToolCall(part, currentCall, toolCalls)
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: calls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
}
} catch (error) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
? error.message
: error
if (
typeof errorMessage === 'string' &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel &&
troubleshooting
) {
const method = await showModal?.()
if (method === 'ctx_len') {
/// Increase context size
activeProvider = await increaseModelContextSize(
selectedModel.id,
activeProvider,
abortController
)
continue
} else if (method === 'context_shift' && selectedModel?.id) {
/// Enable context_shift
activeProvider = await toggleOnContextShifting(
selectedModel?.id,
activeProvider,
abortController
)
continue
} else throw error
} else {
throw error
}
}
// TODO: Remove this check when integrating new llama.cpp extension
if (
accumulatedText.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider.provider === 'llama.cpp'
) {
await stopModel(activeThread.model.id, 'cortex')
throw new Error('No response received from the model')
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent,
abortController,
approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updateThreadTimestamp(activeThread.id)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = []
}
} catch (error) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
? error.message
: error
toast.error(`Error sending message: ${errorMessage}`)
console.error('Error sending message:', error)
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
}
},
[
getCurrentThread,
resetTokenSpeed,
currentProviderId,
getProviderByName,
provider,
getMessages,
setAbortController,
updateStreamingContent,
addMessage,
updateThreadTimestamp,
setPrompt,
selectedModel,
currentAssistant?.instructions,
currentAssistant.parameters,
tools,
updateLoadingModel,
getDisabledToolsForThread,
approvedTools,
allowAllMCPPermissions,
showApprovalModal,
updateTokenSpeed,
increaseModelContextSize,
toggleOnContextShifting,
]
)
return { sendMessage }
}