Merge pull request #5004 from menloresearch/feat/tool-use

feat: tool use render UI
This commit is contained in:
Faisal Amir 2025-05-16 22:20:46 +07:00 committed by GitHub
commit 589de63328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 346 additions and 70 deletions

View File

@ -22,8 +22,10 @@ import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { useModelProvider } from '@/hooks/useModelProvider'
import {
emptyThreadContent,
extractToolCall,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
startModel,
} from '@/lib/completion'
@ -37,6 +39,8 @@ import { MovingBorder } from './MovingBorder'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
type ChatInputProps = {
className?: string
@ -57,12 +61,10 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
useModelProvider()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { streamingContent, updateStreamingContent } = useAppState()
const { streamingContent, updateStreamingContent, updateLoadingModel } =
useAppState()
const { addMessage } = useMessages()
const router = useRouter()
const { updateLoadingModel } = useAppState()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
@ -104,9 +106,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return () => {
unsubscribe()
}
return unsubscribe
}, [])
useEffect(() => {
@ -146,7 +146,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
if (!activeThread || !provider) return
updateStreamingContent(emptyThreadContent)
addMessage(newUserThreadContent(activeThread.id, prompt))
setPrompt('')
try {
@ -158,18 +157,30 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
updateLoadingModel(false)
}
const completion = await sendCompletion(
activeThread,
provider,
prompt,
tools
)
const builder = new CompletionMessagesBuilder()
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(prompt)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
try {
let isCompleted = false
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
tools
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
@ -182,17 +193,17 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
} catch (error) {
console.error('Error during streaming:', error)
} finally {
// Create a final content object for adding to the thread
if (accumulatedText) {
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
addMessage(finalContent)
}
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(toolCalls, builder, finalContent)
console.log(updatedMessage)
addMessage(updatedMessage ?? finalContent)
isCompleted = !toolCalls.length
}
} catch (error) {
console.error('Error sending message:', error)

View File

@ -9,9 +9,10 @@ import {
IconPencil,
} from '@tabler/icons-react'
import { useAppState } from '@/hooks/useAppState'
import ThinkingBlock from './ThinkingBlock'
import { cn } from '@/lib/utils'
import { useMessages } from '@/hooks/useMessages'
import ThinkingBlock from '@/containers/ThinkingBlock'
import ToolCallBlock from '@/containers/ToolCallBlock'
const CopyButton = ({ text }: { text: string }) => {
const [copied, setCopied] = useState(false)
@ -81,6 +82,12 @@ export const ThreadContent = memo(
const { deleteMessage } = useMessages()
const isToolCalls =
item.metadata &&
'tool_calls' in item.metadata &&
Array.isArray(item.metadata.tool_calls) &&
item.metadata.tool_calls.length
return (
<Fragment>
{item.content?.[0]?.text && item.role === 'user' && (
@ -124,41 +131,59 @@ export const ThreadContent = memo(
text={reasoningSegment}
/>
)}
<RenderMarkdown content={textSegment} components={linkComponents} />
<div className="flex items-center gap-2 mt-2 text-main-view-fg/60 text-xs">
<div
className={cn(
'flex items-center gap-2',
item.isLastMessage &&
streamingContent &&
'opacity-0 visinility-hidden pointer-events-none'
)}
>
<CopyButton text={item.content?.[0]?.text.value || ''} />
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => {
deleteMessage(item.thread_id, item.id)
}}
{isToolCalls && item.metadata?.tool_calls ? (
<>
{(item.metadata.tool_calls as ToolCall[]).map((toolCall) => (
<ToolCallBlock
id={toolCall.tool?.id ?? 0}
name={toolCall.tool?.function?.name ?? ''}
key={toolCall.tool?.id}
result={JSON.stringify(toolCall.response)}
loading={toolCall.state === 'pending'}
/>
))}
</>
) : null}
{!isToolCalls && (
<div className="flex items-center gap-2 mt-2 text-main-view-fg/60 text-xs">
<div
className={cn(
'flex items-center gap-2',
item.isLastMessage &&
streamingContent &&
'opacity-0 visinility-hidden pointer-events-none'
)}
>
<IconTrash size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Delete
</span>
</button>
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => {
console.log('Regenerate clicked')
}}
>
<IconRefresh size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Regenerate
</span>
</button>
<CopyButton text={item.content?.[0]?.text.value || ''} />
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => {
deleteMessage(item.thread_id, item.id)
}}
>
<IconTrash size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Delete
</span>
</button>
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => {
console.log('Regenerate clicked')
}}
>
<IconRefresh size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Regenerate
</span>
</button>
</div>
</div>
</div>
)}
</>
)}
{item.type === 'image_url' && image && (

View File

@ -0,0 +1,79 @@
import { ChevronDown, ChevronUp, Loader } from 'lucide-react'
import { cn } from '@/lib/utils'
import { create } from 'zustand'
import { RenderMarkdown } from './RenderMarkdown'
interface Props {
result: string
name: string
id: number
loading: boolean
}
type ToolCallBlockState = {
collapseState: { [id: number]: boolean }
setCollapseState: (id: number, expanded: boolean) => void
}
const useToolCallBlockStore = create<ToolCallBlockState>((set) => ({
collapseState: {},
setCollapseState: (id, expanded) =>
set((state) => ({
collapseState: {
...state.collapseState,
[id]: expanded,
},
})),
}))
const ToolCallBlock = ({ id, name, result, loading }: Props) => {
const { collapseState, setCollapseState } = useToolCallBlockStore()
const isExpanded = collapseState[id] ?? false
const handleClick = () => {
const newExpandedState = !isExpanded
setCollapseState(id, newExpandedState)
}
return (
<div className="mx-auto w-full cursor-pointer mt-4" onClick={handleClick}>
<div className="mb-4 rounded-lg bg-main-view-fg/4 border border-dashed border-main-view-fg/10">
<div className="flex items-center gap-3 p-2">
{loading && (
<Loader className="size-4 animate-spin text-main-view-fg/60" />
)}
<button className="flex items-center gap-2 focus:outline-none">
{isExpanded ? (
<ChevronUp className="h-4 w-4" />
) : (
<ChevronDown className="h-4 w-4" />
)}
<span className="font-medium text-main-view-fg/80">
View result from{' '}
<span className="font-medium text-main-view-fg">{name}</span>
</span>
</button>
</div>
<div
className={cn(
'h-fit w-full overflow-auto transition-all duration-300 px-2',
isExpanded ? '' : 'max-h-0 overflow-hidden'
)}
>
<div className="mt-2 text-main-view-fg/60">
<RenderMarkdown
content={
'```json\n' +
JSON.stringify(result ? JSON.parse(result) : null, null, 2) +
'\n```'
}
/>
</div>
</div>
</div>
</div>
)
}
export default ToolCallBlock

View File

@ -8,7 +8,9 @@ import {
} from '@janhq/core'
import { invoke } from '@tauri-apps/api/core'
import {
ChatCompletionMessageParam,
ChatCompletionTool,
CompletionResponseChunk,
models,
StreamCompletionResponse,
TokenJS,
@ -16,6 +18,9 @@ import {
import { ulid } from 'ulidx'
import { normalizeProvider } from './models'
import { MCPTool } from '@/types/completion'
import { CompletionMessagesBuilder } from './messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
/**
* @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects
@ -97,13 +102,13 @@ export const emptyThreadContent: ThreadMessage = {
* @fileoverview Helper function to send a completion request to the model provider.
* @param thread
* @param provider
* @param prompt
* @param messages
* @returns
*/
export const sendCompletion = async (
thread: Thread,
provider: ModelProvider,
prompt: string,
messages: ChatCompletionMessageParam[],
tools: MCPTool[] = []
): Promise<StreamCompletionResponse | undefined> => {
if (!thread?.model?.id || !provider) return undefined
@ -124,13 +129,9 @@ export const sendCompletion = async (
stream: true,
provider: providerName,
model: thread.model?.id,
messages: [
{
role: 'user',
content: prompt,
},
],
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
})
return completion
}
@ -138,6 +139,8 @@ export const sendCompletion = async (
/**
* @fileoverview Helper function to start a model.
* This function loads the model from the provider.
* @deprecated This function is deprecated and will be removed in the future.
* Provider's chat function will handle loading the model.
* @param provider
* @param model
* @returns
@ -170,8 +173,8 @@ export const stopModel = async (
/**
* @fileoverview Helper function to normalize tools for the chat completion request.
* This function converts the MCPTool objects to ChatCompletionTool objects.
* @param tools
* @returns
* @param tools
* @returns
*/
export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => {
return tools.map((tool) => ({
@ -184,3 +187,114 @@ export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => {
},
}))
}
/**
* @fileoverview Helper function to extract tool calls from the completion response.
* @param part
* @param calls
*/
export const extractToolCall = (
part: CompletionResponseChunk,
currentCall: ChatCompletionMessageToolCall | null,
calls: ChatCompletionMessageToolCall[]
) => {
const deltaToolCalls = part.choices[0].delta.tool_calls
// Handle the beginning of a new tool call
if (deltaToolCalls?.[0]?.index !== undefined && deltaToolCalls[0]?.function) {
const index = deltaToolCalls[0].index
// Create new tool call if this is the first chunk for it
if (!calls[index]) {
calls[index] = {
id: deltaToolCalls[0]?.id || '',
function: {
name: deltaToolCalls[0]?.function?.name || '',
arguments: deltaToolCalls[0]?.function?.arguments || '',
},
type: 'function',
}
currentCall = calls[index]
} else {
// Continuation of existing tool call
currentCall = calls[index]
// Append to function name or arguments if they exist in this chunk
if (deltaToolCalls[0]?.function?.name) {
currentCall!.function.name += deltaToolCalls[0].function.name
}
if (deltaToolCalls[0]?.function?.arguments) {
currentCall!.function.arguments +=
deltaToolCalls[0].function.arguments
}
}
}
return calls
}
/**
* @fileoverview Helper function to process the completion response.
* @param calls
* @param builder
* @param message
* @param content
*/
export const postMessageProcessing = async (
calls: ChatCompletionMessageToolCall[],
builder: CompletionMessagesBuilder,
message: ThreadMessage
) => {
// Handle completed tool calls
if (calls.length) {
for (const toolCall of calls) {
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...(toolCall as object),
id: toolId,
},
response: undefined,
state: 'pending',
},
],
}
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: toolCall.function.arguments.length
? JSON.parse(toolCall.function.arguments)
: {},
})
if (result.error) break
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
builder.addToolMessage(result.content[0]?.text ?? '', toolCall.id)
// update message metadata
return message
}
}
}

View File

@ -0,0 +1,36 @@
import { ChatCompletionMessageParam } from 'token.js'
import { ChatCompletionMessageToolCall } from 'openai/resources'
export class CompletionMessagesBuilder {
private messages: ChatCompletionMessageParam[] = []
constructor() {}
addUserMessage(content: string) {
this.messages.push({
role: 'user',
content: content,
})
}
addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) {
this.messages.push({
role: 'assistant',
content: content,
refusal: refusal,
tool_calls: calls
})
}
addToolMessage(content: string, toolCallId: string) {
this.messages.push({
role: 'tool',
content: content,
tool_call_id: toolCallId,
})
}
getMessages(): ChatCompletionMessageParam[] {
return this.messages
}
}

View File

@ -177,6 +177,7 @@ function ThreadDetail() {
messages.map((item, index) => {
// Only pass isLastMessage to the last message in the array
const isLastMessage = index === messages.length - 1
console.log(messages, 'messages')
return (
<div
key={item.id}

10
web-app/src/types/message.d.ts vendored Normal file
View File

@ -0,0 +1,10 @@
type ToolCall = {
tool: {
id?: number
function?: {
name?: string
}
}
response?: unknown
state?: string
}