refactor: introduce inference tools (#2493)
This commit is contained in:
parent
14a67463dc
commit
8e8dfd4b37
@ -36,7 +36,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
* Registers AI Engines
|
* Registers AI Engines
|
||||||
*/
|
*/
|
||||||
registerEngine() {
|
registerEngine() {
|
||||||
EngineManager.instance()?.register(this)
|
EngineManager.instance().register(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -23,7 +23,10 @@ export class EngineManager {
|
|||||||
return this.engines.get(provider) as T | undefined
|
return this.engines.get(provider) as T | undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
static instance(): EngineManager | undefined {
|
/**
|
||||||
return window.core?.engineManager as EngineManager
|
* The instance of the engine manager.
|
||||||
|
*/
|
||||||
|
static instance(): EngineManager {
|
||||||
|
return window.core?.engineManager as EngineManager ?? new EngineManager()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
message.status = MessageStatus.Error
|
message.status = MessageStatus.Error
|
||||||
|
message.error_code = err.code
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { Observable } from 'rxjs'
|
import { Observable } from 'rxjs'
|
||||||
import { ModelRuntimeParams } from '../../../../types'
|
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
|
||||||
/**
|
/**
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||||
@ -34,6 +34,16 @@ export function requestInference(
|
|||||||
signal: controller?.signal,
|
signal: controller?.signal,
|
||||||
})
|
})
|
||||||
.then(async (response) => {
|
.then(async (response) => {
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json()
|
||||||
|
const error = {
|
||||||
|
message: data.error?.message ?? 'Error occurred.',
|
||||||
|
code: data.error?.code ?? ErrorCode.Unknown,
|
||||||
|
}
|
||||||
|
subscriber.error(error)
|
||||||
|
subscriber.complete()
|
||||||
|
return
|
||||||
|
}
|
||||||
if (model.parameters.stream === false) {
|
if (model.parameters.stream === false) {
|
||||||
const data = await response.json()
|
const data = await response.json()
|
||||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
subscriber.next(data.choices[0]?.message?.content ?? '')
|
||||||
|
|||||||
@ -27,3 +27,9 @@ export * from './extension'
|
|||||||
* @module
|
* @module
|
||||||
*/
|
*/
|
||||||
export * from './extensions'
|
export * from './extensions'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export all base tools.
|
||||||
|
* @module
|
||||||
|
*/
|
||||||
|
export * from './tools'
|
||||||
|
|||||||
2
core/src/browser/tools/index.ts
Normal file
2
core/src/browser/tools/index.ts
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
export * from './manager'
|
||||||
|
export * from './tool'
|
||||||
47
core/src/browser/tools/manager.ts
Normal file
47
core/src/browser/tools/manager.ts
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import { AssistantTool, MessageRequest } from '../../types'
|
||||||
|
import { InferenceTool } from './tool'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages the registration and retrieval of inference tools.
|
||||||
|
*/
|
||||||
|
export class ToolManager {
|
||||||
|
public tools = new Map<string, InferenceTool>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers a tool.
|
||||||
|
* @param tool - The tool to register.
|
||||||
|
*/
|
||||||
|
register<T extends InferenceTool>(tool: T) {
|
||||||
|
this.tools.set(tool.name, tool)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves a tool by it's name.
|
||||||
|
* @param name - The name of the tool to retrieve.
|
||||||
|
* @returns The tool, if found.
|
||||||
|
*/
|
||||||
|
get<T extends InferenceTool>(name: string): T | undefined {
|
||||||
|
return this.tools.get(name) as T | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
** Process the message request with the tools.
|
||||||
|
*/
|
||||||
|
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
|
||||||
|
return tools.reduce((prevPromise, currentTool) => {
|
||||||
|
return prevPromise.then((prevResult) => {
|
||||||
|
return currentTool.enabled
|
||||||
|
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
|
||||||
|
Promise.resolve(prevResult)
|
||||||
|
: Promise.resolve(prevResult)
|
||||||
|
})
|
||||||
|
}, Promise.resolve(request))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The instance of the tool manager.
|
||||||
|
*/
|
||||||
|
static instance(): ToolManager {
|
||||||
|
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
|
||||||
|
}
|
||||||
|
}
|
||||||
12
core/src/browser/tools/tool.ts
Normal file
12
core/src/browser/tools/tool.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import { AssistantTool, MessageRequest } from '../../types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a base inference tool.
|
||||||
|
*/
|
||||||
|
export abstract class InferenceTool {
|
||||||
|
abstract name: string
|
||||||
|
/*
|
||||||
|
** Process a message request and return the processed message request.
|
||||||
|
*/
|
||||||
|
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
|
||||||
|
}
|
||||||
@ -7,7 +7,6 @@ export type ModelInfo = {
|
|||||||
settings: ModelSettingParams
|
settings: ModelSettingParams
|
||||||
parameters: ModelRuntimeParams
|
parameters: ModelRuntimeParams
|
||||||
engine?: InferenceEngine
|
engine?: InferenceEngine
|
||||||
proxy_model?: InferenceEngine
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -21,8 +20,6 @@ export enum InferenceEngine {
|
|||||||
groq = 'groq',
|
groq = 'groq',
|
||||||
triton_trtllm = 'triton_trtllm',
|
triton_trtllm = 'triton_trtllm',
|
||||||
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
|
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
|
||||||
|
|
||||||
tool_retrieval_enabled = 'tool_retrieval_enabled',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelArtifact = {
|
export type ModelArtifact = {
|
||||||
@ -94,8 +91,6 @@ export type Model = {
|
|||||||
* The model engine.
|
* The model engine.
|
||||||
*/
|
*/
|
||||||
engine: InferenceEngine
|
engine: InferenceEngine
|
||||||
|
|
||||||
proxy_model?: InferenceEngine
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelMetadata = {
|
export type ModelMetadata = {
|
||||||
|
|||||||
@ -1,26 +1,21 @@
|
|||||||
import {
|
import {
|
||||||
fs,
|
fs,
|
||||||
Assistant,
|
Assistant,
|
||||||
MessageRequest,
|
|
||||||
events,
|
events,
|
||||||
InferenceEngine,
|
|
||||||
MessageEvent,
|
|
||||||
InferenceEvent,
|
|
||||||
joinPath,
|
joinPath,
|
||||||
executeOnMain,
|
|
||||||
AssistantExtension,
|
AssistantExtension,
|
||||||
AssistantEvent,
|
AssistantEvent,
|
||||||
|
ToolManager,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
import { RetrievalTool } from './tools/retrieval'
|
||||||
|
|
||||||
export default class JanAssistantExtension extends AssistantExtension {
|
export default class JanAssistantExtension extends AssistantExtension {
|
||||||
private static readonly _homeDir = 'file://assistants'
|
private static readonly _homeDir = 'file://assistants'
|
||||||
private static readonly _threadDir = 'file://threads'
|
|
||||||
|
|
||||||
controller = new AbortController()
|
|
||||||
isCancelled = false
|
|
||||||
retrievalThreadId: string | undefined = undefined
|
|
||||||
|
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
|
// Register the retrieval tool
|
||||||
|
ToolManager.instance().register(new RetrievalTool())
|
||||||
|
|
||||||
// making the assistant directory
|
// making the assistant directory
|
||||||
const assistantDirExist = await fs.existsSync(
|
const assistantDirExist = await fs.existsSync(
|
||||||
JanAssistantExtension._homeDir
|
JanAssistantExtension._homeDir
|
||||||
@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
|
|||||||
// Update the assistant list
|
// Update the assistant list
|
||||||
events.emit(AssistantEvent.OnAssistantsUpdate, {})
|
events.emit(AssistantEvent.OnAssistantsUpdate, {})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Events subscription
|
|
||||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
|
|
||||||
JanAssistantExtension.handleMessageRequest(data, this)
|
|
||||||
)
|
|
||||||
|
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () => {
|
|
||||||
JanAssistantExtension.handleInferenceStopped(this)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleInferenceStopped(instance: JanAssistantExtension) {
|
|
||||||
instance.isCancelled = true
|
|
||||||
instance.controller?.abort()
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleMessageRequest(
|
|
||||||
data: MessageRequest,
|
|
||||||
instance: JanAssistantExtension
|
|
||||||
) {
|
|
||||||
instance.isCancelled = false
|
|
||||||
instance.controller = new AbortController()
|
|
||||||
|
|
||||||
if (
|
|
||||||
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
|
|
||||||
!data.messages ||
|
|
||||||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
|
|
||||||
// That could lead to an issue where thread stuck at generating response
|
|
||||||
!data.thread?.assistants[0]?.tools
|
|
||||||
) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const latestMessage = data.messages[data.messages.length - 1]
|
|
||||||
|
|
||||||
// 1. Ingest the document if needed
|
|
||||||
if (
|
|
||||||
latestMessage &&
|
|
||||||
latestMessage.content &&
|
|
||||||
typeof latestMessage.content !== 'string' &&
|
|
||||||
latestMessage.content.length > 1
|
|
||||||
) {
|
|
||||||
const docFile = latestMessage.content[1]?.doc_url?.url
|
|
||||||
if (docFile) {
|
|
||||||
await executeOnMain(
|
|
||||||
NODE,
|
|
||||||
'toolRetrievalIngestNewDocument',
|
|
||||||
docFile,
|
|
||||||
data.model?.proxy_model
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else if (
|
|
||||||
// Check whether we need to ingest document or not
|
|
||||||
// Otherwise wrong context will be sent
|
|
||||||
!(await fs.existsSync(
|
|
||||||
await joinPath([
|
|
||||||
JanAssistantExtension._threadDir,
|
|
||||||
data.threadId,
|
|
||||||
'memory',
|
|
||||||
])
|
|
||||||
))
|
|
||||||
) {
|
|
||||||
// No document ingested, reroute the result to inference engine
|
|
||||||
const output = {
|
|
||||||
...data,
|
|
||||||
model: {
|
|
||||||
...data.model,
|
|
||||||
engine: data.model.proxy_model,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
events.emit(MessageEvent.OnMessageSent, output)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 2. Load agent on thread changed
|
|
||||||
if (instance.retrievalThreadId !== data.threadId) {
|
|
||||||
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
|
||||||
|
|
||||||
instance.retrievalThreadId = data.threadId
|
|
||||||
|
|
||||||
// Update the text splitter
|
|
||||||
await executeOnMain(
|
|
||||||
NODE,
|
|
||||||
'toolRetrievalUpdateTextSplitter',
|
|
||||||
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
|
|
||||||
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Using the retrieval template with the result and query
|
|
||||||
if (latestMessage.content) {
|
|
||||||
const prompt =
|
|
||||||
typeof latestMessage.content === 'string'
|
|
||||||
? latestMessage.content
|
|
||||||
: latestMessage.content[0].text
|
|
||||||
// Retrieve the result
|
|
||||||
const retrievalResult = await executeOnMain(
|
|
||||||
NODE,
|
|
||||||
'toolRetrievalQueryResult',
|
|
||||||
prompt
|
|
||||||
)
|
|
||||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
|
||||||
|
|
||||||
// Update message content
|
|
||||||
if (data.thread?.assistants[0]?.tools && retrievalResult)
|
|
||||||
data.messages[data.messages.length - 1].content =
|
|
||||||
data.thread.assistants[0].tools[0].settings?.retrieval_template
|
|
||||||
?.replace('{CONTEXT}', retrievalResult)
|
|
||||||
.replace('{QUESTION}', prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out all the messages that are not text
|
|
||||||
data.messages = data.messages.map((message) => {
|
|
||||||
if (
|
|
||||||
message.content &&
|
|
||||||
typeof message.content !== 'string' &&
|
|
||||||
(message.content.length ?? 0) > 0
|
|
||||||
) {
|
|
||||||
return {
|
|
||||||
...message,
|
|
||||||
content: [message.content[0]],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return message
|
|
||||||
})
|
|
||||||
|
|
||||||
// 4. Reroute the result to inference engine
|
|
||||||
const output = {
|
|
||||||
...data,
|
|
||||||
model: {
|
|
||||||
...data.model,
|
|
||||||
engine: data.model.proxy_model,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
events.emit(MessageEvent.OnMessageSent, output)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import {
|
||||||
|
AssistantTool,
|
||||||
|
executeOnMain,
|
||||||
|
fs,
|
||||||
|
InferenceTool,
|
||||||
|
joinPath,
|
||||||
|
MessageRequest,
|
||||||
|
} from '@janhq/core'
|
||||||
|
|
||||||
|
export class RetrievalTool extends InferenceTool {
|
||||||
|
private _threadDir = 'file://threads'
|
||||||
|
private retrievalThreadId: string | undefined = undefined
|
||||||
|
|
||||||
|
name: string = 'retrieval'
|
||||||
|
|
||||||
|
async process(
|
||||||
|
data: MessageRequest,
|
||||||
|
tool?: AssistantTool
|
||||||
|
): Promise<MessageRequest> {
|
||||||
|
if (!data.model || !data.messages) {
|
||||||
|
return Promise.resolve(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
const latestMessage = data.messages[data.messages.length - 1]
|
||||||
|
|
||||||
|
// 1. Ingest the document if needed
|
||||||
|
if (
|
||||||
|
latestMessage &&
|
||||||
|
latestMessage.content &&
|
||||||
|
typeof latestMessage.content !== 'string' &&
|
||||||
|
latestMessage.content.length > 1
|
||||||
|
) {
|
||||||
|
const docFile = latestMessage.content[1]?.doc_url?.url
|
||||||
|
if (docFile) {
|
||||||
|
await executeOnMain(
|
||||||
|
NODE,
|
||||||
|
'toolRetrievalIngestNewDocument',
|
||||||
|
docFile,
|
||||||
|
data.model?.engine
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if (
|
||||||
|
// Check whether we need to ingest document or not
|
||||||
|
// Otherwise wrong context will be sent
|
||||||
|
!(await fs.existsSync(
|
||||||
|
await joinPath([this._threadDir, data.threadId, 'memory'])
|
||||||
|
))
|
||||||
|
) {
|
||||||
|
// No document ingested, reroute the result to inference engine
|
||||||
|
|
||||||
|
return Promise.resolve(data)
|
||||||
|
}
|
||||||
|
// 2. Load agent on thread changed
|
||||||
|
if (this.retrievalThreadId !== data.threadId) {
|
||||||
|
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
||||||
|
|
||||||
|
this.retrievalThreadId = data.threadId
|
||||||
|
|
||||||
|
// Update the text splitter
|
||||||
|
await executeOnMain(
|
||||||
|
NODE,
|
||||||
|
'toolRetrievalUpdateTextSplitter',
|
||||||
|
tool?.settings?.chunk_size ?? 4000,
|
||||||
|
tool?.settings?.chunk_overlap ?? 200
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Using the retrieval template with the result and query
|
||||||
|
if (latestMessage.content) {
|
||||||
|
const prompt =
|
||||||
|
typeof latestMessage.content === 'string'
|
||||||
|
? latestMessage.content
|
||||||
|
: latestMessage.content[0].text
|
||||||
|
// Retrieve the result
|
||||||
|
const retrievalResult = await executeOnMain(
|
||||||
|
NODE,
|
||||||
|
'toolRetrievalQueryResult',
|
||||||
|
prompt
|
||||||
|
)
|
||||||
|
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||||
|
|
||||||
|
// Update message content
|
||||||
|
if (retrievalResult)
|
||||||
|
data.messages[data.messages.length - 1].content =
|
||||||
|
tool?.settings?.retrieval_template
|
||||||
|
?.replace('{CONTEXT}', retrievalResult)
|
||||||
|
.replace('{QUESTION}', prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out all the messages that are not text
|
||||||
|
data.messages = data.messages.map((message) => {
|
||||||
|
if (
|
||||||
|
message.content &&
|
||||||
|
typeof message.content !== 'string' &&
|
||||||
|
(message.content.length ?? 0) > 0
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
...message,
|
||||||
|
content: [message.content[0]],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. Reroute the result to inference engine
|
||||||
|
return Promise.resolve(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -230,7 +230,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
|
|
||||||
// 2. Update the title with the result of the inference
|
// 2. Update the title with the result of the inference
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
const engine = EngineManager.instance()?.get(
|
const engine = EngineManager.instance().get(
|
||||||
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
||||||
)
|
)
|
||||||
engine?.inference(messageRequest)
|
engine?.inference(messageRequest)
|
||||||
|
|||||||
@ -78,7 +78,7 @@ export function useActiveModel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
||||||
const engine = EngineManager.instance()?.get(model.engine)
|
const engine = EngineManager.instance().get(model.engine)
|
||||||
return engine
|
return engine
|
||||||
?.loadModel(model)
|
?.loadModel(model)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
@ -95,7 +95,6 @@ export function useActiveModel() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Failed to load model: ', error)
|
|
||||||
setStateModel(() => ({
|
setStateModel(() => ({
|
||||||
state: 'start',
|
state: 'start',
|
||||||
loading: false,
|
loading: false,
|
||||||
@ -108,13 +107,14 @@ export function useActiveModel() {
|
|||||||
type: 'success',
|
type: 'success',
|
||||||
})
|
})
|
||||||
setLoadModelError(error)
|
setLoadModelError(error)
|
||||||
|
return Promise.reject(error)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const stopModel = useCallback(async () => {
|
const stopModel = useCallback(async () => {
|
||||||
if (activeModel) {
|
if (activeModel) {
|
||||||
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
||||||
const engine = EngineManager.instance()?.get(activeModel.engine)
|
const engine = EngineManager.instance().get(activeModel.engine)
|
||||||
await engine
|
await engine
|
||||||
?.unloadModel(activeModel)
|
?.unloadModel(activeModel)
|
||||||
.catch()
|
.catch()
|
||||||
|
|||||||
@ -9,9 +9,8 @@ import {
|
|||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
Model,
|
Model,
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
InferenceEngine,
|
|
||||||
AssistantTool,
|
|
||||||
EngineManager,
|
EngineManager,
|
||||||
|
ToolManager,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -111,7 +110,10 @@ export default function useSendChatMessage() {
|
|||||||
activeThreadRef.current.assistants[0].model.id
|
activeThreadRef.current.assistants[0].model.id
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId) {
|
||||||
await startModel(modelId)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
|
if (error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsGeneratingResponse(true)
|
setIsGeneratingResponse(true)
|
||||||
@ -128,10 +130,18 @@ export default function useSendChatMessage() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const engine = EngineManager.instance()?.get(
|
// Process message request with Assistants tools
|
||||||
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
const request = await ToolManager.instance().process(
|
||||||
|
requestBuilder.build(),
|
||||||
|
activeThreadRef.current.assistants?.flatMap(
|
||||||
|
(assistant) => assistant.tools ?? []
|
||||||
|
) ?? []
|
||||||
)
|
)
|
||||||
engine?.inference(requestBuilder.build())
|
|
||||||
|
const engine =
|
||||||
|
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||||
|
|
||||||
|
EngineManager.instance().get(engine)?.inference(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define interface extending Array prototype
|
// Define interface extending Array prototype
|
||||||
@ -149,8 +159,9 @@ export default function useSendChatMessage() {
|
|||||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
const runtimeParams = toRuntimeParams(activeModelParams)
|
||||||
const settingParams = toSettingParams(activeModelParams)
|
const settingParams = toSettingParams(activeModelParams)
|
||||||
|
|
||||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
|
||||||
const prompt = message.trim()
|
const prompt = message.trim()
|
||||||
|
|
||||||
|
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||||
setCurrentPrompt('')
|
setCurrentPrompt('')
|
||||||
setEditPrompt('')
|
setEditPrompt('')
|
||||||
|
|
||||||
@ -158,17 +169,12 @@ export default function useSendChatMessage() {
|
|||||||
? await getBase64(fileUpload[0].file)
|
? await getBase64(fileUpload[0].file)
|
||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
const fileContentType = fileUpload[0]?.type
|
if (base64Blob && fileUpload[0]?.type === 'image') {
|
||||||
|
|
||||||
const isDocumentInput = base64Blob && fileContentType === 'pdf'
|
|
||||||
const isImageInput = base64Blob && fileContentType === 'image'
|
|
||||||
|
|
||||||
if (isImageInput && base64Blob) {
|
|
||||||
// Compress image
|
// Compress image
|
||||||
base64Blob = await compressImage(base64Blob, 512)
|
base64Blob = await compressImage(base64Blob, 512)
|
||||||
}
|
}
|
||||||
|
|
||||||
let modelRequest =
|
const modelRequest =
|
||||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
||||||
|
|
||||||
// Fallback support for previous broken threads
|
// Fallback support for previous broken threads
|
||||||
@ -182,23 +188,6 @@ export default function useSendChatMessage() {
|
|||||||
if (runtimeParams.stream == null) {
|
if (runtimeParams.stream == null) {
|
||||||
runtimeParams.stream = true
|
runtimeParams.stream = true
|
||||||
}
|
}
|
||||||
// Add middleware to the model request with tool retrieval enabled
|
|
||||||
if (
|
|
||||||
activeThreadRef.current.assistants[0].tools?.some(
|
|
||||||
(tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
modelRequest = {
|
|
||||||
...modelRequest,
|
|
||||||
// Tool retrieval support document input only for now
|
|
||||||
...(isDocumentInput
|
|
||||||
? {
|
|
||||||
engine: InferenceEngine.tool_retrieval_enabled,
|
|
||||||
proxy_model: modelRequest.engine,
|
|
||||||
}
|
|
||||||
: {}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build Message Request
|
// Build Message Request
|
||||||
const requestBuilder = new MessageRequestBuilder(
|
const requestBuilder = new MessageRequestBuilder(
|
||||||
@ -247,15 +236,27 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId) {
|
||||||
setQueuedMessage(true)
|
setQueuedMessage(true)
|
||||||
await startModel(modelId)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
setQueuedMessage(false)
|
setQueuedMessage(false)
|
||||||
|
if (error) {
|
||||||
|
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
setIsGeneratingResponse(true)
|
setIsGeneratingResponse(true)
|
||||||
|
|
||||||
const engine = EngineManager.instance()?.get(
|
// Process message request with Assistants tools
|
||||||
requestBuilder.model?.engine ?? modelRequest.engine ?? ''
|
const request = await ToolManager.instance().process(
|
||||||
|
requestBuilder.build(),
|
||||||
|
activeThreadRef.current.assistants?.flatMap(
|
||||||
|
(assistant) => assistant.tools ?? []
|
||||||
|
) ?? []
|
||||||
)
|
)
|
||||||
engine?.inference(requestBuilder.build())
|
|
||||||
|
// Request for inference
|
||||||
|
EngineManager.instance()
|
||||||
|
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
||||||
|
?.inference(request)
|
||||||
|
|
||||||
// Reset states
|
// Reset states
|
||||||
setReloadModel(false)
|
setReloadModel(false)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { EngineManager } from '@janhq/core'
|
import { EngineManager, ToolManager } from '@janhq/core'
|
||||||
|
|
||||||
import { appService } from './appService'
|
import { appService } from './appService'
|
||||||
import { EventEmitter } from './eventsService'
|
import { EventEmitter } from './eventsService'
|
||||||
@ -15,6 +15,7 @@ export const setupCoreServices = () => {
|
|||||||
window.core = {
|
window.core = {
|
||||||
events: new EventEmitter(),
|
events: new EventEmitter(),
|
||||||
engineManager: new EngineManager(),
|
engineManager: new EngineManager(),
|
||||||
|
toolManager: new ToolManager(),
|
||||||
api: {
|
api: {
|
||||||
...(window.electronAPI ? window.electronAPI : restAPI),
|
...(window.electronAPI ? window.electronAPI : restAPI),
|
||||||
...appService,
|
...appService,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user