refactor: introduce inference tools (#2493)
This commit is contained in:
parent
14a67463dc
commit
8e8dfd4b37
@ -36,7 +36,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
* Registers AI Engines
|
||||
*/
|
||||
registerEngine() {
|
||||
EngineManager.instance()?.register(this)
|
||||
EngineManager.instance().register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -23,7 +23,10 @@ export class EngineManager {
|
||||
return this.engines.get(provider) as T | undefined
|
||||
}
|
||||
|
||||
static instance(): EngineManager | undefined {
|
||||
return window.core?.engineManager as EngineManager
|
||||
/**
|
||||
* The instance of the engine manager.
|
||||
*/
|
||||
static instance(): EngineManager {
|
||||
return window.core?.engineManager as EngineManager ?? new EngineManager()
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
return
|
||||
}
|
||||
message.status = MessageStatus.Error
|
||||
message.error_code = err.code
|
||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||
},
|
||||
})
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { Observable } from 'rxjs'
|
||||
import { ModelRuntimeParams } from '../../../../types'
|
||||
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
|
||||
/**
|
||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||
@ -34,6 +34,16 @@ export function requestInference(
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
if (!response.ok) {
|
||||
const data = await response.json()
|
||||
const error = {
|
||||
message: data.error?.message ?? 'Error occurred.',
|
||||
code: data.error?.code ?? ErrorCode.Unknown,
|
||||
}
|
||||
subscriber.error(error)
|
||||
subscriber.complete()
|
||||
return
|
||||
}
|
||||
if (model.parameters.stream === false) {
|
||||
const data = await response.json()
|
||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
||||
|
||||
@ -27,3 +27,9 @@ export * from './extension'
|
||||
* @module
|
||||
*/
|
||||
export * from './extensions'
|
||||
|
||||
/**
|
||||
* Export all base tools.
|
||||
* @module
|
||||
*/
|
||||
export * from './tools'
|
||||
|
||||
2
core/src/browser/tools/index.ts
Normal file
2
core/src/browser/tools/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export * from './manager'
|
||||
export * from './tool'
|
||||
47
core/src/browser/tools/manager.ts
Normal file
47
core/src/browser/tools/manager.ts
Normal file
@ -0,0 +1,47 @@
|
||||
import { AssistantTool, MessageRequest } from '../../types'
|
||||
import { InferenceTool } from './tool'
|
||||
|
||||
/**
|
||||
* Manages the registration and retrieval of inference tools.
|
||||
*/
|
||||
export class ToolManager {
|
||||
public tools = new Map<string, InferenceTool>()
|
||||
|
||||
/**
|
||||
* Registers a tool.
|
||||
* @param tool - The tool to register.
|
||||
*/
|
||||
register<T extends InferenceTool>(tool: T) {
|
||||
this.tools.set(tool.name, tool)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a tool by it's name.
|
||||
* @param name - The name of the tool to retrieve.
|
||||
* @returns The tool, if found.
|
||||
*/
|
||||
get<T extends InferenceTool>(name: string): T | undefined {
|
||||
return this.tools.get(name) as T | undefined
|
||||
}
|
||||
|
||||
/*
|
||||
** Process the message request with the tools.
|
||||
*/
|
||||
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
|
||||
return tools.reduce((prevPromise, currentTool) => {
|
||||
return prevPromise.then((prevResult) => {
|
||||
return currentTool.enabled
|
||||
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
|
||||
Promise.resolve(prevResult)
|
||||
: Promise.resolve(prevResult)
|
||||
})
|
||||
}, Promise.resolve(request))
|
||||
}
|
||||
|
||||
/**
|
||||
* The instance of the tool manager.
|
||||
*/
|
||||
static instance(): ToolManager {
|
||||
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
|
||||
}
|
||||
}
|
||||
12
core/src/browser/tools/tool.ts
Normal file
12
core/src/browser/tools/tool.ts
Normal file
@ -0,0 +1,12 @@
|
||||
import { AssistantTool, MessageRequest } from '../../types'
|
||||
|
||||
/**
|
||||
* Represents a base inference tool.
|
||||
*/
|
||||
export abstract class InferenceTool {
|
||||
abstract name: string
|
||||
/*
|
||||
** Process a message request and return the processed message request.
|
||||
*/
|
||||
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
|
||||
}
|
||||
@ -7,7 +7,6 @@ export type ModelInfo = {
|
||||
settings: ModelSettingParams
|
||||
parameters: ModelRuntimeParams
|
||||
engine?: InferenceEngine
|
||||
proxy_model?: InferenceEngine
|
||||
}
|
||||
|
||||
/**
|
||||
@ -21,8 +20,6 @@ export enum InferenceEngine {
|
||||
groq = 'groq',
|
||||
triton_trtllm = 'triton_trtllm',
|
||||
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
|
||||
|
||||
tool_retrieval_enabled = 'tool_retrieval_enabled',
|
||||
}
|
||||
|
||||
export type ModelArtifact = {
|
||||
@ -94,8 +91,6 @@ export type Model = {
|
||||
* The model engine.
|
||||
*/
|
||||
engine: InferenceEngine
|
||||
|
||||
proxy_model?: InferenceEngine
|
||||
}
|
||||
|
||||
export type ModelMetadata = {
|
||||
|
||||
@ -1,26 +1,21 @@
|
||||
import {
|
||||
fs,
|
||||
Assistant,
|
||||
MessageRequest,
|
||||
events,
|
||||
InferenceEngine,
|
||||
MessageEvent,
|
||||
InferenceEvent,
|
||||
joinPath,
|
||||
executeOnMain,
|
||||
AssistantExtension,
|
||||
AssistantEvent,
|
||||
ToolManager,
|
||||
} from '@janhq/core'
|
||||
import { RetrievalTool } from './tools/retrieval'
|
||||
|
||||
export default class JanAssistantExtension extends AssistantExtension {
|
||||
private static readonly _homeDir = 'file://assistants'
|
||||
private static readonly _threadDir = 'file://threads'
|
||||
|
||||
controller = new AbortController()
|
||||
isCancelled = false
|
||||
retrievalThreadId: string | undefined = undefined
|
||||
|
||||
async onLoad() {
|
||||
// Register the retrieval tool
|
||||
ToolManager.instance().register(new RetrievalTool())
|
||||
|
||||
// making the assistant directory
|
||||
const assistantDirExist = await fs.existsSync(
|
||||
JanAssistantExtension._homeDir
|
||||
@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
|
||||
// Update the assistant list
|
||||
events.emit(AssistantEvent.OnAssistantsUpdate, {})
|
||||
}
|
||||
|
||||
// Events subscription
|
||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
|
||||
JanAssistantExtension.handleMessageRequest(data, this)
|
||||
)
|
||||
|
||||
events.on(InferenceEvent.OnInferenceStopped, () => {
|
||||
JanAssistantExtension.handleInferenceStopped(this)
|
||||
})
|
||||
}
|
||||
|
||||
private static async handleInferenceStopped(instance: JanAssistantExtension) {
|
||||
instance.isCancelled = true
|
||||
instance.controller?.abort()
|
||||
}
|
||||
|
||||
private static async handleMessageRequest(
|
||||
data: MessageRequest,
|
||||
instance: JanAssistantExtension
|
||||
) {
|
||||
instance.isCancelled = false
|
||||
instance.controller = new AbortController()
|
||||
|
||||
if (
|
||||
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
|
||||
!data.messages ||
|
||||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
|
||||
// That could lead to an issue where thread stuck at generating response
|
||||
!data.thread?.assistants[0]?.tools
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
const latestMessage = data.messages[data.messages.length - 1]
|
||||
|
||||
// 1. Ingest the document if needed
|
||||
if (
|
||||
latestMessage &&
|
||||
latestMessage.content &&
|
||||
typeof latestMessage.content !== 'string' &&
|
||||
latestMessage.content.length > 1
|
||||
) {
|
||||
const docFile = latestMessage.content[1]?.doc_url?.url
|
||||
if (docFile) {
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalIngestNewDocument',
|
||||
docFile,
|
||||
data.model?.proxy_model
|
||||
)
|
||||
}
|
||||
} else if (
|
||||
// Check whether we need to ingest document or not
|
||||
// Otherwise wrong context will be sent
|
||||
!(await fs.existsSync(
|
||||
await joinPath([
|
||||
JanAssistantExtension._threadDir,
|
||||
data.threadId,
|
||||
'memory',
|
||||
])
|
||||
))
|
||||
) {
|
||||
// No document ingested, reroute the result to inference engine
|
||||
const output = {
|
||||
...data,
|
||||
model: {
|
||||
...data.model,
|
||||
engine: data.model.proxy_model,
|
||||
},
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, output)
|
||||
return
|
||||
}
|
||||
// 2. Load agent on thread changed
|
||||
if (instance.retrievalThreadId !== data.threadId) {
|
||||
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
||||
|
||||
instance.retrievalThreadId = data.threadId
|
||||
|
||||
// Update the text splitter
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalUpdateTextSplitter',
|
||||
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
|
||||
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Using the retrieval template with the result and query
|
||||
if (latestMessage.content) {
|
||||
const prompt =
|
||||
typeof latestMessage.content === 'string'
|
||||
? latestMessage.content
|
||||
: latestMessage.content[0].text
|
||||
// Retrieve the result
|
||||
const retrievalResult = await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalQueryResult',
|
||||
prompt
|
||||
)
|
||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||
|
||||
// Update message content
|
||||
if (data.thread?.assistants[0]?.tools && retrievalResult)
|
||||
data.messages[data.messages.length - 1].content =
|
||||
data.thread.assistants[0].tools[0].settings?.retrieval_template
|
||||
?.replace('{CONTEXT}', retrievalResult)
|
||||
.replace('{QUESTION}', prompt)
|
||||
}
|
||||
|
||||
// Filter out all the messages that are not text
|
||||
data.messages = data.messages.map((message) => {
|
||||
if (
|
||||
message.content &&
|
||||
typeof message.content !== 'string' &&
|
||||
(message.content.length ?? 0) > 0
|
||||
) {
|
||||
return {
|
||||
...message,
|
||||
content: [message.content[0]],
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
|
||||
// 4. Reroute the result to inference engine
|
||||
const output = {
|
||||
...data,
|
||||
model: {
|
||||
...data.model,
|
||||
engine: data.model.proxy_model,
|
||||
},
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, output)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
@ -0,0 +1,108 @@
|
||||
import {
|
||||
AssistantTool,
|
||||
executeOnMain,
|
||||
fs,
|
||||
InferenceTool,
|
||||
joinPath,
|
||||
MessageRequest,
|
||||
} from '@janhq/core'
|
||||
|
||||
export class RetrievalTool extends InferenceTool {
|
||||
private _threadDir = 'file://threads'
|
||||
private retrievalThreadId: string | undefined = undefined
|
||||
|
||||
name: string = 'retrieval'
|
||||
|
||||
async process(
|
||||
data: MessageRequest,
|
||||
tool?: AssistantTool
|
||||
): Promise<MessageRequest> {
|
||||
if (!data.model || !data.messages) {
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
|
||||
const latestMessage = data.messages[data.messages.length - 1]
|
||||
|
||||
// 1. Ingest the document if needed
|
||||
if (
|
||||
latestMessage &&
|
||||
latestMessage.content &&
|
||||
typeof latestMessage.content !== 'string' &&
|
||||
latestMessage.content.length > 1
|
||||
) {
|
||||
const docFile = latestMessage.content[1]?.doc_url?.url
|
||||
if (docFile) {
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalIngestNewDocument',
|
||||
docFile,
|
||||
data.model?.engine
|
||||
)
|
||||
}
|
||||
} else if (
|
||||
// Check whether we need to ingest document or not
|
||||
// Otherwise wrong context will be sent
|
||||
!(await fs.existsSync(
|
||||
await joinPath([this._threadDir, data.threadId, 'memory'])
|
||||
))
|
||||
) {
|
||||
// No document ingested, reroute the result to inference engine
|
||||
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
// 2. Load agent on thread changed
|
||||
if (this.retrievalThreadId !== data.threadId) {
|
||||
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
||||
|
||||
this.retrievalThreadId = data.threadId
|
||||
|
||||
// Update the text splitter
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalUpdateTextSplitter',
|
||||
tool?.settings?.chunk_size ?? 4000,
|
||||
tool?.settings?.chunk_overlap ?? 200
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Using the retrieval template with the result and query
|
||||
if (latestMessage.content) {
|
||||
const prompt =
|
||||
typeof latestMessage.content === 'string'
|
||||
? latestMessage.content
|
||||
: latestMessage.content[0].text
|
||||
// Retrieve the result
|
||||
const retrievalResult = await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalQueryResult',
|
||||
prompt
|
||||
)
|
||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||
|
||||
// Update message content
|
||||
if (retrievalResult)
|
||||
data.messages[data.messages.length - 1].content =
|
||||
tool?.settings?.retrieval_template
|
||||
?.replace('{CONTEXT}', retrievalResult)
|
||||
.replace('{QUESTION}', prompt)
|
||||
}
|
||||
|
||||
// Filter out all the messages that are not text
|
||||
data.messages = data.messages.map((message) => {
|
||||
if (
|
||||
message.content &&
|
||||
typeof message.content !== 'string' &&
|
||||
(message.content.length ?? 0) > 0
|
||||
) {
|
||||
return {
|
||||
...message,
|
||||
content: [message.content[0]],
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
|
||||
// 4. Reroute the result to inference engine
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
}
|
||||
@ -230,7 +230,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
|
||||
// 2. Update the title with the result of the inference
|
||||
setTimeout(() => {
|
||||
const engine = EngineManager.instance()?.get(
|
||||
const engine = EngineManager.instance().get(
|
||||
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
|
||||
@ -78,7 +78,7 @@ export function useActiveModel() {
|
||||
}
|
||||
|
||||
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
||||
const engine = EngineManager.instance()?.get(model.engine)
|
||||
const engine = EngineManager.instance().get(model.engine)
|
||||
return engine
|
||||
?.loadModel(model)
|
||||
.then(() => {
|
||||
@ -95,7 +95,6 @@ export function useActiveModel() {
|
||||
})
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to load model: ', error)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
@ -108,13 +107,14 @@ export function useActiveModel() {
|
||||
type: 'success',
|
||||
})
|
||||
setLoadModelError(error)
|
||||
return Promise.reject(error)
|
||||
})
|
||||
}
|
||||
|
||||
const stopModel = useCallback(async () => {
|
||||
if (activeModel) {
|
||||
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
||||
const engine = EngineManager.instance()?.get(activeModel.engine)
|
||||
const engine = EngineManager.instance().get(activeModel.engine)
|
||||
await engine
|
||||
?.unloadModel(activeModel)
|
||||
.catch()
|
||||
|
||||
@ -9,9 +9,8 @@ import {
|
||||
ThreadMessage,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
InferenceEngine,
|
||||
AssistantTool,
|
||||
EngineManager,
|
||||
ToolManager,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
@ -111,7 +110,10 @@ export default function useSendChatMessage() {
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
await startModel(modelId)
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
if (error) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setIsGeneratingResponse(true)
|
||||
@ -128,10 +130,18 @@ export default function useSendChatMessage() {
|
||||
)
|
||||
}
|
||||
}
|
||||
const engine = EngineManager.instance()?.get(
|
||||
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
// Process message request with Assistants tools
|
||||
const request = await ToolManager.instance().process(
|
||||
requestBuilder.build(),
|
||||
activeThreadRef.current.assistants?.flatMap(
|
||||
(assistant) => assistant.tools ?? []
|
||||
) ?? []
|
||||
)
|
||||
engine?.inference(requestBuilder.build())
|
||||
|
||||
const engine =
|
||||
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
|
||||
EngineManager.instance().get(engine)?.inference(request)
|
||||
}
|
||||
|
||||
// Define interface extending Array prototype
|
||||
@ -149,8 +159,9 @@ export default function useSendChatMessage() {
|
||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
||||
const settingParams = toSettingParams(activeModelParams)
|
||||
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
const prompt = message.trim()
|
||||
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
setCurrentPrompt('')
|
||||
setEditPrompt('')
|
||||
|
||||
@ -158,17 +169,12 @@ export default function useSendChatMessage() {
|
||||
? await getBase64(fileUpload[0].file)
|
||||
: undefined
|
||||
|
||||
const fileContentType = fileUpload[0]?.type
|
||||
|
||||
const isDocumentInput = base64Blob && fileContentType === 'pdf'
|
||||
const isImageInput = base64Blob && fileContentType === 'image'
|
||||
|
||||
if (isImageInput && base64Blob) {
|
||||
if (base64Blob && fileUpload[0]?.type === 'image') {
|
||||
// Compress image
|
||||
base64Blob = await compressImage(base64Blob, 512)
|
||||
}
|
||||
|
||||
let modelRequest =
|
||||
const modelRequest =
|
||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
||||
|
||||
// Fallback support for previous broken threads
|
||||
@ -182,23 +188,6 @@ export default function useSendChatMessage() {
|
||||
if (runtimeParams.stream == null) {
|
||||
runtimeParams.stream = true
|
||||
}
|
||||
// Add middleware to the model request with tool retrieval enabled
|
||||
if (
|
||||
activeThreadRef.current.assistants[0].tools?.some(
|
||||
(tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled
|
||||
)
|
||||
) {
|
||||
modelRequest = {
|
||||
...modelRequest,
|
||||
// Tool retrieval support document input only for now
|
||||
...(isDocumentInput
|
||||
? {
|
||||
engine: InferenceEngine.tool_retrieval_enabled,
|
||||
proxy_model: modelRequest.engine,
|
||||
}
|
||||
: {}),
|
||||
}
|
||||
}
|
||||
|
||||
// Build Message Request
|
||||
const requestBuilder = new MessageRequestBuilder(
|
||||
@ -247,15 +236,27 @@ export default function useSendChatMessage() {
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
await startModel(modelId)
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
setQueuedMessage(false)
|
||||
if (error) {
|
||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||
return
|
||||
}
|
||||
}
|
||||
setIsGeneratingResponse(true)
|
||||
|
||||
const engine = EngineManager.instance()?.get(
|
||||
requestBuilder.model?.engine ?? modelRequest.engine ?? ''
|
||||
// Process message request with Assistants tools
|
||||
const request = await ToolManager.instance().process(
|
||||
requestBuilder.build(),
|
||||
activeThreadRef.current.assistants?.flatMap(
|
||||
(assistant) => assistant.tools ?? []
|
||||
) ?? []
|
||||
)
|
||||
engine?.inference(requestBuilder.build())
|
||||
|
||||
// Request for inference
|
||||
EngineManager.instance()
|
||||
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
||||
?.inference(request)
|
||||
|
||||
// Reset states
|
||||
setReloadModel(false)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { EngineManager } from '@janhq/core'
|
||||
import { EngineManager, ToolManager } from '@janhq/core'
|
||||
|
||||
import { appService } from './appService'
|
||||
import { EventEmitter } from './eventsService'
|
||||
@ -15,6 +15,7 @@ export const setupCoreServices = () => {
|
||||
window.core = {
|
||||
events: new EventEmitter(),
|
||||
engineManager: new EngineManager(),
|
||||
toolManager: new ToolManager(),
|
||||
api: {
|
||||
...(window.electronAPI ? window.electronAPI : restAPI),
|
||||
...appService,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user