From 95f90f601d0d6501f8eb54f606c9c94347fa61a9 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 16 May 2025 00:27:20 +0700 Subject: [PATCH 1/2] feat: tool use --- web-app/src/containers/ChatInput.tsx | 65 +++++++------ web-app/src/lib/completion.ts | 134 +++++++++++++++++++++++++-- web-app/src/lib/messages.ts | 36 +++++++ 3 files changed, 198 insertions(+), 37 deletions(-) create mode 100644 web-app/src/lib/messages.ts diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index e5c469573..dcb298e49 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -22,8 +22,10 @@ import { useGeneralSetting } from '@/hooks/useGeneralSetting' import { useModelProvider } from '@/hooks/useModelProvider' import { emptyThreadContent, + extractToolCall, newAssistantThreadContent, newUserThreadContent, + postMessageProcessing, sendCompletion, startModel, } from '@/lib/completion' @@ -37,6 +39,8 @@ import { MovingBorder } from './MovingBorder' import { MCPTool } from '@/types/completion' import { listen } from '@tauri-apps/api/event' import { SystemEvent } from '@/types/events' +import { CompletionMessagesBuilder } from '@/lib/messages' +import { ChatCompletionMessageToolCall } from 'openai/resources' type ChatInputProps = { className?: string @@ -57,12 +61,10 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { useModelProvider() const { getCurrentThread: retrieveThread, createThread } = useThreads() - const { streamingContent, updateStreamingContent } = useAppState() - + const { streamingContent, updateStreamingContent, updateLoadingModel } = + useAppState() const { addMessage } = useMessages() - const router = useRouter() - const { updateLoadingModel } = useAppState() const provider = useMemo(() => { return getProviderByName(selectedProvider) @@ -104,9 +106,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { // Unsubscribe from the event when the component unmounts unsubscribe = unsub }) - return () => { - unsubscribe() - } + return unsubscribe }, []) useEffect(() => { @@ -146,7 +146,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { if (!activeThread || !provider) return updateStreamingContent(emptyThreadContent) - addMessage(newUserThreadContent(activeThread.id, prompt)) setPrompt('') try { @@ -158,18 +157,30 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { updateLoadingModel(false) } - const completion = await sendCompletion( - activeThread, - provider, - prompt, - tools - ) + const builder = new CompletionMessagesBuilder() + // REMARK: Would it possible to not attach the entire message history to the request? + // TODO: If not amend messages history here + builder.addUserMessage(prompt) - if (!completion) throw new Error('No completion received') - let accumulatedText = '' - try { + let isCompleted = false + + while (!isCompleted) { + const completion = await sendCompletion( + activeThread, + provider, + builder.getMessages(), + tools + ) + + if (!completion) throw new Error('No completion received') + let accumulatedText = '' + const currentCall: ChatCompletionMessageToolCall | null = null + const toolCalls: ChatCompletionMessageToolCall[] = [] for await (const part of completion) { const delta = part.choices[0]?.delta?.content || '' + if (part.choices[0]?.delta?.tool_calls) { + extractToolCall(part, currentCall, toolCalls) + } if (delta) { accumulatedText += delta // Create a new object each time to avoid reference issues @@ -182,17 +193,17 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { await new Promise((resolve) => setTimeout(resolve, 0)) } } - } catch (error) { - console.error('Error during streaming:', error) - } finally { // Create a final content object for adding to the thread - if (accumulatedText) { - const finalContent = newAssistantThreadContent( - activeThread.id, - accumulatedText - ) - addMessage(finalContent) - } + const finalContent = newAssistantThreadContent( + activeThread.id, + accumulatedText + ) + builder.addAssistantMessage(accumulatedText, undefined, toolCalls) + const updatedMessage = await postMessageProcessing(toolCalls, builder, finalContent) + console.log(updatedMessage) + addMessage(updatedMessage ?? finalContent) + + isCompleted = !toolCalls.length } } catch (error) { console.error('Error sending message:', error) diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 1cb58f487..b771a6d08 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -8,7 +8,9 @@ import { } from '@janhq/core' import { invoke } from '@tauri-apps/api/core' import { + ChatCompletionMessageParam, ChatCompletionTool, + CompletionResponseChunk, models, StreamCompletionResponse, TokenJS, @@ -16,6 +18,9 @@ import { import { ulid } from 'ulidx' import { normalizeProvider } from './models' import { MCPTool } from '@/types/completion' +import { CompletionMessagesBuilder } from './messages' +import { ChatCompletionMessageToolCall } from 'openai/resources' + /** * @fileoverview Helper functions for creating thread content. * These functions are used to create thread content objects @@ -97,13 +102,13 @@ export const emptyThreadContent: ThreadMessage = { * @fileoverview Helper function to send a completion request to the model provider. * @param thread * @param provider - * @param prompt + * @param messages * @returns */ export const sendCompletion = async ( thread: Thread, provider: ModelProvider, - prompt: string, + messages: ChatCompletionMessageParam[], tools: MCPTool[] = [] ): Promise => { if (!thread?.model?.id || !provider) return undefined @@ -124,13 +129,9 @@ export const sendCompletion = async ( stream: true, provider: providerName, model: thread.model?.id, - messages: [ - { - role: 'user', - content: prompt, - }, - ], + messages, tools: normalizeTools(tools), + tool_choice: tools.length ? 'auto' : undefined, }) return completion } @@ -138,6 +139,8 @@ export const sendCompletion = async ( /** * @fileoverview Helper function to start a model. * This function loads the model from the provider. + * @deprecated This function is deprecated and will be removed in the future. + * Provider's chat function will handle loading the model. * @param provider * @param model * @returns @@ -170,8 +173,8 @@ export const stopModel = async ( /** * @fileoverview Helper function to normalize tools for the chat completion request. * This function converts the MCPTool objects to ChatCompletionTool objects. - * @param tools - * @returns + * @param tools + * @returns */ export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => { return tools.map((tool) => ({ @@ -184,3 +187,114 @@ export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => { }, })) } + +/** + * @fileoverview Helper function to extract tool calls from the completion response. + * @param part + * @param calls + */ +export const extractToolCall = ( + part: CompletionResponseChunk, + currentCall: ChatCompletionMessageToolCall | null, + calls: ChatCompletionMessageToolCall[] +) => { + const deltaToolCalls = part.choices[0].delta.tool_calls + // Handle the beginning of a new tool call + if (deltaToolCalls?.[0]?.index !== undefined && deltaToolCalls[0]?.function) { + const index = deltaToolCalls[0].index + + // Create new tool call if this is the first chunk for it + if (!calls[index]) { + calls[index] = { + id: deltaToolCalls[0]?.id || '', + function: { + name: deltaToolCalls[0]?.function?.name || '', + arguments: deltaToolCalls[0]?.function?.arguments || '', + }, + type: 'function', + } + currentCall = calls[index] + } else { + // Continuation of existing tool call + currentCall = calls[index] + + // Append to function name or arguments if they exist in this chunk + if (deltaToolCalls[0]?.function?.name) { + currentCall!.function.name += deltaToolCalls[0].function.name + } + + if (deltaToolCalls[0]?.function?.arguments) { + currentCall!.function.arguments += + deltaToolCalls[0].function.arguments + } + } + } + return calls +} + +/** + * @fileoverview Helper function to process the completion response. + * @param calls + * @param builder + * @param message + * @param content + */ +export const postMessageProcessing = async ( + calls: ChatCompletionMessageToolCall[], + builder: CompletionMessagesBuilder, + message: ThreadMessage +) => { + // Handle completed tool calls + if (calls.length) { + for (const toolCall of calls) { + const toolId = ulid() + const toolCallsMetadata = + message.metadata?.tool_calls && + Array.isArray(message.metadata?.tool_calls) + ? message.metadata?.tool_calls + : [] + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...(toolCall as object), + id: toolId, + }, + response: undefined, + state: 'pending', + }, + ], + } + + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: toolCall.function.arguments.length + ? JSON.parse(toolCall.function.arguments) + : {}, + }) + + if (result.error) break + + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: result, + state: 'ready', + }, + ], + } + + builder.addToolMessage(result.content[0]?.text ?? '', toolCall.id) + // update message metadata + return message + } + } +} diff --git a/web-app/src/lib/messages.ts b/web-app/src/lib/messages.ts new file mode 100644 index 000000000..1175a6549 --- /dev/null +++ b/web-app/src/lib/messages.ts @@ -0,0 +1,36 @@ +import { ChatCompletionMessageParam } from 'token.js' +import { ChatCompletionMessageToolCall } from 'openai/resources' + +export class CompletionMessagesBuilder { + private messages: ChatCompletionMessageParam[] = [] + + constructor() {} + + addUserMessage(content: string) { + this.messages.push({ + role: 'user', + content: content, + }) + } + + addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) { + this.messages.push({ + role: 'assistant', + content: content, + refusal: refusal, + tool_calls: calls + }) + } + + addToolMessage(content: string, toolCallId: string) { + this.messages.push({ + role: 'tool', + content: content, + tool_call_id: toolCallId, + }) + } + + getMessages(): ChatCompletionMessageParam[] { + return this.messages + } +} From 05ce85d9b11ac00e3a67396c98acd356183f782c Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Fri, 16 May 2025 22:09:43 +0700 Subject: [PATCH 2/2] chore: styling tool call funtion render UI --- web-app/src/containers/ThreadContent.tsx | 91 +++++++++++++++--------- web-app/src/containers/ToolCallBlock.tsx | 79 ++++++++++++++++++++ web-app/src/routes/threads/$threadId.tsx | 1 + web-app/src/types/message.d.ts | 10 +++ 4 files changed, 148 insertions(+), 33 deletions(-) create mode 100644 web-app/src/containers/ToolCallBlock.tsx create mode 100644 web-app/src/types/message.d.ts diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index 4ca6e47e0..f7ae1bec4 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -9,9 +9,10 @@ import { IconPencil, } from '@tabler/icons-react' import { useAppState } from '@/hooks/useAppState' -import ThinkingBlock from './ThinkingBlock' import { cn } from '@/lib/utils' import { useMessages } from '@/hooks/useMessages' +import ThinkingBlock from '@/containers/ThinkingBlock' +import ToolCallBlock from '@/containers/ToolCallBlock' const CopyButton = ({ text }: { text: string }) => { const [copied, setCopied] = useState(false) @@ -81,6 +82,12 @@ export const ThreadContent = memo( const { deleteMessage } = useMessages() + const isToolCalls = + item.metadata && + 'tool_calls' in item.metadata && + Array.isArray(item.metadata.tool_calls) && + item.metadata.tool_calls.length + return ( {item.content?.[0]?.text && item.role === 'user' && ( @@ -124,41 +131,59 @@ export const ThreadContent = memo( text={reasoningSegment} /> )} + -
-
- - - + + + +
- + )} )} {item.type === 'image_url' && image && ( diff --git a/web-app/src/containers/ToolCallBlock.tsx b/web-app/src/containers/ToolCallBlock.tsx new file mode 100644 index 000000000..f9cb1f1db --- /dev/null +++ b/web-app/src/containers/ToolCallBlock.tsx @@ -0,0 +1,79 @@ +import { ChevronDown, ChevronUp, Loader } from 'lucide-react' +import { cn } from '@/lib/utils' +import { create } from 'zustand' +import { RenderMarkdown } from './RenderMarkdown' + +interface Props { + result: string + name: string + id: number + loading: boolean +} + +type ToolCallBlockState = { + collapseState: { [id: number]: boolean } + setCollapseState: (id: number, expanded: boolean) => void +} + +const useToolCallBlockStore = create((set) => ({ + collapseState: {}, + setCollapseState: (id, expanded) => + set((state) => ({ + collapseState: { + ...state.collapseState, + [id]: expanded, + }, + })), +})) + +const ToolCallBlock = ({ id, name, result, loading }: Props) => { + const { collapseState, setCollapseState } = useToolCallBlockStore() + const isExpanded = collapseState[id] ?? false + + const handleClick = () => { + const newExpandedState = !isExpanded + setCollapseState(id, newExpandedState) + } + + return ( +
+
+
+ {loading && ( + + )} + +
+ +
+
+ +
+
+
+
+ ) +} + +export default ToolCallBlock diff --git a/web-app/src/routes/threads/$threadId.tsx b/web-app/src/routes/threads/$threadId.tsx index b5708ec78..12fee157f 100644 --- a/web-app/src/routes/threads/$threadId.tsx +++ b/web-app/src/routes/threads/$threadId.tsx @@ -177,6 +177,7 @@ function ThreadDetail() { messages.map((item, index) => { // Only pass isLastMessage to the last message in the array const isLastMessage = index === messages.length - 1 + console.log(messages, 'messages') return (