diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 2d1bdb3c2..6616208ad 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -40,12 +40,13 @@ export abstract class AIEngine extends BaseExtension { * Stops the model. */ async unloadModel(model?: Model): Promise { - if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() + if (model?.engine && model.engine.toString() !== this.provider) + return Promise.resolve() events.emit(ModelEvent.OnModelStopped, model ?? {}) return Promise.resolve() } - /* + /** * Inference request */ inference(data: MessageRequest) {} diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts index 61032357c..50940a97b 100644 --- a/core/src/browser/extensions/engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -76,7 +76,7 @@ export abstract class OAIEngine extends AIEngine { const timestamp = Date.now() / 1000 const message: ThreadMessage = { id: ulid(), - thread_id: data.threadId, + thread_id: data.thread?.id ?? data.threadId, type: data.type, assistant_id: data.assistantId, role: ChatCompletionRole.Assistant, @@ -104,6 +104,7 @@ export abstract class OAIEngine extends AIEngine { messages: data.messages ?? [], model: model.id, stream: true, + tools: data.tools, ...model.parameters, } if (this.transformPayload) { diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index edd253a57..280ce75a3 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -43,6 +43,9 @@ export type ThreadMessage = { * @data_transfer_object */ export type MessageRequest = { + /** + * The id of the message request. + */ id?: string /** @@ -71,6 +74,11 @@ export type MessageRequest = { // TODO: deprecate threadId field thread?: Thread + /** + * ChatCompletion tools + */ + tools?: MessageTool[] + /** Engine name to process */ engine?: string @@ -78,6 +86,24 @@ export type MessageRequest = { type?: string } +/** + * ChatCompletion Tool parameters + */ +export type MessageTool = { + type: string + function: MessageFunction +} + +/** + * ChatCompletion Tool's function parameters + */ +export type MessageFunction = { + name: string + description?: string + parameters?: Record + strict?: boolean +} + /** * The status of the message. * @data_transfer_object diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index f42dbdc57..759a3ebb7 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -1,19 +1,31 @@ import { useEffect, useRef } from 'react' import { - ChatCompletionRole, MessageRequestType, ExtensionTypeEnum, Thread, ThreadMessage, Model, ConversationalExtension, - EngineManager, ThreadAssistantInfo, - InferenceEngine, + events, + MessageEvent, + ContentType, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' +import { OpenAI } from 'openai' + +import { + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionRole, + ChatCompletionTool, +} from 'openai/resources/chat' + +import { Tool } from 'openai/resources/responses/responses' + +import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' import { @@ -99,7 +111,7 @@ export default function useSendChatMessage() { const newConvoData = Array.from(currentMessages) let toSendMessage = newConvoData.pop() - while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) { + while (toSendMessage && toSendMessage?.role !== 'user') { await extensionManager .get(ExtensionTypeEnum.Conversational) ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id) @@ -172,7 +184,16 @@ export default function useSendChatMessage() { parameters: runtimeParams, }, activeThreadRef.current, - messages ?? currentMessages + messages ?? currentMessages, + (await window.core.api.getTools())?.map((tool) => ({ + type: 'function' as const, + function: { + name: tool.name, + description: tool.description?.slice(0, 1024), + parameters: tool.inputSchema, + strict: false, + }, + })) ).addSystemMessage(activeAssistantRef.current?.instructions) requestBuilder.pushMessage(prompt, base64Blob, fileUpload) @@ -228,10 +249,118 @@ export default function useSendChatMessage() { } setIsGeneratingResponse(true) - // Request for inference - EngineManager.instance() - .get(InferenceEngine.cortex) - ?.inference(requestBuilder.build()) + let isDone = false + const openai = new OpenAI({ + apiKey: await window.core.api.appToken(), + baseURL: `${API_BASE_URL}/v1`, + dangerouslyAllowBrowser: true, + }) + while (!isDone) { + const data = requestBuilder.build() + const response = await openai.chat.completions.create({ + messages: (data.messages ?? []).map((e) => { + return { + role: e.role as ChatCompletionRole, + content: e.content, + } + }) as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + stream: false, + }) + if (response.choices[0]?.message.content) { + const newMessage: ThreadMessage = { + id: ulid(), + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: response.choices[0].message.role as any, + content: [ + { + type: ContentType.Text, + text: { + value: response.choices[0].message.content + ? (response.choices[0].message.content as any) + : '', + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage( + (response.choices[0].message.content as any) ?? '' + ) + events.emit(MessageEvent.OnMessageUpdate, newMessage) + } + + if (response.choices[0]?.message.tool_calls) { + for (const toolCall of response.choices[0].message.tool_calls) { + const id = ulid() + const toolMessage: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, + annotations: [], + }, + }, + ], + status: 'pending' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + events.emit(MessageEvent.OnMessageUpdate, toolMessage) + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) { + console.error(result.error) + break + } + const message: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: + `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + + (result.content[0]?.text ?? ''), + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') + requestBuilder.pushMessage('Go for the next step') + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + isDone = + !response.choices[0]?.message.tool_calls || + !response.choices[0]?.message.tool_calls.length + } // Reset states setReloadModel(false) diff --git a/web/package.json b/web/package.json index 21ce50cd0..369a4ead0 100644 --- a/web/package.json +++ b/web/package.json @@ -36,6 +36,7 @@ "marked": "^9.1.2", "next": "14.2.3", "next-themes": "^0.2.1", + "openai": "^4.90.0", "postcss": "8.4.31", "postcss-url": "10.1.3", "posthog-js": "^1.194.6", diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts index 536abde0a..546370f10 100644 --- a/web/utils/messageRequestBuilder.ts +++ b/web/utils/messageRequestBuilder.ts @@ -6,6 +6,7 @@ import { ChatCompletionRole, MessageRequest, MessageRequestType, + MessageTool, ModelInfo, Thread, ThreadMessage, @@ -22,12 +23,14 @@ export class MessageRequestBuilder { messages: ChatCompletionMessage[] model: ModelInfo thread: Thread + tools?: MessageTool[] constructor( type: MessageRequestType, model: ModelInfo, thread: Thread, - messages: ThreadMessage[] + messages: ThreadMessage[], + tools?: MessageTool[] ) { this.msgId = ulid() this.type = type @@ -39,14 +42,20 @@ export class MessageRequestBuilder { role: msg.role, content: msg.content[0]?.text?.value ?? '.', })) + this.tools = tools } + pushAssistantMessage(message: string) { + this.messages = [ + ...this.messages, + { + role: ChatCompletionRole.Assistant, + content: message, + }, + ] + } // Chainable - pushMessage( - message: string, - base64Blob: string | undefined, - fileInfo?: FileInfo - ) { + pushMessage(message: string, base64Blob?: string, fileInfo?: FileInfo) { if (base64Blob && fileInfo?.type === 'pdf') return this.addDocMessage(message, fileInfo?.name) else if (base64Blob && fileInfo?.type === 'image') { @@ -188,6 +197,7 @@ export class MessageRequestBuilder { messages: this.normalizeMessages(this.messages), model: this.model, thread: this.thread, + tools: this.tools, } } }