Merge pull request #2470 from janhq/chore/load-unload-model-sync

This commit is contained in:
Louis 2024-03-26 00:31:47 +07:00 committed by GitHub
commit 66f7d3dae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 741 additions and 573 deletions

View File

@ -1,4 +1,4 @@
import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from './types' import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from '../types'
/** /**
* Execute a extension module function in main process * Execute a extension module function in main process

View File

@ -1,4 +1,4 @@
import { Assistant, AssistantInterface } from '../index' import { Assistant, AssistantInterface } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
/** /**

View File

@ -1,4 +1,4 @@
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../index' import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
/** /**

View File

@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events' import { events } from '../../events'
import { BaseExtension } from '../../extension' import { BaseExtension } from '../../extension'
import { fs } from '../../fs' import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types' import { MessageRequest, Model, ModelEvent } from '../../../types'
import { EngineManager } from './EngineManager'
/** /**
* Base AIEngine * Base AIEngine
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension { export abstract class AIEngine extends BaseExtension {
// The inference engine // The inference engine
abstract provider: string abstract provider: string
// The model folder
modelFolder: string = 'models'
/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
this.prePopulateModels()
}
/**
* Defines models
*/
models(): Promise<Model[]> { models(): Promise<Model[]> {
return Promise.resolve([]) return Promise.resolve([])
} }
/** /**
* On extension load, subscribe to events. * Registers AI Engines
*/ */
onLoad() { registerEngine() {
this.prePopulateModels() EngineManager.instance().register(this)
} }
/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}
/*
* Inference request
*/
inference(data: MessageRequest) {}
/**
* Stop inference
*/
stopInference() {}
/** /**
* Pre-populate models to App Data Folder * Pre-populate models to App Data Folder
*/ */
prePopulateModels(): Promise<void> { prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => { return this.models().then((models) => {
const prePoluateOperations = models.map((model) => const prePoluateOperations = models.map((model) =>
getJanDataFolderPath() getJanDataFolderPath()
.then((janDataFolder) => .then((janDataFolder) =>
// Attempt to create the model folder // Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) => joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs fs
.mkdir(path) .mkdir(path)
.catch() .catch()

View File

@ -0,0 +1,32 @@
import { AIEngine } from './AIEngine'
/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()
/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}
/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}
/**
* The instance of the engine manager.
*/
static instance(): EngineManager {
return window.core?.engineManager as EngineManager ?? new EngineManager()
}
}

View File

@ -1,6 +1,6 @@
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
import { events } from '../../events' import { events } from '../../events'
import { Model, ModelEvent } from '../../types' import { Model, ModelEvent } from '../../../types'
import { OAIEngine } from './OAIEngine' import { OAIEngine } from './OAIEngine'
/** /**
@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.
*/ */
onLoad() { override onLoad() {
super.onLoad() super.onLoad()
// These events are applicable to local inference providers // These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/** /**
* Load the model. * Load the model.
*/ */
async loadModel(model: Model) { override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return if (model.engine.toString() !== this.provider) return
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id]) const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
const res = await executeOnMain( const res = await executeOnMain(
this.nodeModule, this.nodeModule,
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
) )
if (res?.error) { if (res?.error) {
events.emit(ModelEvent.OnModelFail, { events.emit(ModelEvent.OnModelFail, { error: res.error })
...model, return Promise.reject(res.error)
error: res.error,
})
return
} else { } else {
this.loadedModel = model this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model) events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
} }
} }
/** /**
* Stops the model. * Stops the model.
*/ */
unloadModel(model: Model) { override async unloadModel(model?: Model): Promise<void> {
if (model.engine && model.engine?.toString() !== this.provider) return if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
this.loadedModel = undefined
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {}) events.emit(ModelEvent.OnModelStopped, {})
}) })
} }

View File

@ -13,7 +13,7 @@ import {
ModelInfo, ModelInfo,
ThreadContent, ThreadContent,
ThreadMessage, ThreadMessage,
} from '../../types' } from '../../../types'
import { events } from '../../events' import { events } from '../../events'
/** /**
@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.
*/ */
onLoad() { override onLoad() {
super.onLoad() super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* On extension unload * On extension unload
*/ */
onUnload(): void {} override onUnload(): void {}
/* /*
* Inference request * Inference request
*/ */
inference(data: MessageRequest) { override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return if (data.model?.engine?.toString() !== this.provider) return
const timestamp = Date.now() const timestamp = Date.now()
@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
return return
} }
message.status = MessageStatus.Error message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message) events.emit(MessageEvent.OnMessageUpdate, message)
}, },
}) })
@ -114,7 +115,7 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* Stops the inference. * Stops the inference.
*/ */
stopInference() { override stopInference() {
this.isCancelled = true this.isCancelled = true
this.controller?.abort() this.controller?.abort()
} }

View File

@ -0,0 +1,26 @@
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -1,5 +1,5 @@
import { Observable } from 'rxjs' import { Observable } from 'rxjs'
import { ModelRuntimeParams } from '../../../types' import { ErrorCode, ModelRuntimeParams } from '../../../../types'
/** /**
* Sends a request to the inference server to generate a response based on the recent messages. * Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference. * @param recentMessages - An array of recent messages to use as context for the inference.
@ -34,6 +34,16 @@ export function requestInference(
signal: controller?.signal, signal: controller?.signal,
}) })
.then(async (response) => { .then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'Error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) { if (model.parameters.stream === false) {
const data = await response.json() const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '') subscriber.next(data.choices[0]?.message?.content ?? '')

View File

@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine' export * from './OAIEngine'
export * from './LocalOAIEngine' export * from './LocalOAIEngine'
export * from './RemoteOAIEngine' export * from './RemoteOAIEngine'
export * from './EngineManager'

View File

@ -1,6 +1,6 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../types/huggingface' import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../../types/huggingface'
import { Model } from '../types/model' import { Model } from '../../types/model'
/** /**
* Hugging Face extension for converting HF models to GGUF. * Hugging Face extension for converting HF models to GGUF.

View File

@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/** /**
* Base AI Engines. * Base AI Engines.
*/ */
export * from './ai-engines' export * from './engines'

View File

@ -1,4 +1,4 @@
import { InferenceInterface, MessageRequest, ThreadMessage } from '../index' import { InferenceInterface, MessageRequest, ThreadMessage } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
/** /**

View File

@ -1,5 +1,5 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../index' import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../../types'
/** /**
* Model extension for managing models. * Model extension for managing models.

View File

@ -1,5 +1,5 @@
import { BaseExtension, ExtensionTypeEnum } from '../extension' import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../index' import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../../types'
/** /**
* Monitoring extension for system monitoring. * Monitoring extension for system monitoring.

View File

@ -1,4 +1,4 @@
import { FileStat } from './types' import { FileStat } from '../types'
/** /**
* Writes data to a file at the specified path. * Writes data to a file at the specified path.

35
core/src/browser/index.ts Normal file
View File

@ -0,0 +1,35 @@
/**
* Export Core module
* @module
*/
export * from './core'
/**
* Export Event module.
* @module
*/
export * from './events'
/**
* Export Filesystem module.
* @module
*/
export * from './fs'
/**
* Export Extension module.
* @module
*/
export * from './extension'
/**
* Export all base extensions.
* @module
*/
export * from './extensions'
/**
* Export all base tools.
* @module
*/
export * from './tools'

View File

@ -0,0 +1,2 @@
export * from './manager'
export * from './tool'

View File

@ -0,0 +1,47 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'
/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()
/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}
/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}
/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}
/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}

View File

@ -0,0 +1,12 @@
import { AssistantTool, MessageRequest } from '../../types'
/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}

View File

@ -1,46 +0,0 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -2,42 +2,13 @@
* Export all types. * Export all types.
* @module * @module
*/ */
export * from './types/index' export * from './types'
/** /**
* Export all routes * Export browser module
*/
export * from './api'
/**
* Export Core module
* @module * @module
*/ */
export * from './core' export * from './browser'
/**
* Export Event module.
* @module
*/
export * from './events'
/**
* Export Filesystem module.
* @module
*/
export * from './fs'
/**
* Export Extension module.
* @module
*/
export * from './extension'
/**
* Export all base extensions.
* @module
*/
export * from './extensions/index'
/** /**
* Declare global object * Declare global object

View File

@ -4,7 +4,7 @@ import {
ExtensionRoute, ExtensionRoute,
FileManagerRoute, FileManagerRoute,
FileSystemRoute, FileSystemRoute,
} from '../../../api' } from '../../../types/api'
import { Downloader } from '../processors/download' import { Downloader } from '../processors/download'
import { FileSystem } from '../processors/fs' import { FileSystem } from '../processors/fs'
import { Extension } from '../processors/extension' import { Extension } from '../processors/extension'

View File

@ -1,4 +1,4 @@
import { CoreRoutes } from '../../../api' import { CoreRoutes } from '../../../types/api'
import { RequestAdapter } from './adapter' import { RequestAdapter } from './adapter'
export type Handler = (route: string, args: any) => any export type Handler = (route: string, args: any) => any

View File

@ -1,5 +1,5 @@
import { resolve, sep } from 'path' import { resolve, sep } from 'path'
import { DownloadEvent } from '../../../api' import { DownloadEvent } from '../../../types/api'
import { normalizeFilePath } from '../../helper/path' import { normalizeFilePath } from '../../helper/path'
import { getJanDataFolderPath } from '../../helper' import { getJanDataFolderPath } from '../../helper'
import { DownloadManager } from '../../helper/download' import { DownloadManager } from '../../helper/download'

View File

@ -1,4 +1,4 @@
import { DownloadRoute } from '../../../../api' import { DownloadRoute } from '../../../../types/api'
import { DownloadManager } from '../../../helper/download' import { DownloadManager } from '../../../helper/download'
import { HttpServer } from '../../HttpServer' import { HttpServer } from '../../HttpServer'

View File

@ -5,4 +5,4 @@ export * from './extension/store'
export * from './api' export * from './api'
export * from './helper' export * from './helper'
export * from './../types' export * from './../types'
export * from './../api' export * from '../types/api'

View File

@ -8,3 +8,4 @@ export * from './file'
export * from './config' export * from './config'
export * from './huggingface' export * from './huggingface'
export * from './miscellaneous' export * from './miscellaneous'
export * from './api'

View File

@ -7,7 +7,6 @@ export type ModelInfo = {
settings: ModelSettingParams settings: ModelSettingParams
parameters: ModelRuntimeParams parameters: ModelRuntimeParams
engine?: InferenceEngine engine?: InferenceEngine
proxy_model?: InferenceEngine
} }
/** /**
@ -21,8 +20,6 @@ export enum InferenceEngine {
groq = 'groq', groq = 'groq',
triton_trtllm = 'triton_trtllm', triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm', nitro_tensorrt_llm = 'nitro-tensorrt-llm',
tool_retrieval_enabled = 'tool_retrieval_enabled',
} }
export type ModelArtifact = { export type ModelArtifact = {
@ -94,8 +91,6 @@ export type Model = {
* The model engine. * The model engine.
*/ */
engine: InferenceEngine engine: InferenceEngine
proxy_model?: InferenceEngine
} }
export type ModelMetadata = { export type ModelMetadata = {

View File

@ -1,26 +1,21 @@
import { import {
fs, fs,
Assistant, Assistant,
MessageRequest,
events, events,
InferenceEngine,
MessageEvent,
InferenceEvent,
joinPath, joinPath,
executeOnMain,
AssistantExtension, AssistantExtension,
AssistantEvent, AssistantEvent,
ToolManager,
} from '@janhq/core' } from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'
export default class JanAssistantExtension extends AssistantExtension { export default class JanAssistantExtension extends AssistantExtension {
private static readonly _homeDir = 'file://assistants' private static readonly _homeDir = 'file://assistants'
private static readonly _threadDir = 'file://threads'
controller = new AbortController()
isCancelled = false
retrievalThreadId: string | undefined = undefined
async onLoad() { async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())
// making the assistant directory // making the assistant directory
const assistantDirExist = await fs.existsSync( const assistantDirExist = await fs.existsSync(
JanAssistantExtension._homeDir JanAssistantExtension._homeDir
@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
// Update the assistant list // Update the assistant list
events.emit(AssistantEvent.OnAssistantsUpdate, {}) events.emit(AssistantEvent.OnAssistantsUpdate, {})
} }
// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
JanAssistantExtension.handleMessageRequest(data, this)
)
events.on(InferenceEvent.OnInferenceStopped, () => {
JanAssistantExtension.handleInferenceStopped(this)
})
}
private static async handleInferenceStopped(instance: JanAssistantExtension) {
instance.isCancelled = true
instance.controller?.abort()
}
private static async handleMessageRequest(
data: MessageRequest,
instance: JanAssistantExtension
) {
instance.isCancelled = false
instance.controller = new AbortController()
if (
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
!data.messages ||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
// That could lead to an issue where thread stuck at generating response
!data.thread?.assistants[0]?.tools
) {
return
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
docFile,
data.model?.proxy_model
)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([
JanAssistantExtension._threadDir,
data.threadId,
'memory',
])
))
) {
// No document ingested, reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
return
}
// 2. Load agent on thread changed
if (instance.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
instance.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (data.thread?.assistants[0]?.tools && retrievalResult)
data.messages[data.messages.length - 1].content =
data.thread.assistants[0].tools[0].settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// Filter out all the messages that are not text
data.messages = data.messages.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
// 4. Reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxy_model,
},
}
events.emit(MessageEvent.OnMessageSent, output)
} }
/** /**

View File

@ -0,0 +1,108 @@
import {
AssistantTool,
executeOnMain,
fs,
InferenceTool,
joinPath,
MessageRequest,
} from '@janhq/core'
export class RetrievalTool extends InferenceTool {
private _threadDir = 'file://threads'
private retrievalThreadId: string | undefined = undefined
name: string = 'retrieval'
async process(
data: MessageRequest,
tool?: AssistantTool
): Promise<MessageRequest> {
if (!data.model || !data.messages) {
return Promise.resolve(data)
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
docFile,
data.model?.engine
)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([this._threadDir, data.threadId, 'memory'])
))
) {
// No document ingested, reroute the result to inference engine
return Promise.resolve(data)
}
// 2. Load agent on thread changed
if (this.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
this.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
tool?.settings?.chunk_size ?? 4000,
tool?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (retrievalResult)
data.messages[data.messages.length - 1].content =
tool?.settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// Filter out all the messages that are not text
data.messages = data.messages.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
// 4. Reroute the result to inference engine
return Promise.resolve(data)
}
}

View File

@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model) return super.loadModel(model)
} }
override unloadModel(model: Model): void { override async unloadModel(model?: Model) {
super.unloadModel(model) if (model?.engine && model.engine !== this.provider) return
if (model.engine && model.engine !== this.provider) return
// stop the periocally health check // stop the periocally health check
if (this.getNitroProcesHealthIntervalId) { if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId) clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined this.getNitroProcesHealthIntervalId = undefined
} }
return super.unloadModel(model)
} }
} }

View File

@ -271,26 +271,7 @@ const DropdownListSidebar = ({
)} )}
> >
<div className="relative flex w-full justify-between"> <div className="relative flex w-full justify-between">
{x.engine === InferenceEngine.openai && ( <div>
<svg
width="20"
height="20"
viewBox="0 0 20 20"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="absolute top-1"
>
<path
d="M18.5681 8.18423C18.7917 7.51079 18.8691 6.79739 18.795 6.09168C18.7209 5.38596 18.497 4.70419 18.1384 4.0919C17.6067 3.16642 16.7948 2.43369 15.8199 1.99936C14.8449 1.56503 13.7572 1.45153 12.7135 1.67523C12.1206 1.0157 11.3646 0.523789 10.5214 0.248906C9.67823 -0.0259764 8.77756 -0.0741542 7.90986 0.109212C7.04216 0.292577 6.23798 0.701031 5.57809 1.29355C4.91821 1.88607 4.42584 2.64179 4.15046 3.48481C3.45518 3.62739 2.79834 3.91672 2.22384 4.33347C1.64933 4.75023 1.1704 5.28481 0.81904 5.90148C0.281569 6.82542 0.0518576 7.89634 0.163116 8.95943C0.274374 10.0225 0.720837 11.0227 1.43796 11.8153C1.21351 12.4884 1.13539 13.2017 1.20883 13.9074C1.28227 14.6132 1.50557 15.2951 1.86379 15.9076C2.39616 16.8334 3.20872 17.5663 4.18438 18.0006C5.16004 18.4349 6.24841 18.5483 7.29262 18.3243C7.76367 18.8548 8.34248 19.2786 8.99038 19.5676C9.63828 19.8566 10.3404 20.004 11.0498 20C12.1195 20.001 13.1618 19.662 14.0263 19.032C14.8909 18.4021 15.5329 17.5137 15.8596 16.4951C16.5548 16.3523 17.2116 16.0629 17.786 15.6461C18.3605 15.2294 18.8395 14.6949 19.191 14.0784C19.7222 13.1558 19.9479 12.0889 19.836 11.0303C19.7242 9.97163 19.2804 8.9754 18.5681 8.18423ZM11.0498 18.691C10.1737 18.6924 9.32512 18.3853 8.65279 17.8236L8.77104 17.7566L12.753 15.4581C12.8521 15.4 12.9343 15.3171 12.9917 15.2176C13.0491 15.118 13.0796 15.0053 13.0802 14.8904V9.27631L14.7635 10.2501C14.7719 10.2544 14.7791 10.2605 14.7846 10.268C14.7901 10.2755 14.7937 10.2843 14.7952 10.2935V14.9456C14.7931 15.9383 14.3978 16.8898 13.6959 17.5917C12.9939 18.2936 12.0425 18.6889 11.0498 18.691ZM2.99921 15.2531C2.55985 14.4945 2.4021 13.6052 2.55371 12.7417L2.67204 12.8127L6.65787 15.1112C6.7565 15.1691 6.86877 15.1996 6.98312 15.1996C7.09747 15.1996 7.20975 15.1691 7.30837 15.1112L12.1774 12.3041V14.2478C12.1769 14.2579 12.1742 14.2677 12.1694 14.2766C12.1646 14.2855 12.1579 14.2932 12.1497 14.2991L8.11654 16.6251C7.25581 17.121 6.2335 17.255 5.27405 16.9978C4.3146 16.7405 3.49644 16.1131 2.99921 15.2531ZM1.95054 6.57965C2.39294 5.81612 3.09123 5.23375 3.92179 4.93565V9.66665C3.92029 9.78094 3.94949 9.89355 4.00635 9.99271C4.06321 10.0919 4.14564 10.174 4.24504 10.2304L9.09037 13.0256L7.40696 13.9994C7.39785 14.0042 7.38769 14.0068 7.37737 14.0068C7.36706 14.0068 7.3569 14.0042 7.34779 13.9994L3.32254 11.6773C2.46343 11.1793 1.83666 10.3612 1.57951 9.40204C1.32236 8.44291 1.45577 7.42095 1.95054 6.55998V6.57965ZM15.7808 9.79281L10.9197 6.96998L12.5992 5.99998C12.6083 5.99514 12.6185 5.99261 12.6288 5.99261C12.6391 5.99261 12.6493 5.99514 12.6584 5.99998L16.6836 8.32606C17.2991 8.68119 17.8008 9.20407 18.1303 9.83365C18.4597 10.4632 18.6032 11.1735 18.5441 11.8816C18.485 12.5898 18.2257 13.2664 17.7964 13.8327C17.3672 14.3989 16.7857 14.8314 16.1199 15.0796V10.3486C16.1164 10.2345 16.0833 10.1232 16.0238 10.0258C15.9644 9.92833 15.8807 9.8481 15.7808 9.79281ZM17.4564 7.27356L17.338 7.20256L13.3601 4.8844C13.2609 4.82617 13.1479 4.79547 13.0329 4.79547C12.9178 4.79547 12.8049 4.82617 12.7056 4.8844L7.84071 7.6914V5.74781C7.83967 5.73793 7.84132 5.72795 7.84549 5.71893C7.84965 5.70991 7.85618 5.70218 7.86437 5.69656L11.8896 3.3744C12.5066 3.01899 13.2119 2.84659 13.9232 2.87736C14.6345 2.90813 15.3224 3.14079 15.9063 3.54813C16.4903 3.95548 16.9461 4.52066 17.2206 5.17759C17.4952 5.83452 17.577 6.55602 17.4565 7.25773L17.4564 7.27356ZM6.92196 10.7191L5.23862 9.74931C5.2302 9.74424 5.223 9.73738 5.21753 9.72921C5.21205 9.72105 5.20845 9.71178 5.20696 9.70206V5.06181C5.20788 4.34996 5.41144 3.65307 5.79383 3.05265C6.17622 2.45222 6.72164 1.97305 7.36632 1.67118C8.011 1.3693 8.7283 1.2572 9.43434 1.34796C10.1404 1.43873 10.806 1.72861 11.3534 2.18373L11.235 2.25081L7.25321 4.54915C7.1541 4.60727 7.07182 4.69017 7.01445 4.78971C6.95707 4.88925 6.92658 5.00201 6.92596 5.1169L6.92196 10.7191ZM7.83662 8.74798L10.005 7.49815L12.1774 8.74798V11.2475L10.0129 12.4972L7.84062 11.2475L7.83662 8.74798Z"
fill="#18181B"
/>
</svg>
)}
<div
className={twMerge(
x.engine === InferenceEngine.openai && 'pl-8'
)}
>
<span className="line-clamp-1 block"> <span className="line-clamp-1 block">
{x.name} {x.name}
</span> </span>
@ -307,8 +288,7 @@ const DropdownListSidebar = ({
</SelectItem> </SelectItem>
<div <div
className={twMerge( className={twMerge(
'absolute -mt-6 inline-flex items-center space-x-2 px-4 pb-2 text-muted-foreground', 'absolute -mt-6 inline-flex items-center space-x-2 px-4 pb-2 text-muted-foreground'
x.engine === InferenceEngine.openai && 'left-8'
)} )}
> >
<span className="text-xs">{x.id}</span> <span className="text-xs">{x.id}</span>

View File

@ -8,26 +8,17 @@ import {
ExtensionTypeEnum, ExtensionTypeEnum,
MessageStatus, MessageStatus,
MessageRequest, MessageRequest,
Model,
ConversationalExtension, ConversationalExtension,
MessageEvent, MessageEvent,
MessageRequestType, MessageRequestType,
ModelEvent, ModelEvent,
Thread, Thread,
ModelInitFailed, EngineManager,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx' import { ulid } from 'ulidx'
import { import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
activeModelAtom,
loadModelErrorAtom,
stateModelAtom,
} from '@/hooks/useActiveModel'
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
import { toaster } from '../Toast'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom) const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom) const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setLoadModelError = useSetAtom(loadModelErrorAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage] [addNewMessage]
) )
const onModelReady = useCallback(
(model: Model) => {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
},
[setActiveModel, setStateModel]
)
const onModelStopped = useCallback(() => { const onModelStopped = useCallback(() => {
setTimeout(() => {
setActiveModel(undefined) setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' }) setStateModel({ state: 'start', loading: false, model: '' })
}, 500)
}, [setActiveModel, setStateModel]) }, [setActiveModel, setStateModel])
const onModelInitFailed = useCallback(
(res: ModelInitFailed) => {
console.error('Failed to load model: ', res.error.message)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.id,
}))
setLoadModelError(res.error.message)
setQueuedMessage(false)
},
[setStateModel, setQueuedMessage, setLoadModelError]
)
const updateThreadTitle = useCallback( const updateThreadTitle = useCallback(
(message: ThreadMessage) => { (message: ThreadMessage) => {
// Update only when it's finished // Update only when it's finished
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// 2. Update the title with the result of the inference // 2. Update the title with the result of the inference
setTimeout(() => { setTimeout(() => {
events.emit(MessageEvent.OnMessageSent, messageRequest) const engine = EngineManager.instance().get(
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
}, 1000) }, 1000)
} }
} }
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) { if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse) events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.on(ModelEvent.OnModelReady, onModelReady)
events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped) events.on(ModelEvent.OnModelStopped, onModelStopped)
} }
}, [ }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
onNewMessageResponse,
onMessageResponseUpdate,
onModelReady,
onModelInitFailed,
onModelStopped,
])
useEffect(() => { useEffect(() => {
return () => { return () => {
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse) events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.off(ModelEvent.OnModelStopped, onModelStopped)
} }
}, [onNewMessageResponse, onMessageResponseUpdate]) }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
return <Fragment>{children}</Fragment> return <Fragment>{children}</Fragment>
} }

View File

@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core' import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import Extension from './Extension' import Extension from './Extension'
@ -8,14 +8,26 @@ import Extension from './Extension'
* Manages the registration and retrieval of extensions. * Manages the registration and retrieval of extensions.
*/ */
export class ExtensionManager { export class ExtensionManager {
// Registered extensions
private extensions = new Map<string, BaseExtension>() private extensions = new Map<string, BaseExtension>()
// Registered inference engines
private engines = new Map<string, AIEngine>()
/** /**
* Registers an extension. * Registers an extension.
* @param extension - The extension to register. * @param extension - The extension to register.
*/ */
register<T extends BaseExtension>(name: string, extension: T) { register<T extends BaseExtension>(name: string, extension: T) {
this.extensions.set(extension.type() ?? name, extension) this.extensions.set(extension.type() ?? name, extension)
// Register AI Engines
if ('provider' in extension && typeof extension.provider === 'string') {
this.engines.set(
extension.provider as unknown as string,
extension as unknown as AIEngine
)
}
} }
/** /**
@ -29,6 +41,15 @@ export class ExtensionManager {
return this.extensions.get(type) as T | undefined return this.extensions.get(type) as T | undefined
} }
/**
* Retrieves a extension by its type.
* @param engine - The engine name to retrieve.
* @returns The extension, if found.
*/
getEngine<T extends AIEngine>(engine: string): T | undefined {
return this.engines.get(engine) as T | undefined
}
/** /**
* Loads all registered extension. * Loads all registered extension.
*/ */

View File

@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef } from 'react' import { useCallback, useEffect, useRef } from 'react'
import { events, Model, ModelEvent } from '@janhq/core' import { EngineManager, Model } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
@ -38,19 +38,13 @@ export function useActiveModel() {
(stateModel.model === modelId && stateModel.loading) (stateModel.model === modelId && stateModel.loading)
) { ) {
console.debug(`Model ${modelId} is already initialized. Ignore..`) console.debug(`Model ${modelId} is already initialized. Ignore..`)
return return Promise.resolve()
} }
let model = downloadedModelsRef?.current.find((e) => e.id === modelId) let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
// Switch between engines await stopModel().catch()
if (model && activeModel && activeModel.engine !== model.engine) {
stopModel()
// TODO: Refactor inference provider would address this
await new Promise((res) => setTimeout(res, 1000))
}
// TODO: incase we have multiple assistants, the configuration will be from assistant
setLoadModelError(undefined) setLoadModelError(undefined)
setActiveModel(undefined) setActiveModel(undefined)
@ -68,7 +62,8 @@ export function useActiveModel() {
loading: false, loading: false,
model: '', model: '',
})) }))
return
return Promise.reject(`Model ${modelId} not found!`)
} }
/// Apply thread model settings /// Apply thread model settings
@ -83,15 +78,52 @@ export function useActiveModel() {
} }
localStorage.setItem(LAST_USED_MODEL_ID, model.id) localStorage.setItem(LAST_USED_MODEL_ID, model.id)
events.emit(ModelEvent.OnModelInit, model) const engine = EngineManager.instance().get(model.engine)
return engine
?.loadModel(model)
.then(() => {
setActiveModel(model)
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
})
.catch((error) => {
setStateModel(() => ({
state: 'start',
loading: false,
model: model.id,
}))
toaster({
title: 'Failed!',
description: `Model ${model.id} failed to start.`,
type: 'success',
})
setLoadModelError(error)
return Promise.reject(error)
})
} }
const stopModel = useCallback(async () => { const stopModel = useCallback(async () => {
if (activeModel) { if (activeModel) {
setStateModel({ state: 'stop', loading: true, model: activeModel.id }) setStateModel({ state: 'stop', loading: true, model: activeModel.id })
events.emit(ModelEvent.OnModelStop, activeModel) const engine = EngineManager.instance().get(activeModel.engine)
await engine
?.unloadModel(activeModel)
.catch()
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
})
} }
}, [activeModel, setStateModel]) }, [activeModel, setActiveModel, setStateModel])
return { activeModel, startModel, stopModel, stateModel } return { activeModel, startModel, stopModel, stateModel }
} }

View File

@ -2,27 +2,18 @@
import { useEffect, useRef } from 'react' import { useEffect, useRef } from 'react'
import { import {
ChatCompletionMessage,
ChatCompletionRole, ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType, MessageRequestType,
MessageStatus,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
ThreadMessage, ThreadMessage,
events,
Model, Model,
ConversationalExtension, ConversationalExtension,
MessageEvent, EngineManager,
InferenceEngine, ToolManager,
ChatCompletionMessageContentType,
AssistantTool,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
import { selectedModelAtom } from '@/containers/DropdownListSidebar' import { selectedModelAtom } from '@/containers/DropdownListSidebar'
import { import {
currentPromptAtom, currentPromptAtom,
@ -31,8 +22,11 @@ import {
} from '@/containers/Providers/Jotai' } from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64' import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { loadModelErrorAtom, useActiveModel } from './useActiveModel' import { loadModelErrorAtom, useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
@ -65,7 +59,6 @@ export default function useSendChatMessage() {
const currentMessages = useAtomValue(getCurrentChatMessagesAtom) const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom) const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel() const { activeModel, startModel } = useActiveModel()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const loadModelFailed = useAtomValue(loadModelErrorAtom) const loadModelFailed = useAtomValue(loadModelErrorAtom)
const modelRef = useRef<Model | undefined>() const modelRef = useRef<Model | undefined>()
@ -78,6 +71,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>() const activeThreadRef = useRef<Thread | undefined>()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const selectedModelRef = useRef<Model | undefined>() const selectedModelRef = useRef<Model | undefined>()
@ -103,51 +97,27 @@ export default function useSendChatMessage() {
return return
} }
updateThreadWaiting(activeThreadRef.current.id, true) updateThreadWaiting(activeThreadRef.current.id, true)
const messages: ChatCompletionMessage[] = [
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter(
(e) =>
(currentMessage.role === ChatCompletionRole.User ||
e.id !== currentMessage.id) &&
e.status !== MessageStatus.Error
)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
)
const messageRequest: MessageRequest = { const requestBuilder = new MessageRequestBuilder(
id: ulid(), MessageRequestType.Thread,
type: MessageRequestType.Thread,
messages: messages,
threadId: activeThreadRef.current.id,
model:
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
} activeThreadRef.current,
currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions)
const modelId = const modelId =
selectedModelRef.current?.id ?? selectedModelRef.current?.id ??
activeThreadRef.current.assistants[0].model.id activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
setQueuedMessage(true) const error = await startModel(modelId).catch((error: Error) => error)
startModel(modelId) if (error) {
await waitForModelStarting(modelId) return
setQueuedMessage(false)
} }
}
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
if (currentMessage.role !== ChatCompletionRole.User) { if (currentMessage.role !== ChatCompletionRole.User) {
// Delete last response before regenerating // Delete last response before regenerating
deleteMessage(currentMessage.id ?? '') deleteMessage(currentMessage.id ?? '')
@ -160,9 +130,22 @@ export default function useSendChatMessage() {
) )
} }
} }
events.emit(MessageEvent.OnMessageSent, messageRequest) // Process message request with Assistants tools
const request = await ToolManager.instance().process(
requestBuilder.build(),
activeThreadRef.current.assistants?.flatMap(
(assistant) => assistant.tools ?? []
) ?? []
)
const engine =
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
EngineManager.instance().get(engine)?.inference(request)
} }
// Define interface extending Array prototype
const sendChatMessage = async (message: string) => { const sendChatMessage = async (message: string) => {
if (!message || message.trim().length === 0) return if (!message || message.trim().length === 0) return
@ -176,8 +159,9 @@ export default function useSendChatMessage() {
const runtimeParams = toRuntimeParams(activeModelParams) const runtimeParams = toRuntimeParams(activeModelParams)
const settingParams = toSettingParams(activeModelParams) const settingParams = toSettingParams(activeModelParams)
updateThreadWaiting(activeThreadRef.current.id, true)
const prompt = message.trim() const prompt = message.trim()
updateThreadWaiting(activeThreadRef.current.id, true)
setCurrentPrompt('') setCurrentPrompt('')
setEditPrompt('') setEditPrompt('')
@ -185,69 +169,12 @@ export default function useSendChatMessage() {
? await getBase64(fileUpload[0].file) ? await getBase64(fileUpload[0].file)
: undefined : undefined
const fileContentType = fileUpload[0]?.type if (base64Blob && fileUpload[0]?.type === 'image') {
const msgId = ulid()
const isDocumentInput = base64Blob && fileContentType === 'pdf'
const isImageInput = base64Blob && fileContentType === 'image'
if (isImageInput && base64Blob) {
// Compress image // Compress image
base64Blob = await compressImage(base64Blob, 512) base64Blob = await compressImage(base64Blob, 512)
} }
const messages: ChatCompletionMessage[] = [ const modelRequest =
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
.concat([
{
role: ChatCompletionRole.User,
content:
selectedModelRef.current && base64Blob
? [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
},
isDocumentInput
? {
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`,
},
}
: null,
isImageInput
? {
type: ChatCompletionMessageContentType.Image,
image_url: {
url: base64Blob,
},
}
: null,
].filter((e) => e !== null)
: prompt,
} as ChatCompletionMessage,
])
)
let modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
// Fallback support for previous broken threads // Fallback support for previous broken threads
@ -261,131 +188,83 @@ export default function useSendChatMessage() {
if (runtimeParams.stream == null) { if (runtimeParams.stream == null) {
runtimeParams.stream = true runtimeParams.stream = true
} }
// Add middleware to the model request with tool retrieval enabled
if ( // Build Message Request
activeThreadRef.current.assistants[0].tools?.some( const requestBuilder = new MessageRequestBuilder(
(tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled MessageRequestType.Thread,
) {
) {
modelRequest = {
...modelRequest,
// Tool retrieval support document input only for now
...(isDocumentInput
? {
engine: InferenceEngine.tool_retrieval_enabled,
proxy_model: modelRequest.engine,
}
: {}),
}
}
const messageRequest: MessageRequest = {
id: msgId,
type: MessageRequestType.Thread,
threadId: activeThreadRef.current.id,
messages,
model: {
...modelRequest, ...modelRequest,
settings: settingParams, settings: settingParams,
parameters: runtimeParams, parameters: runtimeParams,
}, },
thread: activeThreadRef.current, activeThreadRef.current,
} currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
const timestamp = Date.now() requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
const content: any = []
if (base64Blob && fileUpload[0]?.type === 'image') { // Build Thread Message to persist
content.push({ const threadMessageBuilder = new ThreadMessageBuilder(
type: ContentType.Image, requestBuilder
text: { ).pushMessage(prompt, base64Blob, fileUpload)
value: prompt,
annotations: [base64Blob],
},
})
}
if (base64Blob && fileUpload[0]?.type === 'pdf') { const newMessage = threadMessageBuilder.build()
content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64Blob],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
if (prompt && !base64Blob) { // Push to states
content.push({ addNewMessage(newMessage)
type: ContentType.Text,
text: {
value: prompt,
annotations: [],
},
})
}
const threadMessage: ThreadMessage = {
id: msgId,
thread_id: activeThreadRef.current.id,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
updated: timestamp,
object: 'thread.message',
content: content,
}
addNewMessage(threadMessage)
if (base64Blob) {
setFileUpload([])
}
// Update thread state
const updatedThread: Thread = { const updatedThread: Thread = {
...activeThreadRef.current, ...activeThreadRef.current,
updated: timestamp, updated: newMessage.created,
metadata: { metadata: {
...(activeThreadRef.current.metadata ?? {}), ...(activeThreadRef.current.metadata ?? {}),
lastMessage: prompt, lastMessage: prompt,
}, },
} }
// change last update thread when send message
updateThread(updatedThread) updateThread(updatedThread)
// Add message
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(threadMessage) ?.addNewMessage(newMessage)
// Start Model if not started
const modelId = const modelId =
selectedModelRef.current?.id ?? selectedModelRef.current?.id ??
activeThreadRef.current.assistants[0].model.id activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
setQueuedMessage(true) setQueuedMessage(true)
startModel(modelId) const error = await startModel(modelId).catch((error: Error) => error)
await waitForModelStarting(modelId)
setQueuedMessage(false) setQueuedMessage(false)
if (error) {
updateThreadWaiting(activeThreadRef.current.id, false)
return
}
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
events.emit(MessageEvent.OnMessageSent, messageRequest)
// Process message request with Assistants tools
const request = await ToolManager.instance().process(
requestBuilder.build(),
activeThreadRef.current.assistants?.flatMap(
(assistant) => assistant.tools ?? []
) ?? []
)
// Request for inference
EngineManager.instance()
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
?.inference(request)
// Reset states
setReloadModel(false) setReloadModel(false)
setEngineParamsUpdate(false) setEngineParamsUpdate(false)
}
const waitForModelStarting = async (modelId: string) => { if (base64Blob) {
return new Promise<void>((resolve) => { setFileUpload([])
setTimeout(async () => {
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
await waitForModelStarting(modelId)
resolve()
} else {
resolve()
} }
}, 200)
})
} }
return { return {

View File

@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
</p> </p>
<ModalTroubleShooting /> <ModalTroubleShooting />
</div> </div>
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? ( ) : loadModelError &&
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
<div <div
key={message.id} key={message.id}
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500" className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"

View File

@ -24,8 +24,8 @@ const FileUploadPreview: React.FC = () => {
<div className="relative inline-flex w-60 space-x-3 rounded-lg bg-secondary p-4"> <div className="relative inline-flex w-60 space-x-3 rounded-lg bg-secondary p-4">
<Icon type={fileUpload[0].type} /> <Icon type={fileUpload[0].type} />
<div> <div className="w-full">
<h6 className="line-clamp-1 font-medium"> <h6 className="line-clamp-1 w-3/4 truncate font-medium">
{fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')} {fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')}
</h6> </h6>
<p className="text-muted-foreground"> <p className="text-muted-foreground">

View File

@ -260,8 +260,8 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
<Icon type={props.content[0].type} /> <Icon type={props.content[0].type} />
<div> <div className="w-full">
<h6 className="line-clamp-1 font-medium"> <h6 className="line-clamp-1 w-4/5 font-medium">
{props.content[0].text.name?.replaceAll(/[-._]/g, ' ')} {props.content[0].text.name?.replaceAll(/[-._]/g, ' ')}
</h6> </h6>
<p className="text-muted-foreground"> <p className="text-muted-foreground">

View File

@ -1,3 +1,5 @@
import { EngineManager, ToolManager } from '@janhq/core'
import { appService } from './appService' import { appService } from './appService'
import { EventEmitter } from './eventsService' import { EventEmitter } from './eventsService'
import { restAPI } from './restService' import { restAPI } from './restService'
@ -12,6 +14,8 @@ export const setupCoreServices = () => {
if (!window.core) { if (!window.core) {
window.core = { window.core = {
events: new EventEmitter(), events: new EventEmitter(),
engineManager: new EngineManager(),
toolManager: new ToolManager(),
api: { api: {
...(window.electronAPI ? window.electronAPI : restAPI), ...(window.electronAPI ? window.electronAPI : restAPI),
...appService, ...appService,

View File

@ -0,0 +1,130 @@
import {
ChatCompletionMessage,
ChatCompletionMessageContent,
ChatCompletionMessageContentText,
ChatCompletionMessageContentType,
ChatCompletionRole,
MessageRequest,
MessageRequestType,
MessageStatus,
ModelInfo,
Thread,
ThreadMessage,
} from '@janhq/core'
import { ulid } from 'ulidx'
import { FileType } from '@/containers/Providers/Jotai'
export class MessageRequestBuilder {
msgId: string
type: MessageRequestType
messages: ChatCompletionMessage[]
model: ModelInfo
thread: Thread
constructor(
type: MessageRequestType,
model: ModelInfo,
thread: Thread,
messages: ThreadMessage[]
) {
this.msgId = ulid()
this.type = type
this.model = model
this.thread = thread
this.messages = messages
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
}
// Chainable
pushMessage(
message: string,
base64Blob: string | undefined,
fileContentType: FileType
) {
if (base64Blob && fileContentType === 'pdf')
return this.addDocMessage(message)
else if (base64Blob && fileContentType === 'image') {
return this.addImageMessage(message, base64Blob)
}
this.messages = [
...this.messages,
{
role: ChatCompletionRole.User,
content: message,
},
]
return this
}
// Chainable
addSystemMessage(message: string | undefined) {
if (!message || message.trim() === '') return this
this.messages = [
{
role: ChatCompletionRole.System,
content: message,
},
...this.messages,
]
return this
}
// Chainable
addDocMessage(prompt: string) {
const message: ChatCompletionMessage = {
role: ChatCompletionRole.User,
content: [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
} as ChatCompletionMessageContentText,
{
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
},
},
] as ChatCompletionMessageContent,
}
this.messages = [message, ...this.messages]
return this
}
// Chainable
addImageMessage(prompt: string, base64: string) {
const message: ChatCompletionMessage = {
role: ChatCompletionRole.User,
content: [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
} as ChatCompletionMessageContentText,
{
type: ChatCompletionMessageContentType.Image,
image_url: {
url: base64,
},
},
] as ChatCompletionMessageContent,
}
this.messages = [message, ...this.messages]
return this
}
build(): MessageRequest {
return {
id: this.msgId,
type: this.type,
threadId: this.thread.id,
messages: this.messages,
model: this.model,
thread: this.thread,
}
}
}

View File

@ -0,0 +1,74 @@
import {
ChatCompletionRole,
ContentType,
MessageStatus,
ThreadContent,
ThreadMessage,
} from '@janhq/core'
import { FileInfo } from '@/containers/Providers/Jotai'
import { MessageRequestBuilder } from './messageRequestBuilder'
export class ThreadMessageBuilder {
messageRequest: MessageRequestBuilder
content: ThreadContent[] = []
constructor(messageRequest: MessageRequestBuilder) {
this.messageRequest = messageRequest
}
build(): ThreadMessage {
const timestamp = Date.now()
return {
id: this.messageRequest.msgId,
thread_id: this.messageRequest.thread.id,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
updated: timestamp,
object: 'thread.message',
content: this.content,
}
}
pushMessage(
prompt: string,
base64: string | undefined,
fileUpload: FileInfo[]
) {
if (base64 && fileUpload[0]?.type === 'image') {
this.content.push({
type: ContentType.Image,
text: {
value: prompt,
annotations: [base64],
},
})
}
if (base64 && fileUpload[0]?.type === 'pdf') {
this.content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
if (prompt && !base64) {
this.content.push({
type: ContentType.Text,
text: {
value: prompt,
annotations: [],
},
})
}
return this
}
}