From 3dd80841c2ea2a4f9e1dfa23249af4e3dc0f2c1a Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 30 Mar 2025 17:56:39 +0700 Subject: [PATCH 1/7] feat: Jan Tool Use - MCP frontend implementation --- .../browser/extensions/engines/AIEngine.ts | 5 +- .../browser/extensions/engines/OAIEngine.ts | 3 +- core/src/types/message/messageEntity.ts | 26 ++++ web/hooks/useSendChatMessage.ts | 147 ++++++++++++++++-- web/package.json | 1 + web/utils/messageRequestBuilder.ts | 22 ++- 6 files changed, 186 insertions(+), 18 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 2d1bdb3c2..6616208ad 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -40,12 +40,13 @@ export abstract class AIEngine extends BaseExtension { * Stops the model. */ async unloadModel(model?: Model): Promise { - if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() + if (model?.engine && model.engine.toString() !== this.provider) + return Promise.resolve() events.emit(ModelEvent.OnModelStopped, model ?? {}) return Promise.resolve() } - /* + /** * Inference request */ inference(data: MessageRequest) {} diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts index 61032357c..50940a97b 100644 --- a/core/src/browser/extensions/engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -76,7 +76,7 @@ export abstract class OAIEngine extends AIEngine { const timestamp = Date.now() / 1000 const message: ThreadMessage = { id: ulid(), - thread_id: data.threadId, + thread_id: data.thread?.id ?? data.threadId, type: data.type, assistant_id: data.assistantId, role: ChatCompletionRole.Assistant, @@ -104,6 +104,7 @@ export abstract class OAIEngine extends AIEngine { messages: data.messages ?? [], model: model.id, stream: true, + tools: data.tools, ...model.parameters, } if (this.transformPayload) { diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index edd253a57..280ce75a3 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -43,6 +43,9 @@ export type ThreadMessage = { * @data_transfer_object */ export type MessageRequest = { + /** + * The id of the message request. + */ id?: string /** @@ -71,6 +74,11 @@ export type MessageRequest = { // TODO: deprecate threadId field thread?: Thread + /** + * ChatCompletion tools + */ + tools?: MessageTool[] + /** Engine name to process */ engine?: string @@ -78,6 +86,24 @@ export type MessageRequest = { type?: string } +/** + * ChatCompletion Tool parameters + */ +export type MessageTool = { + type: string + function: MessageFunction +} + +/** + * ChatCompletion Tool's function parameters + */ +export type MessageFunction = { + name: string + description?: string + parameters?: Record + strict?: boolean +} + /** * The status of the message. * @data_transfer_object diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index f42dbdc57..759a3ebb7 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -1,19 +1,31 @@ import { useEffect, useRef } from 'react' import { - ChatCompletionRole, MessageRequestType, ExtensionTypeEnum, Thread, ThreadMessage, Model, ConversationalExtension, - EngineManager, ThreadAssistantInfo, - InferenceEngine, + events, + MessageEvent, + ContentType, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' +import { OpenAI } from 'openai' + +import { + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionRole, + ChatCompletionTool, +} from 'openai/resources/chat' + +import { Tool } from 'openai/resources/responses/responses' + +import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' import { @@ -99,7 +111,7 @@ export default function useSendChatMessage() { const newConvoData = Array.from(currentMessages) let toSendMessage = newConvoData.pop() - while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) { + while (toSendMessage && toSendMessage?.role !== 'user') { await extensionManager .get(ExtensionTypeEnum.Conversational) ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id) @@ -172,7 +184,16 @@ export default function useSendChatMessage() { parameters: runtimeParams, }, activeThreadRef.current, - messages ?? currentMessages + messages ?? currentMessages, + (await window.core.api.getTools())?.map((tool) => ({ + type: 'function' as const, + function: { + name: tool.name, + description: tool.description?.slice(0, 1024), + parameters: tool.inputSchema, + strict: false, + }, + })) ).addSystemMessage(activeAssistantRef.current?.instructions) requestBuilder.pushMessage(prompt, base64Blob, fileUpload) @@ -228,10 +249,118 @@ export default function useSendChatMessage() { } setIsGeneratingResponse(true) - // Request for inference - EngineManager.instance() - .get(InferenceEngine.cortex) - ?.inference(requestBuilder.build()) + let isDone = false + const openai = new OpenAI({ + apiKey: await window.core.api.appToken(), + baseURL: `${API_BASE_URL}/v1`, + dangerouslyAllowBrowser: true, + }) + while (!isDone) { + const data = requestBuilder.build() + const response = await openai.chat.completions.create({ + messages: (data.messages ?? []).map((e) => { + return { + role: e.role as ChatCompletionRole, + content: e.content, + } + }) as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + stream: false, + }) + if (response.choices[0]?.message.content) { + const newMessage: ThreadMessage = { + id: ulid(), + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: response.choices[0].message.role as any, + content: [ + { + type: ContentType.Text, + text: { + value: response.choices[0].message.content + ? (response.choices[0].message.content as any) + : '', + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage( + (response.choices[0].message.content as any) ?? '' + ) + events.emit(MessageEvent.OnMessageUpdate, newMessage) + } + + if (response.choices[0]?.message.tool_calls) { + for (const toolCall of response.choices[0].message.tool_calls) { + const id = ulid() + const toolMessage: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, + annotations: [], + }, + }, + ], + status: 'pending' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + events.emit(MessageEvent.OnMessageUpdate, toolMessage) + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) { + console.error(result.error) + break + } + const message: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: + `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + + (result.content[0]?.text ?? ''), + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') + requestBuilder.pushMessage('Go for the next step') + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + isDone = + !response.choices[0]?.message.tool_calls || + !response.choices[0]?.message.tool_calls.length + } // Reset states setReloadModel(false) diff --git a/web/package.json b/web/package.json index 21ce50cd0..369a4ead0 100644 --- a/web/package.json +++ b/web/package.json @@ -36,6 +36,7 @@ "marked": "^9.1.2", "next": "14.2.3", "next-themes": "^0.2.1", + "openai": "^4.90.0", "postcss": "8.4.31", "postcss-url": "10.1.3", "posthog-js": "^1.194.6", diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts index c3da9cbd8..8733aff2c 100644 --- a/web/utils/messageRequestBuilder.ts +++ b/web/utils/messageRequestBuilder.ts @@ -6,6 +6,7 @@ import { ChatCompletionRole, MessageRequest, MessageRequestType, + MessageTool, ModelInfo, Thread, ThreadMessage, @@ -22,12 +23,14 @@ export class MessageRequestBuilder { messages: ChatCompletionMessage[] model: ModelInfo thread: Thread + tools?: MessageTool[] constructor( type: MessageRequestType, model: ModelInfo, thread: Thread, - messages: ThreadMessage[] + messages: ThreadMessage[], + tools?: MessageTool[] ) { this.msgId = ulid() this.type = type @@ -39,14 +42,20 @@ export class MessageRequestBuilder { role: msg.role, content: msg.content[0]?.text?.value ?? '.', })) + this.tools = tools } + pushAssistantMessage(message: string) { + this.messages = [ + ...this.messages, + { + role: ChatCompletionRole.Assistant, + content: message, + }, + ] + } // Chainable - pushMessage( - message: string, - base64Blob: string | undefined, - fileInfo?: FileInfo - ) { + pushMessage(message: string, base64Blob?: string, fileInfo?: FileInfo) { if (base64Blob && fileInfo?.type === 'pdf') return this.addDocMessage(message, fileInfo?.name) else if (base64Blob && fileInfo?.type === 'image') { @@ -167,6 +176,7 @@ export class MessageRequestBuilder { messages: this.normalizeMessages(this.messages), model: this.model, thread: this.thread, + tools: this.tools, } } } From e4658ce98c7101a83fab6806414eb6f092a3c214 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 31 Mar 2025 19:41:15 +0700 Subject: [PATCH 2/7] chore: tool type --- web/hooks/useSendChatMessage.ts | 5 +++-- web/types/model.d.ts | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 759a3ebb7..1c6f77905 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -58,6 +58,7 @@ import { updateThreadAtom, updateThreadWaitingForResponseAtom, } from '@/helpers/atoms/Thread.atom' +import { ModelTool } from '@/types/model' export const reloadModelAtom = atom(false) @@ -185,7 +186,7 @@ export default function useSendChatMessage() { }, activeThreadRef.current, messages ?? currentMessages, - (await window.core.api.getTools())?.map((tool) => ({ + (await window.core.api.getTools())?.map((tool: ModelTool) => ({ type: 'function' as const, function: { name: tool.name, @@ -311,7 +312,7 @@ export default function useSendChatMessage() { { type: ContentType.Text, text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, + value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, annotations: [], }, }, diff --git a/web/types/model.d.ts b/web/types/model.d.ts index bbe9d2cc6..5da2b7972 100644 --- a/web/types/model.d.ts +++ b/web/types/model.d.ts @@ -2,3 +2,9 @@ * ModelParams types */ export type ModelParams = ModelRuntimeParams | ModelSettingParams + +export type ModelTool = { + name: string + description: string + inputSchema: string +} From c335caeb42ba9441b359b624b177be83cce91aa7 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 2 Apr 2025 10:26:02 +0700 Subject: [PATCH 3/7] refactor: remove lagecy tools --- core/src/browser/tools/index.test.ts | 5 - core/src/browser/tools/index.ts | 2 - core/src/browser/tools/manager.ts | 47 ----- core/src/browser/tools/tool.test.ts | 63 ------ core/src/browser/tools/tool.ts | 12 -- extensions/assistant-extension/src/index.ts | 9 +- .../assistant-extension/src/node/index.ts | 45 ---- .../assistant-extension/src/node/retrieval.ts | 121 ----------- .../src/tools/retrieval.ts | 118 ----------- web/hooks/useSendChatMessage.ts | 196 +++++++++--------- web/services/coreService.ts | 3 +- 11 files changed, 104 insertions(+), 517 deletions(-) delete mode 100644 core/src/browser/tools/index.test.ts delete mode 100644 core/src/browser/tools/index.ts delete mode 100644 core/src/browser/tools/manager.ts delete mode 100644 core/src/browser/tools/tool.test.ts delete mode 100644 core/src/browser/tools/tool.ts delete mode 100644 extensions/assistant-extension/src/node/index.ts delete mode 100644 extensions/assistant-extension/src/node/retrieval.ts delete mode 100644 extensions/assistant-extension/src/tools/retrieval.ts diff --git a/core/src/browser/tools/index.test.ts b/core/src/browser/tools/index.test.ts deleted file mode 100644 index 8a24d3bb6..000000000 --- a/core/src/browser/tools/index.test.ts +++ /dev/null @@ -1,5 +0,0 @@ - - -it('should not throw any errors when imported', () => { - expect(() => require('./index')).not.toThrow(); -}) diff --git a/core/src/browser/tools/index.ts b/core/src/browser/tools/index.ts deleted file mode 100644 index 24cd12780..000000000 --- a/core/src/browser/tools/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from './manager' -export * from './tool' diff --git a/core/src/browser/tools/manager.ts b/core/src/browser/tools/manager.ts deleted file mode 100644 index b323ad7ce..000000000 --- a/core/src/browser/tools/manager.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { AssistantTool, MessageRequest } from '../../types' -import { InferenceTool } from './tool' - -/** - * Manages the registration and retrieval of inference tools. - */ -export class ToolManager { - public tools = new Map() - - /** - * Registers a tool. - * @param tool - The tool to register. - */ - register(tool: T) { - this.tools.set(tool.name, tool) - } - - /** - * Retrieves a tool by it's name. - * @param name - The name of the tool to retrieve. - * @returns The tool, if found. - */ - get(name: string): T | undefined { - return this.tools.get(name) as T | undefined - } - - /* - ** Process the message request with the tools. - */ - process(request: MessageRequest, tools: AssistantTool[]): Promise { - return tools.reduce((prevPromise, currentTool) => { - return prevPromise.then((prevResult) => { - return currentTool.enabled - ? this.get(currentTool.type)?.process(prevResult, currentTool) ?? - Promise.resolve(prevResult) - : Promise.resolve(prevResult) - }) - }, Promise.resolve(request)) - } - - /** - * The instance of the tool manager. - */ - static instance(): ToolManager { - return (window.core?.toolManager as ToolManager) ?? new ToolManager() - } -} diff --git a/core/src/browser/tools/tool.test.ts b/core/src/browser/tools/tool.test.ts deleted file mode 100644 index dcb478478..000000000 --- a/core/src/browser/tools/tool.test.ts +++ /dev/null @@ -1,63 +0,0 @@ -import { ToolManager } from '../../browser/tools/manager' -import { InferenceTool } from '../../browser/tools/tool' -import { AssistantTool, MessageRequest } from '../../types' - -class MockInferenceTool implements InferenceTool { - name = 'mockTool' - process(request: MessageRequest, tool: AssistantTool): Promise { - return Promise.resolve(request) - } -} - -it('should register a tool', () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - expect(manager.get(tool.name)).toBe(tool) -}) - -it('should retrieve a tool by its name', () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - const retrievedTool = manager.get(tool.name) - expect(retrievedTool).toBe(tool) -}) - -it('should return undefined for a non-existent tool', () => { - const manager = new ToolManager() - const retrievedTool = manager.get('nonExistentTool') - expect(retrievedTool).toBeUndefined() -}) - -it('should process the message request with enabled tools', async () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - - const request: MessageRequest = { message: 'test' } as any - const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any - - const result = await manager.process(request, tools) - expect(result).toBe(request) -}) - -it('should skip processing for disabled tools', async () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - - const request: MessageRequest = { message: 'test' } as any - const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any - - const result = await manager.process(request, tools) - expect(result).toBe(request) -}) - -it('should throw an error when process is called without implementation', () => { - class TestTool extends InferenceTool { - name = 'testTool' - } - const tool = new TestTool() - expect(() => tool.process({} as MessageRequest)).toThrowError() -}) diff --git a/core/src/browser/tools/tool.ts b/core/src/browser/tools/tool.ts deleted file mode 100644 index 0fd342933..000000000 --- a/core/src/browser/tools/tool.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { AssistantTool, MessageRequest } from '../../types' - -/** - * Represents a base inference tool. - */ -export abstract class InferenceTool { - abstract name: string - /* - ** Process a message request and return the processed message request. - */ - abstract process(request: MessageRequest, tool?: AssistantTool): Promise -} diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 0f9b52808..bb253bd7f 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -1,12 +1,7 @@ -import { Assistant, AssistantExtension, ToolManager } from '@janhq/core' -import { RetrievalTool } from './tools/retrieval' +import { Assistant, AssistantExtension } from '@janhq/core' export default class JanAssistantExtension extends AssistantExtension { - - async onLoad() { - // Register the retrieval tool - ToolManager.instance().register(new RetrievalTool()) - } + async onLoad() {} /** * Called when the extension is unloaded. diff --git a/extensions/assistant-extension/src/node/index.ts b/extensions/assistant-extension/src/node/index.ts deleted file mode 100644 index 731890b34..000000000 --- a/extensions/assistant-extension/src/node/index.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { getJanDataFolderPath } from '@janhq/core/node' -import { retrieval } from './retrieval' -import path from 'path' - -export function toolRetrievalUpdateTextSplitter( - chunkSize: number, - chunkOverlap: number -) { - retrieval.updateTextSplitter(chunkSize, chunkOverlap) -} -export async function toolRetrievalIngestNewDocument( - thread: string, - file: string, - model: string, - engine: string, - useTimeWeighted: boolean -) { - const threadPath = path.join(getJanDataFolderPath(), 'threads', thread) - const filePath = path.join(getJanDataFolderPath(), 'files', file) - retrieval.updateEmbeddingEngine(model, engine) - return retrieval - .ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted) - .catch((err) => { - console.error(err) - }) -} - -export async function toolRetrievalLoadThreadMemory(threadId: string) { - return retrieval - .loadRetrievalAgent( - path.join(getJanDataFolderPath(), 'threads', threadId, 'memory') - ) - .catch((err) => { - console.error(err) - }) -} - -export async function toolRetrievalQueryResult( - query: string, - useTimeWeighted: boolean = false -) { - return retrieval.generateResult(query, useTimeWeighted).catch((err) => { - console.error(err) - }) -} diff --git a/extensions/assistant-extension/src/node/retrieval.ts b/extensions/assistant-extension/src/node/retrieval.ts deleted file mode 100644 index 0e80fd2d7..000000000 --- a/extensions/assistant-extension/src/node/retrieval.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter' -import { formatDocumentsAsString } from 'langchain/util/document' -import { PDFLoader } from 'langchain/document_loaders/fs/pdf' - -import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted' -import { MemoryVectorStore } from 'langchain/vectorstores/memory' - -import { HNSWLib } from 'langchain/vectorstores/hnswlib' - -import { OpenAIEmbeddings } from 'langchain/embeddings/openai' - -export class Retrieval { - public chunkSize: number = 100 - public chunkOverlap?: number = 0 - private retriever: any - - private embeddingModel?: OpenAIEmbeddings = undefined - private textSplitter?: RecursiveCharacterTextSplitter - - // to support time-weighted retrieval - private timeWeightedVectorStore: MemoryVectorStore - private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever - - constructor(chunkSize: number = 4000, chunkOverlap: number = 200) { - this.updateTextSplitter(chunkSize, chunkOverlap) - this.initialize() - } - - private async initialize() { - const apiKey = await window.core?.api.appToken() - - // declare time-weighted retriever and storage - this.timeWeightedVectorStore = new MemoryVectorStore( - new OpenAIEmbeddings( - { openAIApiKey: apiKey }, - { basePath: `${CORTEX_API_URL}/v1` } - ) - ) - this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({ - vectorStore: this.timeWeightedVectorStore, - memoryStream: [], - searchKwargs: 2, - }) - } - - public updateTextSplitter(chunkSize: number, chunkOverlap: number): void { - this.chunkSize = chunkSize - this.chunkOverlap = chunkOverlap - this.textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: chunkSize, - chunkOverlap: chunkOverlap, - }) - } - - public async updateEmbeddingEngine(model: string, engine: string) { - const apiKey = await window.core?.api.appToken() - this.embeddingModel = new OpenAIEmbeddings( - { openAIApiKey: apiKey, model }, - // TODO: Raw settings - { basePath: `${CORTEX_API_URL}/v1` } - ) - - // update time-weighted embedding model - this.timeWeightedVectorStore.embeddings = this.embeddingModel - } - - public ingestAgentKnowledge = async ( - filePath: string, - memoryPath: string, - useTimeWeighted: boolean - ): Promise => { - const loader = new PDFLoader(filePath, { - splitPages: true, - }) - if (!this.embeddingModel) return Promise.reject() - const doc = await loader.load() - const docs = await this.textSplitter!.splitDocuments(doc) - const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel) - - // add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval - if (useTimeWeighted && this.timeWeightedretriever) { - await ( - this.timeWeightedretriever as TimeWeightedVectorStoreRetriever - ).addDocuments(docs) - } - return vectorStore.save(memoryPath) - } - - public loadRetrievalAgent = async (memoryPath: string): Promise => { - if (!this.embeddingModel) return Promise.reject() - const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel) - this.retriever = vectorStore.asRetriever(2) - return Promise.resolve() - } - - public generateResult = async ( - query: string, - useTimeWeighted: boolean - ): Promise => { - if (useTimeWeighted) { - if (!this.timeWeightedretriever) { - return Promise.resolve(' ') - } - // use invoke because getRelevantDocuments is deprecated - const relevantDocs = await this.timeWeightedretriever.invoke(query) - const serializedDoc = formatDocumentsAsString(relevantDocs) - return Promise.resolve(serializedDoc) - } - - if (!this.retriever) { - return Promise.resolve(' ') - } - - // should use invoke(query) because getRelevantDocuments is deprecated - const relevantDocs = await this.retriever.getRelevantDocuments(query) - const serializedDoc = formatDocumentsAsString(relevantDocs) - return Promise.resolve(serializedDoc) - } -} - -export const retrieval = new Retrieval() diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts deleted file mode 100644 index b1a0c3cba..000000000 --- a/extensions/assistant-extension/src/tools/retrieval.ts +++ /dev/null @@ -1,118 +0,0 @@ -import { - AssistantTool, - executeOnMain, - fs, - InferenceTool, - joinPath, - MessageRequest, -} from '@janhq/core' - -export class RetrievalTool extends InferenceTool { - private _threadDir = 'file://threads' - private retrievalThreadId: string | undefined = undefined - - name: string = 'retrieval' - - async process( - data: MessageRequest, - tool?: AssistantTool - ): Promise { - if (!data.model || !data.messages) { - return Promise.resolve(data) - } - - const latestMessage = data.messages[data.messages.length - 1] - - // 1. Ingest the document if needed - if ( - latestMessage && - latestMessage.content && - typeof latestMessage.content !== 'string' && - latestMessage.content.length > 1 - ) { - const docFile = latestMessage.content[1]?.doc_url?.url - if (docFile) { - await executeOnMain( - NODE, - 'toolRetrievalIngestNewDocument', - data.thread?.id, - docFile, - data.model?.id, - data.model?.engine, - tool?.useTimeWeightedRetriever ?? false - ) - } else { - return Promise.resolve(data) - } - } else if ( - // Check whether we need to ingest document or not - // Otherwise wrong context will be sent - !(await fs.existsSync( - await joinPath([this._threadDir, data.threadId, 'memory']) - )) - ) { - // No document ingested, reroute the result to inference engine - - return Promise.resolve(data) - } - // 2. Load agent on thread changed - if (this.retrievalThreadId !== data.threadId) { - await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) - - this.retrievalThreadId = data.threadId - - // Update the text splitter - await executeOnMain( - NODE, - 'toolRetrievalUpdateTextSplitter', - tool?.settings?.chunk_size ?? 4000, - tool?.settings?.chunk_overlap ?? 200 - ) - } - - // 3. Using the retrieval template with the result and query - if (latestMessage.content) { - const prompt = - typeof latestMessage.content === 'string' - ? latestMessage.content - : latestMessage.content[0].text - // Retrieve the result - const retrievalResult = await executeOnMain( - NODE, - 'toolRetrievalQueryResult', - prompt, - tool?.useTimeWeightedRetriever ?? false - ) - console.debug('toolRetrievalQueryResult', retrievalResult) - - // Update message content - if (retrievalResult) - data.messages[data.messages.length - 1].content = - tool?.settings?.retrieval_template - ?.replace('{CONTEXT}', retrievalResult) - .replace('{QUESTION}', prompt) - } - - // 4. Reroute the result to inference engine - return Promise.resolve(this.normalize(data)) - } - - // Filter out all the messages that are not text - // TODO: Remove it until engines can handle multiple content types - normalize(request: MessageRequest): MessageRequest { - request.messages = request.messages?.map((message) => { - if ( - message.content && - typeof message.content !== 'string' && - (message.content.length ?? 0) > 0 - ) { - return { - ...message, - content: [message.content[0]], - } - } - return message - }) - return request - } -} diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 1c6f77905..1c9334fb5 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -11,20 +11,19 @@ import { events, MessageEvent, ContentType, + EngineManager, + InferenceEngine, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { OpenAI } from 'openai' import { - ChatCompletionMessage, ChatCompletionMessageParam, ChatCompletionRole, ChatCompletionTool, } from 'openai/resources/chat' -import { Tool } from 'openai/resources/responses/responses' - import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -250,100 +249,41 @@ export default function useSendChatMessage() { } setIsGeneratingResponse(true) - let isDone = false - const openai = new OpenAI({ - apiKey: await window.core.api.appToken(), - baseURL: `${API_BASE_URL}/v1`, - dangerouslyAllowBrowser: true, - }) - while (!isDone) { - const data = requestBuilder.build() - const response = await openai.chat.completions.create({ - messages: (data.messages ?? []).map((e) => { - return { - role: e.role as ChatCompletionRole, - content: e.content, - } - }) as ChatCompletionMessageParam[], - model: data.model?.id ?? '', - tools: data.tools as ChatCompletionTool[], - stream: false, + if (requestBuilder.tools && requestBuilder.tools.length) { + let isDone = false + const openai = new OpenAI({ + apiKey: await window.core.api.appToken(), + baseURL: `${API_BASE_URL}/v1`, + dangerouslyAllowBrowser: true, }) - if (response.choices[0]?.message.content) { - const newMessage: ThreadMessage = { - id: ulid(), - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: response.choices[0].message.role as any, - content: [ - { - type: ContentType.Text, - text: { - value: response.choices[0].message.content - ? (response.choices[0].message.content as any) - : '', - annotations: [], - }, - }, - ], - status: 'ready' as any, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage( - (response.choices[0].message.content as any) ?? '' - ) - events.emit(MessageEvent.OnMessageUpdate, newMessage) - } - - if (response.choices[0]?.message.tool_calls) { - for (const toolCall of response.choices[0].message.tool_calls) { - const id = ulid() - const toolMessage: ThreadMessage = { - id: id, + while (!isDone) { + const data = requestBuilder.build() + const response = await openai.chat.completions.create({ + messages: (data.messages ?? []).map((e) => { + return { + role: e.role as ChatCompletionRole, + content: e.content, + } + }) as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + stream: false, + }) + if (response.choices[0]?.message.content) { + const newMessage: ThreadMessage = { + id: ulid(), object: 'message', thread_id: activeThreadRef.current.id, assistant_id: activeAssistantRef.current.assistant_id, attachments: [], - role: 'assistant' as any, + role: response.choices[0].message.role as any, content: [ { type: ContentType.Text, text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, - annotations: [], - }, - }, - ], - status: 'pending' as any, - created_at: Date.now(), - completed_at: Date.now(), - } - events.emit(MessageEvent.OnMessageUpdate, toolMessage) - const result = await window.core.api.callTool({ - toolName: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments), - }) - if (result.error) { - console.error(result.error) - break - } - const message: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: 'assistant' as any, - content: [ - { - type: ContentType.Text, - text: { - value: - `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + - (result.content[0]?.text ?? ''), + value: response.choices[0].message.content + ? (response.choices[0].message.content as any) + : '', annotations: [], }, }, @@ -352,15 +292,81 @@ export default function useSendChatMessage() { created_at: Date.now(), completed_at: Date.now(), } - requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') - requestBuilder.pushMessage('Go for the next step') - events.emit(MessageEvent.OnMessageUpdate, message) + requestBuilder.pushAssistantMessage( + (response.choices[0].message.content as any) ?? '' + ) + events.emit(MessageEvent.OnMessageUpdate, newMessage) } - } - isDone = - !response.choices[0]?.message.tool_calls || - !response.choices[0]?.message.tool_calls.length + if (response.choices[0]?.message.tool_calls) { + for (const toolCall of response.choices[0].message.tool_calls) { + const id = ulid() + const toolMessage: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, + annotations: [], + }, + }, + ], + status: 'pending' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + events.emit(MessageEvent.OnMessageUpdate, toolMessage) + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) { + console.error(result.error) + break + } + const message: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: + `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + + (result.content[0]?.text ?? ''), + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') + requestBuilder.pushMessage('Go for the next step') + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + isDone = + !response.choices[0]?.message.tool_calls || + !response.choices[0]?.message.tool_calls.length + } + } else { + // Request for inference + EngineManager.instance() + .get(InferenceEngine.cortex) + ?.inference(requestBuilder.build()) } // Reset states diff --git a/web/services/coreService.ts b/web/services/coreService.ts index cda83ceaf..4143facf3 100644 --- a/web/services/coreService.ts +++ b/web/services/coreService.ts @@ -1,4 +1,4 @@ -import { EngineManager, ToolManager } from '@janhq/core' +import { EngineManager } from '@janhq/core' import { appService } from './appService' import { EventEmitter } from './eventsService' @@ -16,7 +16,6 @@ export const setupCoreServices = () => { window.core = { events: new EventEmitter(), engineManager: new EngineManager(), - toolManager: new ToolManager(), api: { ...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)), ...appService, From 8b1709c14f66ba538ad4edee8b172b82cd66e621 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 2 Apr 2025 10:32:42 +0700 Subject: [PATCH 4/7] refactor: clean up assistant extension to remove node modules --- core/src/browser/index.ts | 6 ------ extensions/assistant-extension/package.json | 19 +++---------------- .../assistant-extension/rolldown.config.mjs | 19 +------------------ 3 files changed, 4 insertions(+), 40 deletions(-) diff --git a/core/src/browser/index.ts b/core/src/browser/index.ts index a6ce187ca..5912d8c3b 100644 --- a/core/src/browser/index.ts +++ b/core/src/browser/index.ts @@ -28,12 +28,6 @@ export * from './extension' */ export * from './extensions' -/** - * Export all base tools. - * @module - */ -export * from './tools' - /** * Export all base models. * @module diff --git a/extensions/assistant-extension/package.json b/extensions/assistant-extension/package.json index 08ccb3b3d..4761aa900 100644 --- a/extensions/assistant-extension/package.json +++ b/extensions/assistant-extension/package.json @@ -8,17 +8,10 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { - "clean:modules": "rimraf node_modules/pdf-parse/test && cd node_modules/pdf-parse/lib/pdf.js && rimraf v1.9.426 v1.10.88 v2.0.550", - "build-universal-hnswlib": "[ \"$IS_TEST\" = \"true\" ] && echo \"Skip universal build\" || (cd node_modules/hnswlib-node && arch -x86_64 npx node-gyp rebuild --arch=x64 && mv build/Release/addon.node ./addon-amd64.node && node-gyp rebuild --arch=arm64 && mv build/Release/addon.node ./addon-arm64.node && lipo -create -output build/Release/addon.node ./addon-arm64.node ./addon-amd64.node && rm ./addon-arm64.node && rm ./addon-amd64.node)", - "build": "yarn clean:modules && rolldown -c rolldown.config.mjs", - "build:publish:linux": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install", - "build:publish:darwin": "rimraf *.tgz --glob || true && yarn build-universal-hnswlib && yarn build && ../../.github/scripts/auto-sign.sh && npm pack && cpx *.tgz ../../pre-install", - "build:publish:win32": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install", - "build:publish": "run-script-os", - "build:dev": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install" + "build": "rolldown -c rolldown.config.mjs", + "build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install" }, "devDependencies": { - "@types/pdf-parse": "^1.1.4", "cpx": "^1.5.0", "rimraf": "^3.0.2", "rolldown": "1.0.0-beta.1", @@ -27,11 +20,6 @@ }, "dependencies": { "@janhq/core": "../../core/package.tgz", - "@langchain/community": "0.0.13", - "hnswlib-node": "^1.4.2", - "langchain": "^0.0.214", - "node-gyp": "^11.0.0", - "pdf-parse": "^1.1.1", "ts-loader": "^9.5.0" }, "files": [ @@ -40,8 +28,7 @@ "README.md" ], "bundleDependencies": [ - "@janhq/core", - "hnswlib-node" + "@janhq/core" ], "installConfig": { "hoistingLimits": "workspaces" diff --git a/extensions/assistant-extension/rolldown.config.mjs b/extensions/assistant-extension/rolldown.config.mjs index e549ea7d9..436de93a8 100644 --- a/extensions/assistant-extension/rolldown.config.mjs +++ b/extensions/assistant-extension/rolldown.config.mjs @@ -13,22 +13,5 @@ export default defineConfig([ NODE: JSON.stringify(`${pkgJson.name}/${pkgJson.node}`), VERSION: JSON.stringify(pkgJson.version), }, - }, - { - input: 'src/node/index.ts', - external: ['@janhq/core/node', 'path', 'hnswlib-node'], - output: { - format: 'cjs', - file: 'dist/node/index.js', - sourcemap: false, - inlineDynamicImports: true, - }, - resolve: { - extensions: ['.js', '.ts'], - }, - define: { - CORTEX_API_URL: JSON.stringify(`http://127.0.0.1:${process.env.CORTEX_API_PORT ?? "39291"}`), - }, - platform: 'node', - }, + } ]) From 7392b2f92b2a1db7bacca6b4f056c30c4061ec6b Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 2 Apr 2025 23:53:54 +0700 Subject: [PATCH 5/7] chore: app updater --- src-tauri/Cargo.toml | 3 +++ src-tauri/src/core/mcp.rs | 17 ++++++++++------- src-tauri/src/core/mod.rs | 4 ++-- src-tauri/src/core/state.rs | 2 +- src-tauri/tauri.conf.json | 9 +++++++++ .../ModalAppUpdaterChangelog/index.tsx | 19 ++++++++++++++++--- web/package.json | 1 + 7 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 312392eb5..20df85561 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -43,3 +43,6 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai "transport-child-process", "tower", ] } + +[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] +tauri-plugin-updater = "2" diff --git a/src-tauri/src/core/mcp.rs b/src-tauri/src/core/mcp.rs index c016effd7..98bf0ad3b 100644 --- a/src-tauri/src/core/mcp.rs +++ b/src-tauri/src/core/mcp.rs @@ -30,16 +30,19 @@ pub async fn run_mcp_commands( if let Some(server_map) = mcp_servers.get("mcpServers").and_then(Value::as_object) { println!("MCP Servers: {server_map:#?}"); - + for (name, config) in server_map { if let Some((command, args)) = extract_command_args(config) { let mut cmd = Command::new(command); - args.iter().filter_map(Value::as_str).for_each(|arg| { cmd.arg(arg); }); - - let service = ().serve(TokioChildProcess::new(&mut cmd).map_err(|e| e.to_string())?) - .await - .map_err(|e| e.to_string())?; - + args.iter().filter_map(Value::as_str).for_each(|arg| { + cmd.arg(arg); + }); + + let service = + ().serve(TokioChildProcess::new(&mut cmd).map_err(|e| e.to_string())?) + .await + .map_err(|e| e.to_string())?; + servers_state.lock().await.insert(name.clone(), service); } } diff --git a/src-tauri/src/core/mod.rs b/src-tauri/src/core/mod.rs index baa8c2834..e4f0ee6c4 100644 --- a/src-tauri/src/core/mod.rs +++ b/src-tauri/src/core/mod.rs @@ -1,6 +1,6 @@ pub mod cmd; pub mod fs; +pub mod mcp; +pub mod server; pub mod setup; pub mod state; -pub mod server; -pub mod mcp; \ No newline at end of file diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index 93d770bc2..925030085 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; #[derive(Default)] pub struct AppState { pub app_token: Option, - pub mcp_servers: Arc>>> + pub mcp_servers: Arc>>>, } pub fn generate_app_token() -> String { rand::thread_rng() diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index d59949f8b..3b44dc079 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -40,9 +40,18 @@ } } }, + "plugins": { + "updater": { + "pubkey": "", + "endpoints": [ + "https://github.com/menloresearch/jan/releases/latest/download/latest.json" + ] + } + }, "bundle": { "active": true, "targets": "all", + "createUpdaterArtifacts": true, "icon": [ "icons/32x32.png", "icons/128x128.png", diff --git a/web/containers/ModalAppUpdaterChangelog/index.tsx b/web/containers/ModalAppUpdaterChangelog/index.tsx index 705623a90..fa519780c 100644 --- a/web/containers/ModalAppUpdaterChangelog/index.tsx +++ b/web/containers/ModalAppUpdaterChangelog/index.tsx @@ -1,7 +1,8 @@ -import React, { useEffect, useState } from 'react' +import React, { useEffect, useRef, useState } from 'react' import { Button, Modal } from '@janhq/joi' +import { check, Update } from '@tauri-apps/plugin-updater' import { useAtom } from 'jotai' import { useGetLatestRelease } from '@/hooks/useGetLatestRelease' @@ -16,6 +17,7 @@ const ModalAppUpdaterChangelog = () => { const [appUpdateAvailable, setAppUpdateAvailable] = useAtom( appUpdateAvailableAtom ) + const updaterRef = useRef(null) const [open, setOpen] = useState(appUpdateAvailable) @@ -26,6 +28,17 @@ const ModalAppUpdaterChangelog = () => { const beta = VERSION.includes('beta') const nightly = VERSION.includes('-') + const checkForUpdate = async () => { + const update = await check() + if (update) { + setAppUpdateAvailable(true) + updaterRef.current = update + } + } + useEffect(() => { + checkForUpdate() + }, []) + const { release } = useGetLatestRelease(beta ? true : false) return ( @@ -73,8 +86,8 @@ const ModalAppUpdaterChangelog = () => {