refactor: remove lagecy tools

This commit is contained in:
Louis 2025-04-02 10:26:02 +07:00
parent cc90c1e86e
commit 1027059a6b
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 104 additions and 517 deletions

View File

@ -1,5 +0,0 @@
it('should not throw any errors when imported', () => {
expect(() => require('./index')).not.toThrow();
})

View File

@ -1,2 +0,0 @@
export * from './manager'
export * from './tool'

View File

@ -1,47 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'
/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()
/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}
/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}
/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}
/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}

View File

@ -1,63 +0,0 @@
import { ToolManager } from '../../browser/tools/manager'
import { InferenceTool } from '../../browser/tools/tool'
import { AssistantTool, MessageRequest } from '../../types'
class MockInferenceTool implements InferenceTool {
name = 'mockTool'
process(request: MessageRequest, tool: AssistantTool): Promise<MessageRequest> {
return Promise.resolve(request)
}
}
it('should register a tool', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
expect(manager.get(tool.name)).toBe(tool)
})
it('should retrieve a tool by its name', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const retrievedTool = manager.get(tool.name)
expect(retrievedTool).toBe(tool)
})
it('should return undefined for a non-existent tool', () => {
const manager = new ToolManager()
const retrievedTool = manager.get('nonExistentTool')
expect(retrievedTool).toBeUndefined()
})
it('should process the message request with enabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should skip processing for disabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should throw an error when process is called without implementation', () => {
class TestTool extends InferenceTool {
name = 'testTool'
}
const tool = new TestTool()
expect(() => tool.process({} as MessageRequest)).toThrowError()
})

View File

@ -1,12 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}

View File

@ -1,12 +1,7 @@
import { Assistant, AssistantExtension, ToolManager } from '@janhq/core' import { Assistant, AssistantExtension } from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'
export default class JanAssistantExtension extends AssistantExtension { export default class JanAssistantExtension extends AssistantExtension {
async onLoad() {}
async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())
}
/** /**
* Called when the extension is unloaded. * Called when the extension is unloaded.

View File

@ -1,45 +0,0 @@
import { getJanDataFolderPath } from '@janhq/core/node'
import { retrieval } from './retrieval'
import path from 'path'
export function toolRetrievalUpdateTextSplitter(
chunkSize: number,
chunkOverlap: number
) {
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
}
export async function toolRetrievalIngestNewDocument(
thread: string,
file: string,
model: string,
engine: string,
useTimeWeighted: boolean
) {
const threadPath = path.join(getJanDataFolderPath(), 'threads', thread)
const filePath = path.join(getJanDataFolderPath(), 'files', file)
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalLoadThreadMemory(threadId: string) {
return retrieval
.loadRetrievalAgent(
path.join(getJanDataFolderPath(), 'threads', threadId, 'memory')
)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalQueryResult(
query: string,
useTimeWeighted: boolean = false
) {
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
console.error(err)
})
}

View File

@ -1,121 +0,0 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { formatDocumentsAsString } from 'langchain/util/document'
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
export class Retrieval {
public chunkSize: number = 100
public chunkOverlap?: number = 0
private retriever: any
private embeddingModel?: OpenAIEmbeddings = undefined
private textSplitter?: RecursiveCharacterTextSplitter
// to support time-weighted retrieval
private timeWeightedVectorStore: MemoryVectorStore
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap)
this.initialize()
}
private async initialize() {
const apiKey = await window.core?.api.appToken()
// declare time-weighted retriever and storage
this.timeWeightedVectorStore = new MemoryVectorStore(
new OpenAIEmbeddings(
{ openAIApiKey: apiKey },
{ basePath: `${CORTEX_API_URL}/v1` }
)
)
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
vectorStore: this.timeWeightedVectorStore,
memoryStream: [],
searchKwargs: 2,
})
}
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
this.chunkSize = chunkSize
this.chunkOverlap = chunkOverlap
this.textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: chunkSize,
chunkOverlap: chunkOverlap,
})
}
public async updateEmbeddingEngine(model: string, engine: string) {
const apiKey = await window.core?.api.appToken()
this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: apiKey, model },
// TODO: Raw settings
{ basePath: `${CORTEX_API_URL}/v1` }
)
// update time-weighted embedding model
this.timeWeightedVectorStore.embeddings = this.embeddingModel
}
public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string,
useTimeWeighted: boolean
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
})
if (!this.embeddingModel) return Promise.reject()
const doc = await loader.load()
const docs = await this.textSplitter!.splitDocuments(doc)
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
if (useTimeWeighted && this.timeWeightedretriever) {
await (
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
).addDocuments(docs)
}
return vectorStore.save(memoryPath)
}
public loadRetrievalAgent = async (memoryPath: string): Promise<void> => {
if (!this.embeddingModel) return Promise.reject()
const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel)
this.retriever = vectorStore.asRetriever(2)
return Promise.resolve()
}
public generateResult = async (
query: string,
useTimeWeighted: boolean
): Promise<string> => {
if (useTimeWeighted) {
if (!this.timeWeightedretriever) {
return Promise.resolve(' ')
}
// use invoke because getRelevantDocuments is deprecated
const relevantDocs = await this.timeWeightedretriever.invoke(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
if (!this.retriever) {
return Promise.resolve(' ')
}
// should use invoke(query) because getRelevantDocuments is deprecated
const relevantDocs = await this.retriever.getRelevantDocuments(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
}
export const retrieval = new Retrieval()

View File

@ -1,118 +0,0 @@
import {
AssistantTool,
executeOnMain,
fs,
InferenceTool,
joinPath,
MessageRequest,
} from '@janhq/core'
export class RetrievalTool extends InferenceTool {
private _threadDir = 'file://threads'
private retrievalThreadId: string | undefined = undefined
name: string = 'retrieval'
async process(
data: MessageRequest,
tool?: AssistantTool
): Promise<MessageRequest> {
if (!data.model || !data.messages) {
return Promise.resolve(data)
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
data.thread?.id,
docFile,
data.model?.id,
data.model?.engine,
tool?.useTimeWeightedRetriever ?? false
)
} else {
return Promise.resolve(data)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([this._threadDir, data.threadId, 'memory'])
))
) {
// No document ingested, reroute the result to inference engine
return Promise.resolve(data)
}
// 2. Load agent on thread changed
if (this.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
this.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
tool?.settings?.chunk_size ?? 4000,
tool?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt,
tool?.useTimeWeightedRetriever ?? false
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (retrievalResult)
data.messages[data.messages.length - 1].content =
tool?.settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// 4. Reroute the result to inference engine
return Promise.resolve(this.normalize(data))
}
// Filter out all the messages that are not text
// TODO: Remove it until engines can handle multiple content types
normalize(request: MessageRequest): MessageRequest {
request.messages = request.messages?.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
return request
}
}

View File

@ -11,20 +11,19 @@ import {
events, events,
MessageEvent, MessageEvent,
ContentType, ContentType,
EngineManager,
InferenceEngine,
} from '@janhq/core' } from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai' import { OpenAI } from 'openai'
import { import {
ChatCompletionMessage,
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionRole, ChatCompletionRole,
ChatCompletionTool, ChatCompletionTool,
} from 'openai/resources/chat' } from 'openai/resources/chat'
import { Tool } from 'openai/resources/responses/responses'
import { ulid } from 'ulidx' import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown' import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
@ -250,100 +249,41 @@ export default function useSendChatMessage() {
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
let isDone = false if (requestBuilder.tools && requestBuilder.tools.length) {
const openai = new OpenAI({ let isDone = false
apiKey: await window.core.api.appToken(), const openai = new OpenAI({
baseURL: `${API_BASE_URL}/v1`, apiKey: await window.core.api.appToken(),
dangerouslyAllowBrowser: true, baseURL: `${API_BASE_URL}/v1`,
}) dangerouslyAllowBrowser: true,
while (!isDone) {
const data = requestBuilder.build()
const response = await openai.chat.completions.create({
messages: (data.messages ?? []).map((e) => {
return {
role: e.role as ChatCompletionRole,
content: e.content,
}
}) as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: false,
}) })
if (response.choices[0]?.message.content) { while (!isDone) {
const newMessage: ThreadMessage = { const data = requestBuilder.build()
id: ulid(), const response = await openai.chat.completions.create({
object: 'message', messages: (data.messages ?? []).map((e) => {
thread_id: activeThreadRef.current.id, return {
assistant_id: activeAssistantRef.current.assistant_id, role: e.role as ChatCompletionRole,
attachments: [], content: e.content,
role: response.choices[0].message.role as any, }
content: [ }) as ChatCompletionMessageParam[],
{ model: data.model?.id ?? '',
type: ContentType.Text, tools: data.tools as ChatCompletionTool[],
text: { stream: false,
value: response.choices[0].message.content })
? (response.choices[0].message.content as any) if (response.choices[0]?.message.content) {
: '', const newMessage: ThreadMessage = {
annotations: [], id: ulid(),
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(
(response.choices[0].message.content as any) ?? ''
)
events.emit(MessageEvent.OnMessageUpdate, newMessage)
}
if (response.choices[0]?.message.tool_calls) {
for (const toolCall of response.choices[0].message.tool_calls) {
const id = ulid()
const toolMessage: ThreadMessage = {
id: id,
object: 'message', object: 'message',
thread_id: activeThreadRef.current.id, thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id, assistant_id: activeAssistantRef.current.assistant_id,
attachments: [], attachments: [],
role: 'assistant' as any, role: response.choices[0].message.role as any,
content: [ content: [
{ {
type: ContentType.Text, type: ContentType.Text,
text: { text: {
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, value: response.choices[0].message.content
annotations: [], ? (response.choices[0].message.content as any)
}, : '',
},
],
status: 'pending' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) {
console.error(result.error)
break
}
const message: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value:
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
(result.content[0]?.text ?? ''),
annotations: [], annotations: [],
}, },
}, },
@ -352,15 +292,81 @@ export default function useSendChatMessage() {
created_at: Date.now(), created_at: Date.now(),
completed_at: Date.now(), completed_at: Date.now(),
} }
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') requestBuilder.pushAssistantMessage(
requestBuilder.pushMessage('Go for the next step') (response.choices[0].message.content as any) ?? ''
events.emit(MessageEvent.OnMessageUpdate, message) )
events.emit(MessageEvent.OnMessageUpdate, newMessage)
} }
}
isDone = if (response.choices[0]?.message.tool_calls) {
!response.choices[0]?.message.tool_calls || for (const toolCall of response.choices[0].message.tool_calls) {
!response.choices[0]?.message.tool_calls.length const id = ulid()
const toolMessage: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`,
annotations: [],
},
},
],
status: 'pending' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) {
console.error(result.error)
break
}
const message: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value:
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
(result.content[0]?.text ?? ''),
annotations: [],
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
requestBuilder.pushMessage('Go for the next step')
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
isDone =
!response.choices[0]?.message.tool_calls ||
!response.choices[0]?.message.tool_calls.length
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
} }
// Reset states // Reset states

View File

@ -1,4 +1,4 @@
import { EngineManager, ToolManager } from '@janhq/core' import { EngineManager } from '@janhq/core'
import { appService } from './appService' import { appService } from './appService'
import { EventEmitter } from './eventsService' import { EventEmitter } from './eventsService'
@ -16,7 +16,6 @@ export const setupCoreServices = () => {
window.core = { window.core = {
events: new EventEmitter(), events: new EventEmitter(),
engineManager: new EngineManager(), engineManager: new EngineManager(),
toolManager: new ToolManager(),
api: { api: {
...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)), ...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)),
...appService, ...appService,