feat: Jan Tool Use - MCP frontend implementation

This commit is contained in:
Louis 2025-03-30 17:56:39 +07:00
parent 94b77db294
commit 3dd80841c2
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
6 changed files with 186 additions and 18 deletions

View File

@ -40,12 +40,13 @@ export abstract class AIEngine extends BaseExtension {
* Stops the model. * Stops the model.
*/ */
async unloadModel(model?: Model): Promise<any> { async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() if (model?.engine && model.engine.toString() !== this.provider)
return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {}) events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve() return Promise.resolve()
} }
/* /**
* Inference request * Inference request
*/ */
inference(data: MessageRequest) {} inference(data: MessageRequest) {}

View File

@ -76,7 +76,7 @@ export abstract class OAIEngine extends AIEngine {
const timestamp = Date.now() / 1000 const timestamp = Date.now() / 1000
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
thread_id: data.threadId, thread_id: data.thread?.id ?? data.threadId,
type: data.type, type: data.type,
assistant_id: data.assistantId, assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant, role: ChatCompletionRole.Assistant,
@ -104,6 +104,7 @@ export abstract class OAIEngine extends AIEngine {
messages: data.messages ?? [], messages: data.messages ?? [],
model: model.id, model: model.id,
stream: true, stream: true,
tools: data.tools,
...model.parameters, ...model.parameters,
} }
if (this.transformPayload) { if (this.transformPayload) {

View File

@ -43,6 +43,9 @@ export type ThreadMessage = {
* @data_transfer_object * @data_transfer_object
*/ */
export type MessageRequest = { export type MessageRequest = {
/**
* The id of the message request.
*/
id?: string id?: string
/** /**
@ -71,6 +74,11 @@ export type MessageRequest = {
// TODO: deprecate threadId field // TODO: deprecate threadId field
thread?: Thread thread?: Thread
/**
* ChatCompletion tools
*/
tools?: MessageTool[]
/** Engine name to process */ /** Engine name to process */
engine?: string engine?: string
@ -78,6 +86,24 @@ export type MessageRequest = {
type?: string type?: string
} }
/**
* ChatCompletion Tool parameters
*/
export type MessageTool = {
type: string
function: MessageFunction
}
/**
* ChatCompletion Tool's function parameters
*/
export type MessageFunction = {
name: string
description?: string
parameters?: Record<string, unknown>
strict?: boolean
}
/** /**
* The status of the message. * The status of the message.
* @data_transfer_object * @data_transfer_object

View File

@ -1,19 +1,31 @@
import { useEffect, useRef } from 'react' import { useEffect, useRef } from 'react'
import { import {
ChatCompletionRole,
MessageRequestType, MessageRequestType,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
ThreadMessage, ThreadMessage,
Model, Model,
ConversationalExtension, ConversationalExtension,
EngineManager,
ThreadAssistantInfo, ThreadAssistantInfo,
InferenceEngine, events,
MessageEvent,
ContentType,
} from '@janhq/core' } from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionRole,
ChatCompletionTool,
} from 'openai/resources/chat'
import { Tool } from 'openai/resources/responses/responses'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown' import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
import { import {
@ -99,7 +111,7 @@ export default function useSendChatMessage() {
const newConvoData = Array.from(currentMessages) const newConvoData = Array.from(currentMessages)
let toSendMessage = newConvoData.pop() let toSendMessage = newConvoData.pop()
while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) { while (toSendMessage && toSendMessage?.role !== 'user') {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id) ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
@ -172,7 +184,16 @@ export default function useSendChatMessage() {
parameters: runtimeParams, parameters: runtimeParams,
}, },
activeThreadRef.current, activeThreadRef.current,
messages ?? currentMessages messages ?? currentMessages,
(await window.core.api.getTools())?.map((tool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
).addSystemMessage(activeAssistantRef.current?.instructions) ).addSystemMessage(activeAssistantRef.current?.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload) requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
@ -228,10 +249,118 @@ export default function useSendChatMessage() {
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
// Request for inference let isDone = false
EngineManager.instance() const openai = new OpenAI({
.get(InferenceEngine.cortex) apiKey: await window.core.api.appToken(),
?.inference(requestBuilder.build()) baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
})
while (!isDone) {
const data = requestBuilder.build()
const response = await openai.chat.completions.create({
messages: (data.messages ?? []).map((e) => {
return {
role: e.role as ChatCompletionRole,
content: e.content,
}
}) as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: false,
})
if (response.choices[0]?.message.content) {
const newMessage: ThreadMessage = {
id: ulid(),
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: response.choices[0].message.role as any,
content: [
{
type: ContentType.Text,
text: {
value: response.choices[0].message.content
? (response.choices[0].message.content as any)
: '',
annotations: [],
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(
(response.choices[0].message.content as any) ?? ''
)
events.emit(MessageEvent.OnMessageUpdate, newMessage)
}
if (response.choices[0]?.message.tool_calls) {
for (const toolCall of response.choices[0].message.tool_calls) {
const id = ulid()
const toolMessage: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>`,
annotations: [],
},
},
],
status: 'pending' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) {
console.error(result.error)
break
}
const message: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value:
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
(result.content[0]?.text ?? ''),
annotations: [],
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
requestBuilder.pushMessage('Go for the next step')
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
isDone =
!response.choices[0]?.message.tool_calls ||
!response.choices[0]?.message.tool_calls.length
}
// Reset states // Reset states
setReloadModel(false) setReloadModel(false)

View File

@ -36,6 +36,7 @@
"marked": "^9.1.2", "marked": "^9.1.2",
"next": "14.2.3", "next": "14.2.3",
"next-themes": "^0.2.1", "next-themes": "^0.2.1",
"openai": "^4.90.0",
"postcss": "8.4.31", "postcss": "8.4.31",
"postcss-url": "10.1.3", "postcss-url": "10.1.3",
"posthog-js": "^1.194.6", "posthog-js": "^1.194.6",

View File

@ -6,6 +6,7 @@ import {
ChatCompletionRole, ChatCompletionRole,
MessageRequest, MessageRequest,
MessageRequestType, MessageRequestType,
MessageTool,
ModelInfo, ModelInfo,
Thread, Thread,
ThreadMessage, ThreadMessage,
@ -22,12 +23,14 @@ export class MessageRequestBuilder {
messages: ChatCompletionMessage[] messages: ChatCompletionMessage[]
model: ModelInfo model: ModelInfo
thread: Thread thread: Thread
tools?: MessageTool[]
constructor( constructor(
type: MessageRequestType, type: MessageRequestType,
model: ModelInfo, model: ModelInfo,
thread: Thread, thread: Thread,
messages: ThreadMessage[] messages: ThreadMessage[],
tools?: MessageTool[]
) { ) {
this.msgId = ulid() this.msgId = ulid()
this.type = type this.type = type
@ -39,14 +42,20 @@ export class MessageRequestBuilder {
role: msg.role, role: msg.role,
content: msg.content[0]?.text?.value ?? '.', content: msg.content[0]?.text?.value ?? '.',
})) }))
this.tools = tools
} }
pushAssistantMessage(message: string) {
this.messages = [
...this.messages,
{
role: ChatCompletionRole.Assistant,
content: message,
},
]
}
// Chainable // Chainable
pushMessage( pushMessage(message: string, base64Blob?: string, fileInfo?: FileInfo) {
message: string,
base64Blob: string | undefined,
fileInfo?: FileInfo
) {
if (base64Blob && fileInfo?.type === 'pdf') if (base64Blob && fileInfo?.type === 'pdf')
return this.addDocMessage(message, fileInfo?.name) return this.addDocMessage(message, fileInfo?.name)
else if (base64Blob && fileInfo?.type === 'image') { else if (base64Blob && fileInfo?.type === 'image') {
@ -167,6 +176,7 @@ export class MessageRequestBuilder {
messages: this.normalizeMessages(this.messages), messages: this.normalizeMessages(this.messages),
model: this.model, model: this.model,
thread: this.thread, thread: this.thread,
tools: this.tools,
} }
} }
} }