diff --git a/core/src/types/inference/inferenceEntity.ts b/core/src/types/inference/inferenceEntity.ts index c37e3b079..ac2e48d32 100644 --- a/core/src/types/inference/inferenceEntity.ts +++ b/core/src/types/inference/inferenceEntity.ts @@ -7,6 +7,7 @@ export enum ChatCompletionRole { System = 'system', Assistant = 'assistant', User = 'user', + Tool = 'tool', } /** @@ -18,6 +19,9 @@ export type ChatCompletionMessage = { content?: ChatCompletionMessageContent /** The role of the author of this message. **/ role: ChatCompletionRole + type?: string + output?: string + tool_call_id?: string } export type ChatCompletionMessageContent = diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 280ce75a3..20979c68e 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -36,6 +36,8 @@ export type ThreadMessage = { type?: string /** The error code which explain what error type. Used in conjunction with MessageStatus.Error */ error_code?: ErrorCode + + tool_call_id?: string } /** diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 786dbd4f0..a1fffa011 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -114,7 +114,7 @@ export default function ModelHandler() { const onNewMessageResponse = useCallback( async (message: ThreadMessage) => { - if (message.type === MessageRequestType.Thread) { + if (message.type !== MessageRequestType.Summary) { addNewMessage(message) } }, @@ -129,35 +129,20 @@ export default function ModelHandler() { const updateThreadTitle = useCallback( (message: ThreadMessage) => { // Update only when it's finished - if (message.status !== MessageStatus.Ready) { - return - } + if (message.status !== MessageStatus.Ready) return const thread = threadsRef.current?.find((e) => e.id == message.thread_id) - if (!thread) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Thread not found!` - ) - return - } - let messageContent = message.content[0]?.text?.value - if (!messageContent) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Responded content is null!` - ) - return - } + if (!thread || !messageContent) return // No new line character is presented in the title // And non-alphanumeric characters should be removed - if (messageContent.includes('\n')) { + if (messageContent.includes('\n')) messageContent = messageContent.replace(/\n/g, ' ') - } + const match = messageContent.match(/<\/think>(.*)$/) - if (match) { - messageContent = match[1] - } + if (match) messageContent = match[1] + // Remove non-alphanumeric characters const cleanedMessageContent = messageContent .replace(/[^\p{L}\s]+/gu, '') @@ -193,18 +178,13 @@ export default function ModelHandler() { const updateThreadMessage = useCallback( (message: ThreadMessage) => { - if ( - messageGenerationSubscriber.current && - message.thread_id === activeThreadRef.current?.id && - !messageGenerationSubscriber.current!.thread_id - ) { - updateMessage( - message.id, - message.thread_id, - message.content, - message.status - ) - } + updateMessage( + message.id, + message.thread_id, + message.content, + message.metadata, + message.status + ) if (message.status === MessageStatus.Pending) { if (message.content.length) { @@ -243,16 +223,19 @@ export default function ModelHandler() { engines && isLocalEngine(engines, activeModelRef.current.engine) ) { - ;(async () => { - if ( - !(await extensionManager - .get(ExtensionTypeEnum.Model) - ?.isModelLoaded(activeModelRef.current?.id as string)) - ) { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: undefined }) - } - })() + extensionManager + .get(ExtensionTypeEnum.Model) + ?.isModelLoaded(activeModelRef.current?.id as string) + .then((isLoaded) => { + if (!isLoaded) { + setActiveModel(undefined) + setStateModel({ + state: 'start', + loading: false, + model: undefined, + }) + } + }) } // Mark the thread as not waiting for response updateThreadWaiting(message.thread_id, false) @@ -296,19 +279,10 @@ export default function ModelHandler() { error_code: message.error_code, } } - ;(async () => { - const updatedMessage = await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.createMessage(message) - .catch(() => undefined) - if (updatedMessage) { - deleteMessage(message.id) - addNewMessage(updatedMessage) - setTokenSpeed((prev) => - prev ? { ...prev, message: updatedMessage.id } : undefined - ) - } - })() + + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(message) // Attempt to generate the title of the Thread when needed generateThreadTitle(message, thread) @@ -319,25 +293,21 @@ export default function ModelHandler() { const onMessageResponseUpdate = useCallback( (message: ThreadMessage) => { - switch (message.type) { - case MessageRequestType.Summary: - updateThreadTitle(message) - break - default: - updateThreadMessage(message) - break - } + if (message.type === MessageRequestType.Summary) + updateThreadTitle(message) + else updateThreadMessage(message) }, [updateThreadMessage, updateThreadTitle] ) const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { // If this is the first ever prompt in the thread - if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle) + if ( + !activeModelRef.current || + (thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle + ) return - if (!activeModelRef.current) return - // Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp if ( activeModelRef.current?.engine !== InferenceEngine.cortex && diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 1847aa422..faae6e298 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -165,6 +165,7 @@ export const updateMessageAtom = atom( id: string, conversationId: string, text: ThreadContent[], + metadata: Record | undefined, status: MessageStatus ) => { const messages = get(chatMessages)[conversationId] ?? [] @@ -172,6 +173,7 @@ export const updateMessageAtom = atom( if (message) { message.content = text message.status = status + message.metadata = metadata const updatedMessages = [...messages] const newData: Record = { @@ -192,6 +194,7 @@ export const updateMessageAtom = atom( created_at: Date.now() / 1000, completed_at: Date.now() / 1000, object: 'thread.message', + metadata: metadata, }) } } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 3242b085c..a76e01325 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -26,6 +26,7 @@ import { ChatCompletionTool, } from 'openai/resources/chat' +import { Stream } from 'openai/streaming' import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -258,111 +259,63 @@ export default function useSendChatMessage() { baseURL: `${API_BASE_URL}/v1`, dangerouslyAllowBrowser: true, }) + let parentMessageId: string | undefined while (!isDone) { + let messageId = ulid() + if (!parentMessageId) { + parentMessageId = ulid() + messageId = parentMessageId + } const data = requestBuilder.build() + const message: ThreadMessage = { + id: messageId, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + role: ChatCompletionRole.Assistant, + content: [], + metadata: { + ...(messageId !== parentMessageId + ? { parent_id: parentMessageId } + : {}), + }, + status: MessageStatus.Pending, + created_at: Date.now() / 1000, + completed_at: Date.now() / 1000, + } + events.emit(MessageEvent.OnMessageResponse, message) const response = await openai.chat.completions.create({ - messages: (data.messages ?? []).map((e) => { - return { - role: e.role as OpenAIChatCompletionRole, - content: e.content, - } - }) as ChatCompletionMessageParam[], + messages: requestBuilder.messages as ChatCompletionMessageParam[], model: data.model?.id ?? '', tools: data.tools as ChatCompletionTool[], - stream: false, + stream: data.model?.parameters?.stream ?? false, + tool_choice: 'auto', }) - if (response.choices[0]?.message.content) { - const newMessage: ThreadMessage = { - id: ulid(), - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: response.choices[0].message.role as ChatCompletionRole, - content: [ - { - type: ContentType.Text, - text: { - value: response.choices[0].message.content - ? response.choices[0].message.content - : '', - annotations: [], - }, + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = [ + { + type: ContentType.Text, + text: { + value: '', + annotations: [], }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage( - response.choices[0].message.content ?? '' + }, + ] + } + if (data.model?.parameters?.stream) + isDone = await processStreamingResponse( + response as Stream, + requestBuilder, + message + ) + else { + isDone = await processNonStreamingResponse( + response as OpenAI.Chat.Completions.ChatCompletion, + requestBuilder, + message ) - events.emit(MessageEvent.OnMessageUpdate, newMessage) } - - if (response.choices[0]?.message.tool_calls) { - for (const toolCall of response.choices[0].message.tool_calls) { - const id = ulid() - const toolMessage: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, - annotations: [], - }, - }, - ], - status: MessageStatus.Pending, - created_at: Date.now(), - completed_at: Date.now(), - } - events.emit(MessageEvent.OnMessageUpdate, toolMessage) - const result = await window.core.api.callTool({ - toolName: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments), - }) - if (result.error) { - console.error(result.error) - break - } - const message: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: - `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + - (result.content[0]?.text ?? ''), - annotations: [], - }, - }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') - requestBuilder.pushMessage('Go for the next step') - events.emit(MessageEvent.OnMessageUpdate, message) - } - } - - isDone = - !response.choices[0]?.message.tool_calls || - !response.choices[0]?.message.tool_calls.length } } else { // Request for inference @@ -376,6 +329,184 @@ export default function useSendChatMessage() { setEngineParamsUpdate(false) } + const processNonStreamingResponse = async ( + response: OpenAI.Chat.Completions.ChatCompletion, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Handle tool calls in the response + const toolCalls: ChatCompletionMessageToolCall[] = + response.choices[0]?.message?.tool_calls ?? [] + const content = response.choices[0].message?.content + message.content = [ + { + type: ContentType.Text, + text: { + value: content ?? '', + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + content ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const processStreamingResponse = async ( + response: Stream, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Variables to track and accumulate streaming content + let currentToolCall: { + id: string + function: { name: string; arguments: string } + } | null = null + let accumulatedContent = '' + const toolCalls: ChatCompletionMessageToolCall[] = [] + // Process the streaming chunks + for await (const chunk of response) { + // Handle tool calls in the chunk + if (chunk.choices[0]?.delta?.tool_calls) { + const deltaToolCalls = chunk.choices[0].delta.tool_calls + + // Handle the beginning of a new tool call + if ( + deltaToolCalls[0]?.index !== undefined && + deltaToolCalls[0]?.function + ) { + const index = deltaToolCalls[0].index + + // Create new tool call if this is the first chunk for it + if (!toolCalls[index]) { + toolCalls[index] = { + id: deltaToolCalls[0]?.id || '', + function: { + name: deltaToolCalls[0]?.function?.name || '', + arguments: deltaToolCalls[0]?.function?.arguments || '', + }, + type: 'function', + } + currentToolCall = toolCalls[index] + } else { + // Continuation of existing tool call + currentToolCall = toolCalls[index] + + // Append to function name or arguments if they exist in this chunk + if (deltaToolCalls[0]?.function?.name) { + currentToolCall!.function.name += deltaToolCalls[0].function.name + } + + if (deltaToolCalls[0]?.function?.arguments) { + currentToolCall!.function.arguments += + deltaToolCalls[0].function.arguments + } + } + } + } + + // Handle regular content in the chunk + if (chunk.choices[0]?.delta?.content) { + const content = chunk.choices[0].delta.content + accumulatedContent += content + + message.content = [ + { + type: ContentType.Text, + text: { + value: accumulatedContent, + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + accumulatedContent ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const postMessageProcessing = async ( + toolCalls: ChatCompletionMessageToolCall[], + requestBuilder: MessageRequestBuilder, + message: ThreadMessage, + content: string + ) => { + requestBuilder.pushAssistantMessage({ + content, + role: 'assistant', + refusal: null, + tool_calls: toolCalls, + }) + + // Handle completed tool calls + if (toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolId = ulid() + const toolCallsMetadata = + message.metadata?.tool_calls && + Array.isArray(message.metadata?.tool_calls) + ? message.metadata?.tool_calls + : [] + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: undefined, + state: 'pending', + }, + ], + } + events.emit(MessageEvent.OnMessageUpdate, message) + + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) break + + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: result, + state: 'ready', + }, + ], + } + + requestBuilder.pushToolMessage( + result.content[0]?.text ?? '', + toolCall.id + ) + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) + } + return { sendChatMessage, resendChatMessage, diff --git a/web/screens/Hub/ModelFilter/ModelSize/index.tsx b/web/screens/Hub/ModelFilter/ModelSize/index.tsx index b95d57f8b..a8d411e33 100644 --- a/web/screens/Hub/ModelFilter/ModelSize/index.tsx +++ b/web/screens/Hub/ModelFilter/ModelSize/index.tsx @@ -1,9 +1,8 @@ -import { useRef, useState } from 'react' +import { useState } from 'react' -import { Slider, Input, Tooltip } from '@janhq/joi' +import { Slider, Input } from '@janhq/joi' import { atom, useAtom } from 'jotai' -import { InfoIcon } from 'lucide-react' export const hubModelSizeMinAtom = atom(0) export const hubModelSizeMaxAtom = atom(100) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index 44145330f..015289df2 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -9,6 +9,9 @@ import { Tooltip, useClickOutside, badgeVariants, + Badge, + Modal, + ModalClose, } from '@janhq/joi' import { useAtom, useAtomValue } from 'jotai' import { @@ -19,6 +22,7 @@ import { SettingsIcon, ChevronUpIcon, Settings2Icon, + WrenchIcon, } from 'lucide-react' import { twMerge } from 'tailwind-merge' @@ -51,6 +55,7 @@ import { isBlockingSendAtom, } from '@/helpers/atoms/Thread.atom' import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom' +import { ModelTool } from '@/types/model' const ChatInput = () => { const activeThread = useAtomValue(activeThreadAtom) @@ -75,6 +80,8 @@ const ChatInput = () => { const isBlockingSend = useAtomValue(isBlockingSendAtom) const activeAssistant = useAtomValue(activeAssistantAtom) const { stopInference } = useActiveModel() + const [tools, setTools] = useState([]) + const [showToolsModal, setShowToolsModal] = useState(false) const upload = uploader() const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( @@ -98,6 +105,12 @@ const ChatInput = () => { } }, [activeSettingInputBox, selectedModel, setActiveSettingInputBox]) + useEffect(() => { + window.core?.api?.getTools().then((data: ModelTool[]) => { + setTools(data) + }) + }, []) + const onStopInferenceClick = async () => { stopInference() } @@ -142,6 +155,8 @@ const ChatInput = () => { } } + console.log(tools) + return (
{renderPreview(fileUpload)} @@ -392,6 +407,62 @@ const ChatInput = () => { className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]" /> + {tools && tools.length > 0 && ( + <> + setShowToolsModal(true)} + > + + {tools.length} + + + +
+ Jan can use tools provided by specialized servers using + Model Context Protocol.{' '} + + Learn more about MCP + +
+ {tools.map((tool: any) => ( +
+ +
+
{tool.name}
+
+ {tool.description} +
+
+
+ ))} +
+ } + /> + + )} {selectedModel && ( + + + {isExpanded && ( +
+ {result.trim()} +
+ )} + + + ) +} + +export default ToolCallBlock diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx index 7103f7914..522b27a0d 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx @@ -18,6 +18,8 @@ import ImageMessage from './ImageMessage' import { MarkdownTextMessage } from './MarkdownTextMessage' import ThinkingBlock from './ThinkingBlock' +import ToolCallBlock from './ToolCallBlock' + import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { editMessageAtom, @@ -65,57 +67,62 @@ const MessageContainer: React.FC< [props.content] ) - const attachedFile = useMemo(() => 'attachments' in props, [props]) + const attachedFile = useMemo( + () => 'attachments' in props && props.attachments?.length, + [props] + ) return (
-
- {!isUser && !isSystem && } - + {!(props.metadata && 'parent_id' in props.metadata) && (
- {!isUser && ( - <> - {props.metadata && 'model' in props.metadata - ? (props.metadata?.model as string) - : props.isCurrentMessage - ? selectedModel?.name - : (activeAssistant?.assistant_name ?? props.role)} - - )} + {!isUser && !isSystem && } + +
+ {!isUser && ( + <> + {props.metadata && 'model' in props.metadata + ? (props.metadata?.model as string) + : props.isCurrentMessage + ? selectedModel?.name + : (activeAssistant?.assistant_name ?? props.role)} + + )} +
+ +

+ {props.created_at && + displayDate(props.created_at ?? Date.now() / 1000)} +

+ )} -

- {props.created_at && - displayDate(props.created_at ?? Date.now() / 1000)} -

-
- -
+
@@ -179,6 +186,22 @@ const MessageContainer: React.FC< />
)} + {props.metadata && + 'tool_calls' in props.metadata && + Array.isArray(props.metadata.tool_calls) && + props.metadata.tool_calls.length && ( + <> + {props.metadata.tool_calls.map((toolCall) => ( + + ))} + + )}
diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts index 546370f10..5a0e50aff 100644 --- a/web/utils/messageRequestBuilder.ts +++ b/web/utils/messageRequestBuilder.ts @@ -11,6 +11,7 @@ import { Thread, ThreadMessage, } from '@janhq/core' +import { ChatCompletionMessage as OAIChatCompletionMessage } from 'openai/resources/chat' import { ulid } from 'ulidx' import { Stack } from '@/utils/Stack' @@ -45,12 +46,26 @@ export class MessageRequestBuilder { this.tools = tools } - pushAssistantMessage(message: string) { + pushAssistantMessage(message: OAIChatCompletionMessage) { + const { content, refusal, ...rest } = message + const normalizedMessage = { + ...rest, + ...(content ? { content } : {}), + ...(refusal ? { refusal } : {}), + } + this.messages = [ + ...this.messages, + normalizedMessage as ChatCompletionMessage, + ] + } + + pushToolMessage(message: string, toolCallId: string) { this.messages = [ ...this.messages, { - role: ChatCompletionRole.Assistant, + role: ChatCompletionRole.Tool, content: message, + tool_call_id: toolCallId, }, ] } @@ -194,7 +209,7 @@ export class MessageRequestBuilder { type: this.type, attachments: [], threadId: this.thread.id, - messages: this.normalizeMessages(this.messages), + messages: this.messages, model: this.model, thread: this.thread, tools: this.tools,