From c335caeb42ba9441b359b624b177be83cce91aa7 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 2 Apr 2025 10:26:02 +0700 Subject: [PATCH] refactor: remove lagecy tools --- core/src/browser/tools/index.test.ts | 5 - core/src/browser/tools/index.ts | 2 - core/src/browser/tools/manager.ts | 47 ----- core/src/browser/tools/tool.test.ts | 63 ------ core/src/browser/tools/tool.ts | 12 -- extensions/assistant-extension/src/index.ts | 9 +- .../assistant-extension/src/node/index.ts | 45 ---- .../assistant-extension/src/node/retrieval.ts | 121 ----------- .../src/tools/retrieval.ts | 118 ----------- web/hooks/useSendChatMessage.ts | 196 +++++++++--------- web/services/coreService.ts | 3 +- 11 files changed, 104 insertions(+), 517 deletions(-) delete mode 100644 core/src/browser/tools/index.test.ts delete mode 100644 core/src/browser/tools/index.ts delete mode 100644 core/src/browser/tools/manager.ts delete mode 100644 core/src/browser/tools/tool.test.ts delete mode 100644 core/src/browser/tools/tool.ts delete mode 100644 extensions/assistant-extension/src/node/index.ts delete mode 100644 extensions/assistant-extension/src/node/retrieval.ts delete mode 100644 extensions/assistant-extension/src/tools/retrieval.ts diff --git a/core/src/browser/tools/index.test.ts b/core/src/browser/tools/index.test.ts deleted file mode 100644 index 8a24d3bb6..000000000 --- a/core/src/browser/tools/index.test.ts +++ /dev/null @@ -1,5 +0,0 @@ - - -it('should not throw any errors when imported', () => { - expect(() => require('./index')).not.toThrow(); -}) diff --git a/core/src/browser/tools/index.ts b/core/src/browser/tools/index.ts deleted file mode 100644 index 24cd12780..000000000 --- a/core/src/browser/tools/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from './manager' -export * from './tool' diff --git a/core/src/browser/tools/manager.ts b/core/src/browser/tools/manager.ts deleted file mode 100644 index b323ad7ce..000000000 --- a/core/src/browser/tools/manager.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { AssistantTool, MessageRequest } from '../../types' -import { InferenceTool } from './tool' - -/** - * Manages the registration and retrieval of inference tools. - */ -export class ToolManager { - public tools = new Map() - - /** - * Registers a tool. - * @param tool - The tool to register. - */ - register(tool: T) { - this.tools.set(tool.name, tool) - } - - /** - * Retrieves a tool by it's name. - * @param name - The name of the tool to retrieve. - * @returns The tool, if found. - */ - get(name: string): T | undefined { - return this.tools.get(name) as T | undefined - } - - /* - ** Process the message request with the tools. - */ - process(request: MessageRequest, tools: AssistantTool[]): Promise { - return tools.reduce((prevPromise, currentTool) => { - return prevPromise.then((prevResult) => { - return currentTool.enabled - ? this.get(currentTool.type)?.process(prevResult, currentTool) ?? - Promise.resolve(prevResult) - : Promise.resolve(prevResult) - }) - }, Promise.resolve(request)) - } - - /** - * The instance of the tool manager. - */ - static instance(): ToolManager { - return (window.core?.toolManager as ToolManager) ?? new ToolManager() - } -} diff --git a/core/src/browser/tools/tool.test.ts b/core/src/browser/tools/tool.test.ts deleted file mode 100644 index dcb478478..000000000 --- a/core/src/browser/tools/tool.test.ts +++ /dev/null @@ -1,63 +0,0 @@ -import { ToolManager } from '../../browser/tools/manager' -import { InferenceTool } from '../../browser/tools/tool' -import { AssistantTool, MessageRequest } from '../../types' - -class MockInferenceTool implements InferenceTool { - name = 'mockTool' - process(request: MessageRequest, tool: AssistantTool): Promise { - return Promise.resolve(request) - } -} - -it('should register a tool', () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - expect(manager.get(tool.name)).toBe(tool) -}) - -it('should retrieve a tool by its name', () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - const retrievedTool = manager.get(tool.name) - expect(retrievedTool).toBe(tool) -}) - -it('should return undefined for a non-existent tool', () => { - const manager = new ToolManager() - const retrievedTool = manager.get('nonExistentTool') - expect(retrievedTool).toBeUndefined() -}) - -it('should process the message request with enabled tools', async () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - - const request: MessageRequest = { message: 'test' } as any - const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any - - const result = await manager.process(request, tools) - expect(result).toBe(request) -}) - -it('should skip processing for disabled tools', async () => { - const manager = new ToolManager() - const tool = new MockInferenceTool() - manager.register(tool) - - const request: MessageRequest = { message: 'test' } as any - const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any - - const result = await manager.process(request, tools) - expect(result).toBe(request) -}) - -it('should throw an error when process is called without implementation', () => { - class TestTool extends InferenceTool { - name = 'testTool' - } - const tool = new TestTool() - expect(() => tool.process({} as MessageRequest)).toThrowError() -}) diff --git a/core/src/browser/tools/tool.ts b/core/src/browser/tools/tool.ts deleted file mode 100644 index 0fd342933..000000000 --- a/core/src/browser/tools/tool.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { AssistantTool, MessageRequest } from '../../types' - -/** - * Represents a base inference tool. - */ -export abstract class InferenceTool { - abstract name: string - /* - ** Process a message request and return the processed message request. - */ - abstract process(request: MessageRequest, tool?: AssistantTool): Promise -} diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 0f9b52808..bb253bd7f 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -1,12 +1,7 @@ -import { Assistant, AssistantExtension, ToolManager } from '@janhq/core' -import { RetrievalTool } from './tools/retrieval' +import { Assistant, AssistantExtension } from '@janhq/core' export default class JanAssistantExtension extends AssistantExtension { - - async onLoad() { - // Register the retrieval tool - ToolManager.instance().register(new RetrievalTool()) - } + async onLoad() {} /** * Called when the extension is unloaded. diff --git a/extensions/assistant-extension/src/node/index.ts b/extensions/assistant-extension/src/node/index.ts deleted file mode 100644 index 731890b34..000000000 --- a/extensions/assistant-extension/src/node/index.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { getJanDataFolderPath } from '@janhq/core/node' -import { retrieval } from './retrieval' -import path from 'path' - -export function toolRetrievalUpdateTextSplitter( - chunkSize: number, - chunkOverlap: number -) { - retrieval.updateTextSplitter(chunkSize, chunkOverlap) -} -export async function toolRetrievalIngestNewDocument( - thread: string, - file: string, - model: string, - engine: string, - useTimeWeighted: boolean -) { - const threadPath = path.join(getJanDataFolderPath(), 'threads', thread) - const filePath = path.join(getJanDataFolderPath(), 'files', file) - retrieval.updateEmbeddingEngine(model, engine) - return retrieval - .ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted) - .catch((err) => { - console.error(err) - }) -} - -export async function toolRetrievalLoadThreadMemory(threadId: string) { - return retrieval - .loadRetrievalAgent( - path.join(getJanDataFolderPath(), 'threads', threadId, 'memory') - ) - .catch((err) => { - console.error(err) - }) -} - -export async function toolRetrievalQueryResult( - query: string, - useTimeWeighted: boolean = false -) { - return retrieval.generateResult(query, useTimeWeighted).catch((err) => { - console.error(err) - }) -} diff --git a/extensions/assistant-extension/src/node/retrieval.ts b/extensions/assistant-extension/src/node/retrieval.ts deleted file mode 100644 index 0e80fd2d7..000000000 --- a/extensions/assistant-extension/src/node/retrieval.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter' -import { formatDocumentsAsString } from 'langchain/util/document' -import { PDFLoader } from 'langchain/document_loaders/fs/pdf' - -import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted' -import { MemoryVectorStore } from 'langchain/vectorstores/memory' - -import { HNSWLib } from 'langchain/vectorstores/hnswlib' - -import { OpenAIEmbeddings } from 'langchain/embeddings/openai' - -export class Retrieval { - public chunkSize: number = 100 - public chunkOverlap?: number = 0 - private retriever: any - - private embeddingModel?: OpenAIEmbeddings = undefined - private textSplitter?: RecursiveCharacterTextSplitter - - // to support time-weighted retrieval - private timeWeightedVectorStore: MemoryVectorStore - private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever - - constructor(chunkSize: number = 4000, chunkOverlap: number = 200) { - this.updateTextSplitter(chunkSize, chunkOverlap) - this.initialize() - } - - private async initialize() { - const apiKey = await window.core?.api.appToken() - - // declare time-weighted retriever and storage - this.timeWeightedVectorStore = new MemoryVectorStore( - new OpenAIEmbeddings( - { openAIApiKey: apiKey }, - { basePath: `${CORTEX_API_URL}/v1` } - ) - ) - this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({ - vectorStore: this.timeWeightedVectorStore, - memoryStream: [], - searchKwargs: 2, - }) - } - - public updateTextSplitter(chunkSize: number, chunkOverlap: number): void { - this.chunkSize = chunkSize - this.chunkOverlap = chunkOverlap - this.textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: chunkSize, - chunkOverlap: chunkOverlap, - }) - } - - public async updateEmbeddingEngine(model: string, engine: string) { - const apiKey = await window.core?.api.appToken() - this.embeddingModel = new OpenAIEmbeddings( - { openAIApiKey: apiKey, model }, - // TODO: Raw settings - { basePath: `${CORTEX_API_URL}/v1` } - ) - - // update time-weighted embedding model - this.timeWeightedVectorStore.embeddings = this.embeddingModel - } - - public ingestAgentKnowledge = async ( - filePath: string, - memoryPath: string, - useTimeWeighted: boolean - ): Promise => { - const loader = new PDFLoader(filePath, { - splitPages: true, - }) - if (!this.embeddingModel) return Promise.reject() - const doc = await loader.load() - const docs = await this.textSplitter!.splitDocuments(doc) - const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel) - - // add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval - if (useTimeWeighted && this.timeWeightedretriever) { - await ( - this.timeWeightedretriever as TimeWeightedVectorStoreRetriever - ).addDocuments(docs) - } - return vectorStore.save(memoryPath) - } - - public loadRetrievalAgent = async (memoryPath: string): Promise => { - if (!this.embeddingModel) return Promise.reject() - const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel) - this.retriever = vectorStore.asRetriever(2) - return Promise.resolve() - } - - public generateResult = async ( - query: string, - useTimeWeighted: boolean - ): Promise => { - if (useTimeWeighted) { - if (!this.timeWeightedretriever) { - return Promise.resolve(' ') - } - // use invoke because getRelevantDocuments is deprecated - const relevantDocs = await this.timeWeightedretriever.invoke(query) - const serializedDoc = formatDocumentsAsString(relevantDocs) - return Promise.resolve(serializedDoc) - } - - if (!this.retriever) { - return Promise.resolve(' ') - } - - // should use invoke(query) because getRelevantDocuments is deprecated - const relevantDocs = await this.retriever.getRelevantDocuments(query) - const serializedDoc = formatDocumentsAsString(relevantDocs) - return Promise.resolve(serializedDoc) - } -} - -export const retrieval = new Retrieval() diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts deleted file mode 100644 index b1a0c3cba..000000000 --- a/extensions/assistant-extension/src/tools/retrieval.ts +++ /dev/null @@ -1,118 +0,0 @@ -import { - AssistantTool, - executeOnMain, - fs, - InferenceTool, - joinPath, - MessageRequest, -} from '@janhq/core' - -export class RetrievalTool extends InferenceTool { - private _threadDir = 'file://threads' - private retrievalThreadId: string | undefined = undefined - - name: string = 'retrieval' - - async process( - data: MessageRequest, - tool?: AssistantTool - ): Promise { - if (!data.model || !data.messages) { - return Promise.resolve(data) - } - - const latestMessage = data.messages[data.messages.length - 1] - - // 1. Ingest the document if needed - if ( - latestMessage && - latestMessage.content && - typeof latestMessage.content !== 'string' && - latestMessage.content.length > 1 - ) { - const docFile = latestMessage.content[1]?.doc_url?.url - if (docFile) { - await executeOnMain( - NODE, - 'toolRetrievalIngestNewDocument', - data.thread?.id, - docFile, - data.model?.id, - data.model?.engine, - tool?.useTimeWeightedRetriever ?? false - ) - } else { - return Promise.resolve(data) - } - } else if ( - // Check whether we need to ingest document or not - // Otherwise wrong context will be sent - !(await fs.existsSync( - await joinPath([this._threadDir, data.threadId, 'memory']) - )) - ) { - // No document ingested, reroute the result to inference engine - - return Promise.resolve(data) - } - // 2. Load agent on thread changed - if (this.retrievalThreadId !== data.threadId) { - await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId) - - this.retrievalThreadId = data.threadId - - // Update the text splitter - await executeOnMain( - NODE, - 'toolRetrievalUpdateTextSplitter', - tool?.settings?.chunk_size ?? 4000, - tool?.settings?.chunk_overlap ?? 200 - ) - } - - // 3. Using the retrieval template with the result and query - if (latestMessage.content) { - const prompt = - typeof latestMessage.content === 'string' - ? latestMessage.content - : latestMessage.content[0].text - // Retrieve the result - const retrievalResult = await executeOnMain( - NODE, - 'toolRetrievalQueryResult', - prompt, - tool?.useTimeWeightedRetriever ?? false - ) - console.debug('toolRetrievalQueryResult', retrievalResult) - - // Update message content - if (retrievalResult) - data.messages[data.messages.length - 1].content = - tool?.settings?.retrieval_template - ?.replace('{CONTEXT}', retrievalResult) - .replace('{QUESTION}', prompt) - } - - // 4. Reroute the result to inference engine - return Promise.resolve(this.normalize(data)) - } - - // Filter out all the messages that are not text - // TODO: Remove it until engines can handle multiple content types - normalize(request: MessageRequest): MessageRequest { - request.messages = request.messages?.map((message) => { - if ( - message.content && - typeof message.content !== 'string' && - (message.content.length ?? 0) > 0 - ) { - return { - ...message, - content: [message.content[0]], - } - } - return message - }) - return request - } -} diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 1c6f77905..1c9334fb5 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -11,20 +11,19 @@ import { events, MessageEvent, ContentType, + EngineManager, + InferenceEngine, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { OpenAI } from 'openai' import { - ChatCompletionMessage, ChatCompletionMessageParam, ChatCompletionRole, ChatCompletionTool, } from 'openai/resources/chat' -import { Tool } from 'openai/resources/responses/responses' - import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -250,100 +249,41 @@ export default function useSendChatMessage() { } setIsGeneratingResponse(true) - let isDone = false - const openai = new OpenAI({ - apiKey: await window.core.api.appToken(), - baseURL: `${API_BASE_URL}/v1`, - dangerouslyAllowBrowser: true, - }) - while (!isDone) { - const data = requestBuilder.build() - const response = await openai.chat.completions.create({ - messages: (data.messages ?? []).map((e) => { - return { - role: e.role as ChatCompletionRole, - content: e.content, - } - }) as ChatCompletionMessageParam[], - model: data.model?.id ?? '', - tools: data.tools as ChatCompletionTool[], - stream: false, + if (requestBuilder.tools && requestBuilder.tools.length) { + let isDone = false + const openai = new OpenAI({ + apiKey: await window.core.api.appToken(), + baseURL: `${API_BASE_URL}/v1`, + dangerouslyAllowBrowser: true, }) - if (response.choices[0]?.message.content) { - const newMessage: ThreadMessage = { - id: ulid(), - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: response.choices[0].message.role as any, - content: [ - { - type: ContentType.Text, - text: { - value: response.choices[0].message.content - ? (response.choices[0].message.content as any) - : '', - annotations: [], - }, - }, - ], - status: 'ready' as any, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage( - (response.choices[0].message.content as any) ?? '' - ) - events.emit(MessageEvent.OnMessageUpdate, newMessage) - } - - if (response.choices[0]?.message.tool_calls) { - for (const toolCall of response.choices[0].message.tool_calls) { - const id = ulid() - const toolMessage: ThreadMessage = { - id: id, + while (!isDone) { + const data = requestBuilder.build() + const response = await openai.chat.completions.create({ + messages: (data.messages ?? []).map((e) => { + return { + role: e.role as ChatCompletionRole, + content: e.content, + } + }) as ChatCompletionMessageParam[], + model: data.model?.id ?? '', + tools: data.tools as ChatCompletionTool[], + stream: false, + }) + if (response.choices[0]?.message.content) { + const newMessage: ThreadMessage = { + id: ulid(), object: 'message', thread_id: activeThreadRef.current.id, assistant_id: activeAssistantRef.current.assistant_id, attachments: [], - role: 'assistant' as any, + role: response.choices[0].message.role as any, content: [ { type: ContentType.Text, text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, - annotations: [], - }, - }, - ], - status: 'pending' as any, - created_at: Date.now(), - completed_at: Date.now(), - } - events.emit(MessageEvent.OnMessageUpdate, toolMessage) - const result = await window.core.api.callTool({ - toolName: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments), - }) - if (result.error) { - console.error(result.error) - break - } - const message: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: 'assistant' as any, - content: [ - { - type: ContentType.Text, - text: { - value: - `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + - (result.content[0]?.text ?? ''), + value: response.choices[0].message.content + ? (response.choices[0].message.content as any) + : '', annotations: [], }, }, @@ -352,15 +292,81 @@ export default function useSendChatMessage() { created_at: Date.now(), completed_at: Date.now(), } - requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') - requestBuilder.pushMessage('Go for the next step') - events.emit(MessageEvent.OnMessageUpdate, message) + requestBuilder.pushAssistantMessage( + (response.choices[0].message.content as any) ?? '' + ) + events.emit(MessageEvent.OnMessageUpdate, newMessage) } - } - isDone = - !response.choices[0]?.message.tool_calls || - !response.choices[0]?.message.tool_calls.length + if (response.choices[0]?.message.tool_calls) { + for (const toolCall of response.choices[0].message.tool_calls) { + const id = ulid() + const toolMessage: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, + annotations: [], + }, + }, + ], + status: 'pending' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + events.emit(MessageEvent.OnMessageUpdate, toolMessage) + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) { + console.error(result.error) + break + } + const message: ThreadMessage = { + id: id, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + attachments: [], + role: 'assistant' as any, + content: [ + { + type: ContentType.Text, + text: { + value: + `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + + (result.content[0]?.text ?? ''), + annotations: [], + }, + }, + ], + status: 'ready' as any, + created_at: Date.now(), + completed_at: Date.now(), + } + requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') + requestBuilder.pushMessage('Go for the next step') + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + isDone = + !response.choices[0]?.message.tool_calls || + !response.choices[0]?.message.tool_calls.length + } + } else { + // Request for inference + EngineManager.instance() + .get(InferenceEngine.cortex) + ?.inference(requestBuilder.build()) } // Reset states diff --git a/web/services/coreService.ts b/web/services/coreService.ts index cda83ceaf..4143facf3 100644 --- a/web/services/coreService.ts +++ b/web/services/coreService.ts @@ -1,4 +1,4 @@ -import { EngineManager, ToolManager } from '@janhq/core' +import { EngineManager } from '@janhq/core' import { appService } from './appService' import { EventEmitter } from './eventsService' @@ -16,7 +16,6 @@ export const setupCoreServices = () => { window.core = { events: new EventEmitter(), engineManager: new EngineManager(), - toolManager: new ToolManager(), api: { ...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)), ...appService,