refactor: introduce inference tools (#2493)

This commit is contained in:
Louis 2024-03-25 23:26:05 +07:00 committed by GitHub
parent 14a67463dc
commit 8e8dfd4b37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 240 additions and 193 deletions

View File

@ -36,7 +36,7 @@ export abstract class AIEngine extends BaseExtension {
* Registers AI Engines
*/
registerEngine() {
EngineManager.instance()?.register(this)
EngineManager.instance().register(this)
}
/**

View File

@ -23,7 +23,10 @@ export class EngineManager {
return this.engines.get(provider) as T | undefined
}
static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
/**
* The instance of the engine manager.
*/
static instance(): EngineManager {
return window.core?.engineManager as EngineManager ?? new EngineManager()
}
}

View File

@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
return
}
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})

View File

@ -1,5 +1,5 @@
import { Observable } from 'rxjs'
import { ModelRuntimeParams } from '../../../../types'
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
@ -34,6 +34,16 @@ export function requestInference(
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'Error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')

View File

@ -27,3 +27,9 @@ export * from './extension'
* @module
*/
export * from './extensions'
/**
* Export all base tools.
* @module
*/
export * from './tools'

View File

@ -0,0 +1,2 @@
export * from './manager'
export * from './tool'

View File

@ -0,0 +1,47 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'
/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()
/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}
/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}
/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}
/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}

View File

@ -0,0 +1,12 @@
import { AssistantTool, MessageRequest } from '../../types'
/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}

View File

@ -7,7 +7,6 @@ export type ModelInfo = {
settings: ModelSettingParams
parameters: ModelRuntimeParams
engine?: InferenceEngine
proxy_model?: InferenceEngine
}
/**
@ -21,8 +20,6 @@ export enum InferenceEngine {
groq = 'groq',
triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
tool_retrieval_enabled = 'tool_retrieval_enabled',
}
export type ModelArtifact = {
@ -94,8 +91,6 @@ export type Model = {
* The model engine.
*/
engine: InferenceEngine
proxy_model?: InferenceEngine
}
export type ModelMetadata = {

View File

@ -1,26 +1,21 @@
import {
fs,
Assistant,
MessageRequest,
events,
InferenceEngine,
MessageEvent,
InferenceEvent,
joinPath,
executeOnMain,
AssistantExtension,
AssistantEvent,
ToolManager,
} from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'
export default class JanAssistantExtension extends AssistantExtension {
private static readonly _homeDir = 'file://assistants'
private static readonly _threadDir = 'file://threads'
controller = new AbortController()
isCancelled = false
retrievalThreadId: string | undefined = undefined
async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())
// making the assistant directory
const assistantDirExist = await fs.existsSync(
JanAssistantExtension._homeDir
@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
// Update the assistant list
events.emit(AssistantEvent.OnAssistantsUpdate, {})
}
// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
JanAssistantExtension.handleMessageRequest(data, this)
)
events.on(InferenceEvent.OnInferenceStopped, () => {
JanAssistantExtension.handleInferenceStopped(this)
})
}
private static async handleInferenceStopped(instance: JanAssistantExtension) {
instance.isCancelled = true
instance.controller?.abort()
}
private static async handleMessageRequest(
data: MessageRequest,
instance: JanAssistantExtension
) {
instance.isCancelled = false
instance.controller = new AbortController()
if (
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
!data.messages ||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
// That could lead to an issue where thread stuck at generating response
!data.thread?.assistants[0]?.tools
) {
return
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
docFile,
data.model?.proxy_model
)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([
JanAssistantExtension._threadDir,
data.threadId,
'memory',
])
))
) {
// No document ingested, reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
return
}
// 2. Load agent on thread changed
if (instance.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
instance.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (data.thread?.assistants[0]?.tools && retrievalResult)
data.messages[data.messages.length - 1].content =
data.thread.assistants[0].tools[0].settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// Filter out all the messages that are not text
data.messages = data.messages.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
// 4. Reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
}
/**

View File

@ -0,0 +1,108 @@
import {
AssistantTool,
executeOnMain,
fs,
InferenceTool,
joinPath,
MessageRequest,
} from '@janhq/core'
export class RetrievalTool extends InferenceTool {
private _threadDir = 'file://threads'
private retrievalThreadId: string | undefined = undefined
name: string = 'retrieval'
async process(
data: MessageRequest,
tool?: AssistantTool
): Promise<MessageRequest> {
if (!data.model || !data.messages) {
return Promise.resolve(data)
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
docFile,
data.model?.engine
)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([this._threadDir, data.threadId, 'memory'])
))
) {
// No document ingested, reroute the result to inference engine
return Promise.resolve(data)
}
// 2. Load agent on thread changed
if (this.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
this.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
tool?.settings?.chunk_size ?? 4000,
tool?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (retrievalResult)
data.messages[data.messages.length - 1].content =
tool?.settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// Filter out all the messages that are not text
data.messages = data.messages.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
// 4. Reroute the result to inference engine
return Promise.resolve(data)
}
}

View File

@ -230,7 +230,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// 2. Update the title with the result of the inference
setTimeout(() => {
const engine = EngineManager.instance()?.get(
const engine = EngineManager.instance().get(
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)

View File

@ -78,7 +78,7 @@ export function useActiveModel() {
}
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
const engine = EngineManager.instance()?.get(model.engine)
const engine = EngineManager.instance().get(model.engine)
return engine
?.loadModel(model)
.then(() => {
@ -95,7 +95,6 @@ export function useActiveModel() {
})
})
.catch((error) => {
console.error('Failed to load model: ', error)
setStateModel(() => ({
state: 'start',
loading: false,
@ -108,13 +107,14 @@ export function useActiveModel() {
type: 'success',
})
setLoadModelError(error)
return Promise.reject(error)
})
}
const stopModel = useCallback(async () => {
if (activeModel) {
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
const engine = EngineManager.instance()?.get(activeModel.engine)
const engine = EngineManager.instance().get(activeModel.engine)
await engine
?.unloadModel(activeModel)
.catch()

View File

@ -9,9 +9,8 @@ import {
ThreadMessage,
Model,
ConversationalExtension,
InferenceEngine,
AssistantTool,
EngineManager,
ToolManager,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -111,7 +110,10 @@ export default function useSendChatMessage() {
activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) {
await startModel(modelId)
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
return
}
}
setIsGeneratingResponse(true)
@ -128,10 +130,18 @@ export default function useSendChatMessage() {
)
}
}
const engine = EngineManager.instance()?.get(
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
// Process message request with Assistants tools
const request = await ToolManager.instance().process(
requestBuilder.build(),
activeThreadRef.current.assistants?.flatMap(
(assistant) => assistant.tools ?? []
) ?? []
)
engine?.inference(requestBuilder.build())
const engine =
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
EngineManager.instance().get(engine)?.inference(request)
}
// Define interface extending Array prototype
@ -149,8 +159,9 @@ export default function useSendChatMessage() {
const runtimeParams = toRuntimeParams(activeModelParams)
const settingParams = toSettingParams(activeModelParams)
updateThreadWaiting(activeThreadRef.current.id, true)
const prompt = message.trim()
updateThreadWaiting(activeThreadRef.current.id, true)
setCurrentPrompt('')
setEditPrompt('')
@ -158,17 +169,12 @@ export default function useSendChatMessage() {
? await getBase64(fileUpload[0].file)
: undefined
const fileContentType = fileUpload[0]?.type
const isDocumentInput = base64Blob && fileContentType === 'pdf'
const isImageInput = base64Blob && fileContentType === 'image'
if (isImageInput && base64Blob) {
if (base64Blob && fileUpload[0]?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
let modelRequest =
const modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
// Fallback support for previous broken threads
@ -182,23 +188,6 @@ export default function useSendChatMessage() {
if (runtimeParams.stream == null) {
runtimeParams.stream = true
}
// Add middleware to the model request with tool retrieval enabled
if (
activeThreadRef.current.assistants[0].tools?.some(
(tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled
)
) {
modelRequest = {
...modelRequest,
// Tool retrieval support document input only for now
...(isDocumentInput
? {
engine: InferenceEngine.tool_retrieval_enabled,
proxy_model: modelRequest.engine,
}
: {}),
}
}
// Build Message Request
const requestBuilder = new MessageRequestBuilder(
@ -247,15 +236,27 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
await startModel(modelId)
const error = await startModel(modelId).catch((error: Error) => error)
setQueuedMessage(false)
if (error) {
updateThreadWaiting(activeThreadRef.current.id, false)
return
}
}
setIsGeneratingResponse(true)
const engine = EngineManager.instance()?.get(
requestBuilder.model?.engine ?? modelRequest.engine ?? ''
// Process message request with Assistants tools
const request = await ToolManager.instance().process(
requestBuilder.build(),
activeThreadRef.current.assistants?.flatMap(
(assistant) => assistant.tools ?? []
) ?? []
)
engine?.inference(requestBuilder.build())
// Request for inference
EngineManager.instance()
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
?.inference(request)
// Reset states
setReloadModel(false)

View File

@ -1,4 +1,4 @@
import { EngineManager } from '@janhq/core'
import { EngineManager, ToolManager } from '@janhq/core'
import { appService } from './appService'
import { EventEmitter } from './eventsService'
@ -15,6 +15,7 @@ export const setupCoreServices = () => {
window.core = {
events: new EventEmitter(),
engineManager: new EngineManager(),
toolManager: new ToolManager(),
api: {
...(window.electronAPI ? window.electronAPI : restAPI),
...appService,