jan/web/hooks/useSendChatMessage.ts
Thien Tran dc23cc2716
Use token.js for non-tools calls (#4973)
* deprecate inference()

* fix tool_choice. only startModel for Cortex

* appease linter

* remove sse

* add stopInferencing support. temporarily with OpenAI

* use abortSignal in token.js

* bump token.js version
2025-05-15 17:11:19 +07:00

580 lines
18 KiB
TypeScript

import 'openai/shims/web'
import { useEffect, useRef } from 'react'
import {
MessageRequestType,
ExtensionTypeEnum,
Thread,
ThreadMessage,
Model,
ConversationalExtension,
ThreadAssistantInfo,
events,
MessageEvent,
EngineManager,
InferenceEngine,
MessageStatus,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import {
ChatCompletionMessageParam,
ChatCompletionTool,
ChatCompletionMessageToolCall,
} from 'openai/resources/chat'
import {
CompletionResponse,
StreamCompletionResponse,
TokenJS,
} from 'token.js'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
import {
currentPromptAtom,
editPromptAtom,
fileUploadAtom,
} from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64'
import {
createMessage,
createMessageContent,
emptyMessageContent,
} from '@/utils/createMessage'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel'
import {
convertBuiltInEngine,
extendBuiltInEngineModels,
useGetEngines,
} from './useEngineManagement'
import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
addNewMessageAtom,
deleteMessageAtom,
getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
activeThreadAtom,
approvedThreadToolsAtom,
disabledThreadToolsAtom,
engineParamsUpdateAtom,
getActiveThreadModelParamsAtom,
isGeneratingResponseAtom,
updateThreadAtom,
updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Thread.atom'
import { ModelTool } from '@/types/model'
export const reloadModelAtom = atom(false)
export default function useSendChatMessage(
showModal?: (toolName: string, threadId: string) => Promise<unknown>
) {
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const deleteMessage = useSetAtom(deleteMessageAtom)
const setEditPrompt = useSetAtom(editPromptAtom)
const approvedTools = useAtomValue(approvedThreadToolsAtom)
const disabledTools = useAtomValue(disabledThreadToolsAtom)
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel()
const modelRef = useRef<Model | undefined>()
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const engineParamsUpdate = useAtomValue(engineParamsUpdateAtom)
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const setReloadModel = useSetAtom(reloadModelAtom)
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>()
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
const setModelDropdownState = useSetAtom(modelDropdownStateAtom)
const selectedModelRef = useRef<Model | undefined>()
const { engines } = useGetEngines()
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
useEffect(() => {
activeThreadRef.current = activeThread
}, [activeThread])
useEffect(() => {
selectedModelRef.current = selectedModel
}, [selectedModel])
useEffect(() => {
activeAssistantRef.current = activeAssistant
}, [activeAssistant])
const resendChatMessage = async () => {
// Delete last response before regenerating
const newConvoData = Array.from(currentMessages)
let toSendMessage = newConvoData.pop()
while (toSendMessage && toSendMessage?.role !== 'user') {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
.catch(console.error)
deleteMessage(toSendMessage.id ?? '')
toSendMessage = newConvoData.pop()
}
if (toSendMessage?.content[0]?.text?.value)
sendChatMessage(toSendMessage.content[0].text.value, true, newConvoData)
}
const sendChatMessage = async (
message: string,
isResend: boolean = false,
messages?: ThreadMessage[]
) => {
if (!message || message.trim().length === 0) return
const activeThread = activeThreadRef.current
const activeAssistant = activeAssistantRef.current
const currentModel = selectedModelRef.current
if (!activeThread || !activeAssistant) {
console.error('No active thread or assistant')
return
}
if (!currentModel?.id) {
setModelDropdownState(true)
return
}
if (engineParamsUpdate) setReloadModel(true)
setTokenSpeed(undefined)
const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim()
updateThreadWaiting(activeThread.id, true)
setCurrentPrompt('')
setEditPrompt('')
try {
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
if (base64Blob && fileUpload?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
const modelRequest = selectedModel ?? activeAssistant.model
// Fallback support for previous broken threads
if (activeAssistant.model?.id === '*') {
activeAssistant.model = {
id: currentModel.id,
settings: currentModel.settings,
parameters: currentModel.parameters,
}
}
if (runtimeParams.stream == null) {
runtimeParams.stream = true
}
// Build Message Request
// TODO: detect if model supports tools
const tools = (await window.core.api.getTools())
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
...modelRequest,
settings: settingParams,
parameters: runtimeParams,
},
activeThread,
messages ?? currentMessages,
(tools && tools.length) ? tools : undefined,
).addSystemMessage(activeAssistant.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
requestBuilder
).pushMessage(prompt, base64Blob, fileUpload)
const newMessage = threadMessageBuilder.build()
// Update thread state
const updatedThread: Thread = {
...activeThread,
updated: newMessage.created_at,
metadata: {
...activeThread.metadata,
lastMessage: prompt,
},
}
updateThread(updatedThread)
if (
!isResend &&
(newMessage.content.length || newMessage.attachments?.length)
) {
// Add message
const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(newMessage)
.catch(() => undefined)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started
const isCortex = modelRequest.engine == InferenceEngine.cortex ||
modelRequest.engine == InferenceEngine.cortex_llamacpp
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId && modelId && isCortex) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThread.id, false)
return
}
}
setIsGeneratingResponse(true)
let isDone = false
const engine =
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
const apiKey = engine?.api_key
const provider = convertBuiltInEngine(engine?.engine)
const tokenJS = new TokenJS({
apiKey: apiKey ?? (await window.core.api.appToken()),
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
})
extendBuiltInEngineModels(tokenJS, provider, modelId)
// llama.cpp currently does not support streaming when tools are used.
const useStream = (requestBuilder.tools && isCortex) ?
false :
modelRequest.parameters?.stream
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = createMessage({
id: messageId,
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
})
events.emit(MessageEvent.OnMessageResponse, message)
// we need to separate into 2 cases to appease linter
const controller = new AbortController()
EngineManager.instance().controller = controller
if (useStream) {
const response = await tokenJS.chat.completions.create(
{
stream: true,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: data.tools ? 'auto' : undefined,
},
{
signal: controller.signal,
}
)
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processStreamingResponse(
response,
requestBuilder,
message
)
} else {
const response = await tokenJS.chat.completions.create(
{
stream: false,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: data.tools ? 'auto' : undefined,
},
{
signal: controller.signal,
}
)
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processNonStreamingResponse(
response,
requestBuilder,
message
)
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}
} catch (error) {
setIsGeneratingResponse(false)
updateThreadWaiting(activeThread.id, false)
const errorMessage: ThreadMessage = createMessage({
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
content: createMessageContent(
typeof error === 'object' && error && 'message' in error
? (error as { message: string }).message
: JSON.stringify(error)
),
})
events.emit(MessageEvent.OnMessageResponse, errorMessage)
errorMessage.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, errorMessage)
}
// Reset states
setReloadModel(false)
setEngineParamsUpdate(false)
}
const processNonStreamingResponse = async (
response: CompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Handle tool calls in the response
const toolCalls: ChatCompletionMessageToolCall[] =
response.choices[0]?.message?.tool_calls ?? []
const content = response.choices[0].message?.content
message.content = createMessageContent(content ?? '')
events.emit(MessageEvent.OnMessageUpdate, message)
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
content ?? ''
)
return !toolCalls || !toolCalls.length
}
const processStreamingResponse = async (
response: StreamCompletionResponse,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Variables to track and accumulate streaming content
let currentToolCall: {
id: string
function: { name: string; arguments: string }
} | null = null
let accumulatedContent = ''
const toolCalls: ChatCompletionMessageToolCall[] = []
// Process the streaming chunks
for await (const chunk of response) {
// Handle tool calls in the chunk
if (chunk.choices[0]?.delta?.tool_calls) {
const deltaToolCalls = chunk.choices[0].delta.tool_calls
// Handle the beginning of a new tool call
if (
deltaToolCalls[0]?.index !== undefined &&
deltaToolCalls[0]?.function
) {
const index = deltaToolCalls[0].index
// Create new tool call if this is the first chunk for it
if (!toolCalls[index]) {
toolCalls[index] = {
id: deltaToolCalls[0]?.id || '',
function: {
name: deltaToolCalls[0]?.function?.name || '',
arguments: deltaToolCalls[0]?.function?.arguments || '',
},
type: 'function',
}
currentToolCall = toolCalls[index]
} else {
// Continuation of existing tool call
currentToolCall = toolCalls[index]
// Append to function name or arguments if they exist in this chunk
if (deltaToolCalls[0]?.function?.name) {
currentToolCall!.function.name += deltaToolCalls[0].function.name
}
if (deltaToolCalls[0]?.function?.arguments) {
currentToolCall!.function.arguments +=
deltaToolCalls[0].function.arguments
}
}
}
}
// Handle regular content in the chunk
if (chunk.choices[0]?.delta?.content) {
const content = chunk.choices[0].delta.content
accumulatedContent += content
message.content = createMessageContent(accumulatedContent)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
accumulatedContent ?? ''
)
return !toolCalls || !toolCalls.length
}
const postMessageProcessing = async (
toolCalls: ChatCompletionMessageToolCall[],
requestBuilder: MessageRequestBuilder,
message: ThreadMessage,
content: string
) => {
requestBuilder.pushAssistantMessage({
content,
role: 'assistant',
refusal: null,
tool_calls: toolCalls,
})
// Handle completed tool calls
if (toolCalls.length > 0) {
for (const toolCall of toolCalls) {
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: undefined,
state: 'pending',
},
],
}
events.emit(MessageEvent.OnMessageUpdate, message)
const approved =
approvedTools[message.thread_id]?.includes(toolCall.function.name) ||
(showModal
? await showModal(toolCall.function.name, message.thread_id)
: true)
const result = approved
? await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
: {
content: [
{
type: 'text',
text: 'The user has chosen to disallow the tool call.',
},
],
}
if (result.error) break
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
requestBuilder.pushToolMessage(
result.content[0]?.text ?? '',
toolCall.id
)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
}
return {
sendChatMessage,
resendChatMessage,
}
}