refactor: remove lagecy tools

This commit is contained in:
Louis 2025-04-02 10:26:02 +07:00
parent cc90c1e86e
commit 1027059a6b
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 104 additions and 517 deletions

View File

@ -1,5 +0,0 @@
it('should not throw any errors when imported', () => {
expect(() => require('./index')).not.toThrow();
})

View File

@ -1,2 +0,0 @@
export * from './manager'
export * from './tool'

View File

@ -1,47 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'
/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()
/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}
/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}
/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}
/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}

View File

@ -1,63 +0,0 @@
import { ToolManager } from '../../browser/tools/manager'
import { InferenceTool } from '../../browser/tools/tool'
import { AssistantTool, MessageRequest } from '../../types'
class MockInferenceTool implements InferenceTool {
name = 'mockTool'
process(request: MessageRequest, tool: AssistantTool): Promise<MessageRequest> {
return Promise.resolve(request)
}
}
it('should register a tool', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
expect(manager.get(tool.name)).toBe(tool)
})
it('should retrieve a tool by its name', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const retrievedTool = manager.get(tool.name)
expect(retrievedTool).toBe(tool)
})
it('should return undefined for a non-existent tool', () => {
const manager = new ToolManager()
const retrievedTool = manager.get('nonExistentTool')
expect(retrievedTool).toBeUndefined()
})
it('should process the message request with enabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should skip processing for disabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should throw an error when process is called without implementation', () => {
class TestTool extends InferenceTool {
name = 'testTool'
}
const tool = new TestTool()
expect(() => tool.process({} as MessageRequest)).toThrowError()
})

View File

@ -1,12 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}

View File

@ -1,12 +1,7 @@
import { Assistant, AssistantExtension, ToolManager } from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'
import { Assistant, AssistantExtension } from '@janhq/core'
export default class JanAssistantExtension extends AssistantExtension {
async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())
}
async onLoad() {}
/**
* Called when the extension is unloaded.

View File

@ -1,45 +0,0 @@
import { getJanDataFolderPath } from '@janhq/core/node'
import { retrieval } from './retrieval'
import path from 'path'
export function toolRetrievalUpdateTextSplitter(
chunkSize: number,
chunkOverlap: number
) {
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
}
export async function toolRetrievalIngestNewDocument(
thread: string,
file: string,
model: string,
engine: string,
useTimeWeighted: boolean
) {
const threadPath = path.join(getJanDataFolderPath(), 'threads', thread)
const filePath = path.join(getJanDataFolderPath(), 'files', file)
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalLoadThreadMemory(threadId: string) {
return retrieval
.loadRetrievalAgent(
path.join(getJanDataFolderPath(), 'threads', threadId, 'memory')
)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalQueryResult(
query: string,
useTimeWeighted: boolean = false
) {
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
console.error(err)
})
}

View File

@ -1,121 +0,0 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { formatDocumentsAsString } from 'langchain/util/document'
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
export class Retrieval {
public chunkSize: number = 100
public chunkOverlap?: number = 0
private retriever: any
private embeddingModel?: OpenAIEmbeddings = undefined
private textSplitter?: RecursiveCharacterTextSplitter
// to support time-weighted retrieval
private timeWeightedVectorStore: MemoryVectorStore
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap)
this.initialize()
}
private async initialize() {
const apiKey = await window.core?.api.appToken()
// declare time-weighted retriever and storage
this.timeWeightedVectorStore = new MemoryVectorStore(
new OpenAIEmbeddings(
{ openAIApiKey: apiKey },
{ basePath: `${CORTEX_API_URL}/v1` }
)
)
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
vectorStore: this.timeWeightedVectorStore,
memoryStream: [],
searchKwargs: 2,
})
}
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
this.chunkSize = chunkSize
this.chunkOverlap = chunkOverlap
this.textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: chunkSize,
chunkOverlap: chunkOverlap,
})
}
public async updateEmbeddingEngine(model: string, engine: string) {
const apiKey = await window.core?.api.appToken()
this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: apiKey, model },
// TODO: Raw settings
{ basePath: `${CORTEX_API_URL}/v1` }
)
// update time-weighted embedding model
this.timeWeightedVectorStore.embeddings = this.embeddingModel
}
public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string,
useTimeWeighted: boolean
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
})
if (!this.embeddingModel) return Promise.reject()
const doc = await loader.load()
const docs = await this.textSplitter!.splitDocuments(doc)
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
if (useTimeWeighted && this.timeWeightedretriever) {
await (
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
).addDocuments(docs)
}
return vectorStore.save(memoryPath)
}
public loadRetrievalAgent = async (memoryPath: string): Promise<void> => {
if (!this.embeddingModel) return Promise.reject()
const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel)
this.retriever = vectorStore.asRetriever(2)
return Promise.resolve()
}
public generateResult = async (
query: string,
useTimeWeighted: boolean
): Promise<string> => {
if (useTimeWeighted) {
if (!this.timeWeightedretriever) {
return Promise.resolve(' ')
}
// use invoke because getRelevantDocuments is deprecated
const relevantDocs = await this.timeWeightedretriever.invoke(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
if (!this.retriever) {
return Promise.resolve(' ')
}
// should use invoke(query) because getRelevantDocuments is deprecated
const relevantDocs = await this.retriever.getRelevantDocuments(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
}
export const retrieval = new Retrieval()

View File

@ -1,118 +0,0 @@
import {
AssistantTool,
executeOnMain,
fs,
InferenceTool,
joinPath,
MessageRequest,
} from '@janhq/core'
export class RetrievalTool extends InferenceTool {
private _threadDir = 'file://threads'
private retrievalThreadId: string | undefined = undefined
name: string = 'retrieval'
async process(
data: MessageRequest,
tool?: AssistantTool
): Promise<MessageRequest> {
if (!data.model || !data.messages) {
return Promise.resolve(data)
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
data.thread?.id,
docFile,
data.model?.id,
data.model?.engine,
tool?.useTimeWeightedRetriever ?? false
)
} else {
return Promise.resolve(data)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([this._threadDir, data.threadId, 'memory'])
))
) {
// No document ingested, reroute the result to inference engine
return Promise.resolve(data)
}
// 2. Load agent on thread changed
if (this.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
this.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
tool?.settings?.chunk_size ?? 4000,
tool?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt,
tool?.useTimeWeightedRetriever ?? false
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (retrievalResult)
data.messages[data.messages.length - 1].content =
tool?.settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// 4. Reroute the result to inference engine
return Promise.resolve(this.normalize(data))
}
// Filter out all the messages that are not text
// TODO: Remove it until engines can handle multiple content types
normalize(request: MessageRequest): MessageRequest {
request.messages = request.messages?.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
return request
}
}

View File

@ -11,20 +11,19 @@ import {
events,
MessageEvent,
ContentType,
EngineManager,
InferenceEngine,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionRole,
ChatCompletionTool,
} from 'openai/resources/chat'
import { Tool } from 'openai/resources/responses/responses'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
@ -250,6 +249,7 @@ export default function useSendChatMessage() {
}
setIsGeneratingResponse(true)
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
const openai = new OpenAI({
apiKey: await window.core.api.appToken(),
@ -362,6 +362,12 @@ export default function useSendChatMessage() {
!response.choices[0]?.message.tool_calls ||
!response.choices[0]?.message.tool_calls.length
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
}
// Reset states
setReloadModel(false)

View File

@ -1,4 +1,4 @@
import { EngineManager, ToolManager } from '@janhq/core'
import { EngineManager } from '@janhq/core'
import { appService } from './appService'
import { EventEmitter } from './eventsService'
@ -16,7 +16,6 @@ export const setupCoreServices = () => {
window.core = {
events: new EventEmitter(),
engineManager: new EngineManager(),
toolManager: new ToolManager(),
api: {
...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)),
...appService,