feat: tool use
This commit is contained in:
parent
66a4ac420b
commit
95f90f601d
@ -22,8 +22,10 @@ import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import {
|
||||
emptyThreadContent,
|
||||
extractToolCall,
|
||||
newAssistantThreadContent,
|
||||
newUserThreadContent,
|
||||
postMessageProcessing,
|
||||
sendCompletion,
|
||||
startModel,
|
||||
} from '@/lib/completion'
|
||||
@ -37,6 +39,8 @@ import { MovingBorder } from './MovingBorder'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { listen } from '@tauri-apps/api/event'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
type ChatInputProps = {
|
||||
className?: string
|
||||
@ -57,12 +61,10 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
useModelProvider()
|
||||
|
||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||
const { streamingContent, updateStreamingContent } = useAppState()
|
||||
|
||||
const { streamingContent, updateStreamingContent, updateLoadingModel } =
|
||||
useAppState()
|
||||
const { addMessage } = useMessages()
|
||||
|
||||
const router = useRouter()
|
||||
const { updateLoadingModel } = useAppState()
|
||||
|
||||
const provider = useMemo(() => {
|
||||
return getProviderByName(selectedProvider)
|
||||
@ -104,9 +106,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
// Unsubscribe from the event when the component unmounts
|
||||
unsubscribe = unsub
|
||||
})
|
||||
return () => {
|
||||
unsubscribe()
|
||||
}
|
||||
return unsubscribe
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
@ -146,7 +146,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
if (!activeThread || !provider) return
|
||||
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
|
||||
addMessage(newUserThreadContent(activeThread.id, prompt))
|
||||
setPrompt('')
|
||||
try {
|
||||
@ -158,18 +157,30 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
prompt,
|
||||
tools
|
||||
)
|
||||
const builder = new CompletionMessagesBuilder()
|
||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||
// TODO: If not amend messages history here
|
||||
builder.addUserMessage(prompt)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
try {
|
||||
let isCompleted = false
|
||||
|
||||
while (!isCompleted) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
builder.getMessages(),
|
||||
tools
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
for await (const part of completion) {
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
@ -182,17 +193,17 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error during streaming:', error)
|
||||
} finally {
|
||||
// Create a final content object for adding to the thread
|
||||
if (accumulatedText) {
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
addMessage(finalContent)
|
||||
}
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
const updatedMessage = await postMessageProcessing(toolCalls, builder, finalContent)
|
||||
console.log(updatedMessage)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error sending message:', error)
|
||||
|
||||
@ -8,7 +8,9 @@ import {
|
||||
} from '@janhq/core'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import {
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionTool,
|
||||
CompletionResponseChunk,
|
||||
models,
|
||||
StreamCompletionResponse,
|
||||
TokenJS,
|
||||
@ -16,6 +18,9 @@ import {
|
||||
import { ulid } from 'ulidx'
|
||||
import { normalizeProvider } from './models'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { CompletionMessagesBuilder } from './messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
/**
|
||||
* @fileoverview Helper functions for creating thread content.
|
||||
* These functions are used to create thread content objects
|
||||
@ -97,13 +102,13 @@ export const emptyThreadContent: ThreadMessage = {
|
||||
* @fileoverview Helper function to send a completion request to the model provider.
|
||||
* @param thread
|
||||
* @param provider
|
||||
* @param prompt
|
||||
* @param messages
|
||||
* @returns
|
||||
*/
|
||||
export const sendCompletion = async (
|
||||
thread: Thread,
|
||||
provider: ModelProvider,
|
||||
prompt: string,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
tools: MCPTool[] = []
|
||||
): Promise<StreamCompletionResponse | undefined> => {
|
||||
if (!thread?.model?.id || !provider) return undefined
|
||||
@ -124,13 +129,9 @@ export const sendCompletion = async (
|
||||
stream: true,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
},
|
||||
],
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
})
|
||||
return completion
|
||||
}
|
||||
@ -138,6 +139,8 @@ export const sendCompletion = async (
|
||||
/**
|
||||
* @fileoverview Helper function to start a model.
|
||||
* This function loads the model from the provider.
|
||||
* @deprecated This function is deprecated and will be removed in the future.
|
||||
* Provider's chat function will handle loading the model.
|
||||
* @param provider
|
||||
* @param model
|
||||
* @returns
|
||||
@ -170,8 +173,8 @@ export const stopModel = async (
|
||||
/**
|
||||
* @fileoverview Helper function to normalize tools for the chat completion request.
|
||||
* This function converts the MCPTool objects to ChatCompletionTool objects.
|
||||
* @param tools
|
||||
* @returns
|
||||
* @param tools
|
||||
* @returns
|
||||
*/
|
||||
export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => {
|
||||
return tools.map((tool) => ({
|
||||
@ -184,3 +187,114 @@ export const normalizeTools = (tools: MCPTool[]): ChatCompletionTool[] => {
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* @fileoverview Helper function to extract tool calls from the completion response.
|
||||
* @param part
|
||||
* @param calls
|
||||
*/
|
||||
export const extractToolCall = (
|
||||
part: CompletionResponseChunk,
|
||||
currentCall: ChatCompletionMessageToolCall | null,
|
||||
calls: ChatCompletionMessageToolCall[]
|
||||
) => {
|
||||
const deltaToolCalls = part.choices[0].delta.tool_calls
|
||||
// Handle the beginning of a new tool call
|
||||
if (deltaToolCalls?.[0]?.index !== undefined && deltaToolCalls[0]?.function) {
|
||||
const index = deltaToolCalls[0].index
|
||||
|
||||
// Create new tool call if this is the first chunk for it
|
||||
if (!calls[index]) {
|
||||
calls[index] = {
|
||||
id: deltaToolCalls[0]?.id || '',
|
||||
function: {
|
||||
name: deltaToolCalls[0]?.function?.name || '',
|
||||
arguments: deltaToolCalls[0]?.function?.arguments || '',
|
||||
},
|
||||
type: 'function',
|
||||
}
|
||||
currentCall = calls[index]
|
||||
} else {
|
||||
// Continuation of existing tool call
|
||||
currentCall = calls[index]
|
||||
|
||||
// Append to function name or arguments if they exist in this chunk
|
||||
if (deltaToolCalls[0]?.function?.name) {
|
||||
currentCall!.function.name += deltaToolCalls[0].function.name
|
||||
}
|
||||
|
||||
if (deltaToolCalls[0]?.function?.arguments) {
|
||||
currentCall!.function.arguments +=
|
||||
deltaToolCalls[0].function.arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
/**
|
||||
* @fileoverview Helper function to process the completion response.
|
||||
* @param calls
|
||||
* @param builder
|
||||
* @param message
|
||||
* @param content
|
||||
*/
|
||||
export const postMessageProcessing = async (
|
||||
calls: ChatCompletionMessageToolCall[],
|
||||
builder: CompletionMessagesBuilder,
|
||||
message: ThreadMessage
|
||||
) => {
|
||||
// Handle completed tool calls
|
||||
if (calls.length) {
|
||||
for (const toolCall of calls) {
|
||||
const toolId = ulid()
|
||||
const toolCallsMetadata =
|
||||
message.metadata?.tool_calls &&
|
||||
Array.isArray(message.metadata?.tool_calls)
|
||||
? message.metadata?.tool_calls
|
||||
: []
|
||||
message.metadata = {
|
||||
...(message.metadata ?? {}),
|
||||
tool_calls: [
|
||||
...toolCallsMetadata,
|
||||
{
|
||||
tool: {
|
||||
...(toolCall as object),
|
||||
id: toolId,
|
||||
},
|
||||
response: undefined,
|
||||
state: 'pending',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = await window.core.api.callTool({
|
||||
toolName: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments.length
|
||||
? JSON.parse(toolCall.function.arguments)
|
||||
: {},
|
||||
})
|
||||
|
||||
if (result.error) break
|
||||
|
||||
message.metadata = {
|
||||
...(message.metadata ?? {}),
|
||||
tool_calls: [
|
||||
...toolCallsMetadata,
|
||||
{
|
||||
tool: {
|
||||
...toolCall,
|
||||
id: toolId,
|
||||
},
|
||||
response: result,
|
||||
state: 'ready',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
builder.addToolMessage(result.content[0]?.text ?? '', toolCall.id)
|
||||
// update message metadata
|
||||
return message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
36
web-app/src/lib/messages.ts
Normal file
36
web-app/src/lib/messages.ts
Normal file
@ -0,0 +1,36 @@
|
||||
import { ChatCompletionMessageParam } from 'token.js'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
export class CompletionMessagesBuilder {
|
||||
private messages: ChatCompletionMessageParam[] = []
|
||||
|
||||
constructor() {}
|
||||
|
||||
addUserMessage(content: string) {
|
||||
this.messages.push({
|
||||
role: 'user',
|
||||
content: content,
|
||||
})
|
||||
}
|
||||
|
||||
addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) {
|
||||
this.messages.push({
|
||||
role: 'assistant',
|
||||
content: content,
|
||||
refusal: refusal,
|
||||
tool_calls: calls
|
||||
})
|
||||
}
|
||||
|
||||
addToolMessage(content: string, toolCallId: string) {
|
||||
this.messages.push({
|
||||
role: 'tool',
|
||||
content: content,
|
||||
tool_call_id: toolCallId,
|
||||
})
|
||||
}
|
||||
|
||||
getMessages(): ChatCompletionMessageParam[] {
|
||||
return this.messages
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user