Merge pull request #2470 from janhq/chore/load-unload-model-sync
This commit is contained in:
commit
66f7d3dae3
@ -1,4 +1,4 @@
|
||||
import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from './types'
|
||||
import { DownloadRequest, FileStat, NetworkConfig, SystemInformation } from '../types'
|
||||
|
||||
/**
|
||||
* Execute a extension module function in main process
|
||||
@ -1,4 +1,4 @@
|
||||
import { Assistant, AssistantInterface } from '../index'
|
||||
import { Assistant, AssistantInterface } from '../../types'
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
|
||||
/**
|
||||
@ -1,4 +1,4 @@
|
||||
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../index'
|
||||
import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types'
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
|
||||
/**
|
||||
@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
|
||||
import { events } from '../../events'
|
||||
import { BaseExtension } from '../../extension'
|
||||
import { fs } from '../../fs'
|
||||
import { Model, ModelEvent } from '../../types'
|
||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
||||
import { EngineManager } from './EngineManager'
|
||||
|
||||
/**
|
||||
* Base AIEngine
|
||||
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
|
||||
export abstract class AIEngine extends BaseExtension {
|
||||
// The inference engine
|
||||
abstract provider: string
|
||||
// The model folder
|
||||
modelFolder: string = 'models'
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
override onLoad() {
|
||||
this.registerEngine()
|
||||
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
|
||||
this.prePopulateModels()
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines models
|
||||
*/
|
||||
models(): Promise<Model[]> {
|
||||
return Promise.resolve([])
|
||||
}
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
* Registers AI Engines
|
||||
*/
|
||||
onLoad() {
|
||||
this.prePopulateModels()
|
||||
registerEngine() {
|
||||
EngineManager.instance().register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the model.
|
||||
*/
|
||||
async loadModel(model: Model): Promise<any> {
|
||||
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
async unloadModel(model?: Model): Promise<any> {
|
||||
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelStopped, model ?? {})
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
/*
|
||||
* Inference request
|
||||
*/
|
||||
inference(data: MessageRequest) {}
|
||||
|
||||
/**
|
||||
* Stop inference
|
||||
*/
|
||||
stopInference() {}
|
||||
|
||||
/**
|
||||
* Pre-populate models to App Data Folder
|
||||
*/
|
||||
prePopulateModels(): Promise<void> {
|
||||
const modelFolder = 'models'
|
||||
return this.models().then((models) => {
|
||||
const prePoluateOperations = models.map((model) =>
|
||||
getJanDataFolderPath()
|
||||
.then((janDataFolder) =>
|
||||
// Attempt to create the model folder
|
||||
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
|
||||
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
|
||||
fs
|
||||
.mkdir(path)
|
||||
.catch()
|
||||
32
core/src/browser/extensions/engines/EngineManager.ts
Normal file
32
core/src/browser/extensions/engines/EngineManager.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { AIEngine } from './AIEngine'
|
||||
|
||||
/**
|
||||
* Manages the registration and retrieval of inference engines.
|
||||
*/
|
||||
export class EngineManager {
|
||||
public engines = new Map<string, AIEngine>()
|
||||
|
||||
/**
|
||||
* Registers an engine.
|
||||
* @param engine - The engine to register.
|
||||
*/
|
||||
register<T extends AIEngine>(engine: T) {
|
||||
this.engines.set(engine.provider, engine)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a engine by provider.
|
||||
* @param provider - The name of the engine to retrieve.
|
||||
* @returns The engine, if found.
|
||||
*/
|
||||
get<T extends AIEngine>(provider: string): T | undefined {
|
||||
return this.engines.get(provider) as T | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* The instance of the engine manager.
|
||||
*/
|
||||
static instance(): EngineManager {
|
||||
return window.core?.engineManager as EngineManager ?? new EngineManager()
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
|
||||
import { events } from '../../events'
|
||||
import { Model, ModelEvent } from '../../types'
|
||||
import { Model, ModelEvent } from '../../../types'
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
// These events are applicable to local inference providers
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
async loadModel(model: Model) {
|
||||
override async loadModel(model: Model): Promise<void> {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
|
||||
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
|
||||
const modelFolderName = 'models'
|
||||
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
|
||||
const systemInfo = await systemInformation()
|
||||
const res = await executeOnMain(
|
||||
this.nodeModule,
|
||||
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
)
|
||||
|
||||
if (res?.error) {
|
||||
events.emit(ModelEvent.OnModelFail, {
|
||||
...model,
|
||||
error: res.error,
|
||||
})
|
||||
return
|
||||
events.emit(ModelEvent.OnModelFail, { error: res.error })
|
||||
return Promise.reject(res.error)
|
||||
} else {
|
||||
this.loadedModel = model
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
unloadModel(model: Model) {
|
||||
if (model.engine && model.engine?.toString() !== this.provider) return
|
||||
this.loadedModel = undefined
|
||||
override async unloadModel(model?: Model): Promise<void> {
|
||||
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
|
||||
|
||||
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||
this.loadedModel = undefined
|
||||
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||
events.emit(ModelEvent.OnModelStopped, {})
|
||||
})
|
||||
}
|
||||
@ -13,7 +13,7 @@ import {
|
||||
ModelInfo,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
} from '../../types'
|
||||
} from '../../../types'
|
||||
import { events } from '../../events'
|
||||
|
||||
/**
|
||||
@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
||||
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
|
||||
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* On extension unload
|
||||
*/
|
||||
onUnload(): void {}
|
||||
override onUnload(): void {}
|
||||
|
||||
/*
|
||||
* Inference request
|
||||
*/
|
||||
inference(data: MessageRequest) {
|
||||
override inference(data: MessageRequest) {
|
||||
if (data.model?.engine?.toString() !== this.provider) return
|
||||
|
||||
const timestamp = Date.now()
|
||||
@ -106,6 +106,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
return
|
||||
}
|
||||
message.status = MessageStatus.Error
|
||||
message.error_code = err.code
|
||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||
},
|
||||
})
|
||||
@ -114,7 +115,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* Stops the inference.
|
||||
*/
|
||||
stopInference() {
|
||||
override stopInference() {
|
||||
this.isCancelled = true
|
||||
this.controller?.abort()
|
||||
}
|
||||
26
core/src/browser/extensions/engines/RemoteOAIEngine.ts
Normal file
26
core/src/browser/extensions/engines/RemoteOAIEngine.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
* Base OAI Remote Inference Provider
|
||||
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||
*/
|
||||
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||
// The inference engine
|
||||
abstract apiKey: string
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
}
|
||||
|
||||
/**
|
||||
* Headers for the inference request
|
||||
*/
|
||||
override headers(): HeadersInit {
|
||||
return {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'api-key': `${this.apiKey}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
import { Observable } from 'rxjs'
|
||||
import { ModelRuntimeParams } from '../../../types'
|
||||
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
|
||||
/**
|
||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||
@ -34,6 +34,16 @@ export function requestInference(
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
if (!response.ok) {
|
||||
const data = await response.json()
|
||||
const error = {
|
||||
message: data.error?.message ?? 'Error occurred.',
|
||||
code: data.error?.code ?? ErrorCode.Unknown,
|
||||
}
|
||||
subscriber.error(error)
|
||||
subscriber.complete()
|
||||
return
|
||||
}
|
||||
if (model.parameters.stream === false) {
|
||||
const data = await response.json()
|
||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
||||
@ -2,3 +2,4 @@ export * from './AIEngine'
|
||||
export * from './OAIEngine'
|
||||
export * from './LocalOAIEngine'
|
||||
export * from './RemoteOAIEngine'
|
||||
export * from './EngineManager'
|
||||
@ -1,6 +1,6 @@
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../types/huggingface'
|
||||
import { Model } from '../types/model'
|
||||
import { HuggingFaceInterface, HuggingFaceRepoData, Quantization } from '../../types/huggingface'
|
||||
import { Model } from '../../types/model'
|
||||
|
||||
/**
|
||||
* Hugging Face extension for converting HF models to GGUF.
|
||||
@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
|
||||
/**
|
||||
* Base AI Engines.
|
||||
*/
|
||||
export * from './ai-engines'
|
||||
export * from './engines'
|
||||
@ -1,4 +1,4 @@
|
||||
import { InferenceInterface, MessageRequest, ThreadMessage } from '../index'
|
||||
import { InferenceInterface, MessageRequest, ThreadMessage } from '../../types'
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
|
||||
/**
|
||||
@ -1,5 +1,5 @@
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../index'
|
||||
import { GpuSetting, ImportingModel, Model, ModelInterface, OptionType } from '../../types'
|
||||
|
||||
/**
|
||||
* Model extension for managing models.
|
||||
@ -1,5 +1,5 @@
|
||||
import { BaseExtension, ExtensionTypeEnum } from '../extension'
|
||||
import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../index'
|
||||
import { GpuSetting, MonitoringInterface, OperatingSystemInfo } from '../../types'
|
||||
|
||||
/**
|
||||
* Monitoring extension for system monitoring.
|
||||
@ -1,4 +1,4 @@
|
||||
import { FileStat } from './types'
|
||||
import { FileStat } from '../types'
|
||||
|
||||
/**
|
||||
* Writes data to a file at the specified path.
|
||||
35
core/src/browser/index.ts
Normal file
35
core/src/browser/index.ts
Normal file
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Export Core module
|
||||
* @module
|
||||
*/
|
||||
export * from './core'
|
||||
|
||||
/**
|
||||
* Export Event module.
|
||||
* @module
|
||||
*/
|
||||
export * from './events'
|
||||
|
||||
/**
|
||||
* Export Filesystem module.
|
||||
* @module
|
||||
*/
|
||||
export * from './fs'
|
||||
|
||||
/**
|
||||
* Export Extension module.
|
||||
* @module
|
||||
*/
|
||||
export * from './extension'
|
||||
|
||||
/**
|
||||
* Export all base extensions.
|
||||
* @module
|
||||
*/
|
||||
export * from './extensions'
|
||||
|
||||
/**
|
||||
* Export all base tools.
|
||||
* @module
|
||||
*/
|
||||
export * from './tools'
|
||||
2
core/src/browser/tools/index.ts
Normal file
2
core/src/browser/tools/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export * from './manager'
|
||||
export * from './tool'
|
||||
47
core/src/browser/tools/manager.ts
Normal file
47
core/src/browser/tools/manager.ts
Normal file
@ -0,0 +1,47 @@
|
||||
import { AssistantTool, MessageRequest } from '../../types'
|
||||
import { InferenceTool } from './tool'
|
||||
|
||||
/**
|
||||
* Manages the registration and retrieval of inference tools.
|
||||
*/
|
||||
export class ToolManager {
|
||||
public tools = new Map<string, InferenceTool>()
|
||||
|
||||
/**
|
||||
* Registers a tool.
|
||||
* @param tool - The tool to register.
|
||||
*/
|
||||
register<T extends InferenceTool>(tool: T) {
|
||||
this.tools.set(tool.name, tool)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a tool by it's name.
|
||||
* @param name - The name of the tool to retrieve.
|
||||
* @returns The tool, if found.
|
||||
*/
|
||||
get<T extends InferenceTool>(name: string): T | undefined {
|
||||
return this.tools.get(name) as T | undefined
|
||||
}
|
||||
|
||||
/*
|
||||
** Process the message request with the tools.
|
||||
*/
|
||||
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
|
||||
return tools.reduce((prevPromise, currentTool) => {
|
||||
return prevPromise.then((prevResult) => {
|
||||
return currentTool.enabled
|
||||
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
|
||||
Promise.resolve(prevResult)
|
||||
: Promise.resolve(prevResult)
|
||||
})
|
||||
}, Promise.resolve(request))
|
||||
}
|
||||
|
||||
/**
|
||||
* The instance of the tool manager.
|
||||
*/
|
||||
static instance(): ToolManager {
|
||||
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
|
||||
}
|
||||
}
|
||||
12
core/src/browser/tools/tool.ts
Normal file
12
core/src/browser/tools/tool.ts
Normal file
@ -0,0 +1,12 @@
|
||||
import { AssistantTool, MessageRequest } from '../../types'
|
||||
|
||||
/**
|
||||
* Represents a base inference tool.
|
||||
*/
|
||||
export abstract class InferenceTool {
|
||||
abstract name: string
|
||||
/*
|
||||
** Process a message request and return the processed message request.
|
||||
*/
|
||||
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
|
||||
}
|
||||
@ -1,46 +0,0 @@
|
||||
import { events } from '../../events'
|
||||
import { Model, ModelEvent } from '../../types'
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
* Base OAI Remote Inference Provider
|
||||
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||
*/
|
||||
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||
// The inference engine
|
||||
abstract apiKey: string
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
super.onLoad()
|
||||
// These events are applicable to local inference providers
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
async loadModel(model: Model) {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
unloadModel(model: Model) {
|
||||
if (model.engine && model.engine.toString() !== this.provider) return
|
||||
events.emit(ModelEvent.OnModelStopped, {})
|
||||
}
|
||||
|
||||
/**
|
||||
* Headers for the inference request
|
||||
*/
|
||||
override headers(): HeadersInit {
|
||||
return {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'api-key': `${this.apiKey}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,42 +2,13 @@
|
||||
* Export all types.
|
||||
* @module
|
||||
*/
|
||||
export * from './types/index'
|
||||
export * from './types'
|
||||
|
||||
/**
|
||||
* Export all routes
|
||||
*/
|
||||
export * from './api'
|
||||
|
||||
/**
|
||||
* Export Core module
|
||||
* Export browser module
|
||||
* @module
|
||||
*/
|
||||
export * from './core'
|
||||
|
||||
/**
|
||||
* Export Event module.
|
||||
* @module
|
||||
*/
|
||||
export * from './events'
|
||||
|
||||
/**
|
||||
* Export Filesystem module.
|
||||
* @module
|
||||
*/
|
||||
export * from './fs'
|
||||
|
||||
/**
|
||||
* Export Extension module.
|
||||
* @module
|
||||
*/
|
||||
export * from './extension'
|
||||
|
||||
/**
|
||||
* Export all base extensions.
|
||||
* @module
|
||||
*/
|
||||
export * from './extensions/index'
|
||||
export * from './browser'
|
||||
|
||||
/**
|
||||
* Declare global object
|
||||
|
||||
@ -4,7 +4,7 @@ import {
|
||||
ExtensionRoute,
|
||||
FileManagerRoute,
|
||||
FileSystemRoute,
|
||||
} from '../../../api'
|
||||
} from '../../../types/api'
|
||||
import { Downloader } from '../processors/download'
|
||||
import { FileSystem } from '../processors/fs'
|
||||
import { Extension } from '../processors/extension'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { CoreRoutes } from '../../../api'
|
||||
import { CoreRoutes } from '../../../types/api'
|
||||
import { RequestAdapter } from './adapter'
|
||||
|
||||
export type Handler = (route: string, args: any) => any
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { resolve, sep } from 'path'
|
||||
import { DownloadEvent } from '../../../api'
|
||||
import { DownloadEvent } from '../../../types/api'
|
||||
import { normalizeFilePath } from '../../helper/path'
|
||||
import { getJanDataFolderPath } from '../../helper'
|
||||
import { DownloadManager } from '../../helper/download'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { DownloadRoute } from '../../../../api'
|
||||
import { DownloadRoute } from '../../../../types/api'
|
||||
import { DownloadManager } from '../../../helper/download'
|
||||
import { HttpServer } from '../../HttpServer'
|
||||
|
||||
|
||||
@ -5,4 +5,4 @@ export * from './extension/store'
|
||||
export * from './api'
|
||||
export * from './helper'
|
||||
export * from './../types'
|
||||
export * from './../api'
|
||||
export * from '../types/api'
|
||||
|
||||
@ -8,3 +8,4 @@ export * from './file'
|
||||
export * from './config'
|
||||
export * from './huggingface'
|
||||
export * from './miscellaneous'
|
||||
export * from './api'
|
||||
|
||||
@ -7,7 +7,6 @@ export type ModelInfo = {
|
||||
settings: ModelSettingParams
|
||||
parameters: ModelRuntimeParams
|
||||
engine?: InferenceEngine
|
||||
proxy_model?: InferenceEngine
|
||||
}
|
||||
|
||||
/**
|
||||
@ -21,8 +20,6 @@ export enum InferenceEngine {
|
||||
groq = 'groq',
|
||||
triton_trtllm = 'triton_trtllm',
|
||||
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
|
||||
|
||||
tool_retrieval_enabled = 'tool_retrieval_enabled',
|
||||
}
|
||||
|
||||
export type ModelArtifact = {
|
||||
@ -94,8 +91,6 @@ export type Model = {
|
||||
* The model engine.
|
||||
*/
|
||||
engine: InferenceEngine
|
||||
|
||||
proxy_model?: InferenceEngine
|
||||
}
|
||||
|
||||
export type ModelMetadata = {
|
||||
|
||||
@ -1,26 +1,21 @@
|
||||
import {
|
||||
fs,
|
||||
Assistant,
|
||||
MessageRequest,
|
||||
events,
|
||||
InferenceEngine,
|
||||
MessageEvent,
|
||||
InferenceEvent,
|
||||
joinPath,
|
||||
executeOnMain,
|
||||
AssistantExtension,
|
||||
AssistantEvent,
|
||||
ToolManager,
|
||||
} from '@janhq/core'
|
||||
import { RetrievalTool } from './tools/retrieval'
|
||||
|
||||
export default class JanAssistantExtension extends AssistantExtension {
|
||||
private static readonly _homeDir = 'file://assistants'
|
||||
private static readonly _threadDir = 'file://threads'
|
||||
|
||||
controller = new AbortController()
|
||||
isCancelled = false
|
||||
retrievalThreadId: string | undefined = undefined
|
||||
|
||||
async onLoad() {
|
||||
// Register the retrieval tool
|
||||
ToolManager.instance().register(new RetrievalTool())
|
||||
|
||||
// making the assistant directory
|
||||
const assistantDirExist = await fs.existsSync(
|
||||
JanAssistantExtension._homeDir
|
||||
@ -38,140 +33,6 @@ export default class JanAssistantExtension extends AssistantExtension {
|
||||
// Update the assistant list
|
||||
events.emit(AssistantEvent.OnAssistantsUpdate, {})
|
||||
}
|
||||
|
||||
// Events subscription
|
||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
|
||||
JanAssistantExtension.handleMessageRequest(data, this)
|
||||
)
|
||||
|
||||
events.on(InferenceEvent.OnInferenceStopped, () => {
|
||||
JanAssistantExtension.handleInferenceStopped(this)
|
||||
})
|
||||
}
|
||||
|
||||
private static async handleInferenceStopped(instance: JanAssistantExtension) {
|
||||
instance.isCancelled = true
|
||||
instance.controller?.abort()
|
||||
}
|
||||
|
||||
private static async handleMessageRequest(
|
||||
data: MessageRequest,
|
||||
instance: JanAssistantExtension
|
||||
) {
|
||||
instance.isCancelled = false
|
||||
instance.controller = new AbortController()
|
||||
|
||||
if (
|
||||
data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
|
||||
!data.messages ||
|
||||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
|
||||
// That could lead to an issue where thread stuck at generating response
|
||||
!data.thread?.assistants[0]?.tools
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
const latestMessage = data.messages[data.messages.length - 1]
|
||||
|
||||
// 1. Ingest the document if needed
|
||||
if (
|
||||
latestMessage &&
|
||||
latestMessage.content &&
|
||||
typeof latestMessage.content !== 'string' &&
|
||||
latestMessage.content.length > 1
|
||||
) {
|
||||
const docFile = latestMessage.content[1]?.doc_url?.url
|
||||
if (docFile) {
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalIngestNewDocument',
|
||||
docFile,
|
||||
data.model?.proxy_model
|
||||
)
|
||||
}
|
||||
} else if (
|
||||
// Check whether we need to ingest document or not
|
||||
// Otherwise wrong context will be sent
|
||||
!(await fs.existsSync(
|
||||
await joinPath([
|
||||
JanAssistantExtension._threadDir,
|
||||
data.threadId,
|
||||
'memory',
|
||||
])
|
||||
))
|
||||
) {
|
||||
// No document ingested, reroute the result to inference engine
|
||||
const output = {
|
||||
...data,
|
||||
model: {
|
||||
...data.model,
|
||||
engine: data.model.proxy_model,
|
||||
},
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, output)
|
||||
return
|
||||
}
|
||||
// 2. Load agent on thread changed
|
||||
if (instance.retrievalThreadId !== data.threadId) {
|
||||
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
||||
|
||||
instance.retrievalThreadId = data.threadId
|
||||
|
||||
// Update the text splitter
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalUpdateTextSplitter',
|
||||
data.thread.assistants[0].tools[0]?.settings?.chunk_size ?? 4000,
|
||||
data.thread.assistants[0].tools[0]?.settings?.chunk_overlap ?? 200
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Using the retrieval template with the result and query
|
||||
if (latestMessage.content) {
|
||||
const prompt =
|
||||
typeof latestMessage.content === 'string'
|
||||
? latestMessage.content
|
||||
: latestMessage.content[0].text
|
||||
// Retrieve the result
|
||||
const retrievalResult = await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalQueryResult',
|
||||
prompt
|
||||
)
|
||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||
|
||||
// Update message content
|
||||
if (data.thread?.assistants[0]?.tools && retrievalResult)
|
||||
data.messages[data.messages.length - 1].content =
|
||||
data.thread.assistants[0].tools[0].settings?.retrieval_template
|
||||
?.replace('{CONTEXT}', retrievalResult)
|
||||
.replace('{QUESTION}', prompt)
|
||||
}
|
||||
|
||||
// Filter out all the messages that are not text
|
||||
data.messages = data.messages.map((message) => {
|
||||
if (
|
||||
message.content &&
|
||||
typeof message.content !== 'string' &&
|
||||
(message.content.length ?? 0) > 0
|
||||
) {
|
||||
return {
|
||||
...message,
|
||||
content: [message.content[0]],
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
|
||||
// 4. Reroute the result to inference engine
|
||||
const output = {
|
||||
...data,
|
||||
model: {
|
||||
...data.model,
|
||||
engine: data.model.proxy_model,
|
||||
},
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, output)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
108
extensions/assistant-extension/src/tools/retrieval.ts
Normal file
@ -0,0 +1,108 @@
|
||||
import {
|
||||
AssistantTool,
|
||||
executeOnMain,
|
||||
fs,
|
||||
InferenceTool,
|
||||
joinPath,
|
||||
MessageRequest,
|
||||
} from '@janhq/core'
|
||||
|
||||
export class RetrievalTool extends InferenceTool {
|
||||
private _threadDir = 'file://threads'
|
||||
private retrievalThreadId: string | undefined = undefined
|
||||
|
||||
name: string = 'retrieval'
|
||||
|
||||
async process(
|
||||
data: MessageRequest,
|
||||
tool?: AssistantTool
|
||||
): Promise<MessageRequest> {
|
||||
if (!data.model || !data.messages) {
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
|
||||
const latestMessage = data.messages[data.messages.length - 1]
|
||||
|
||||
// 1. Ingest the document if needed
|
||||
if (
|
||||
latestMessage &&
|
||||
latestMessage.content &&
|
||||
typeof latestMessage.content !== 'string' &&
|
||||
latestMessage.content.length > 1
|
||||
) {
|
||||
const docFile = latestMessage.content[1]?.doc_url?.url
|
||||
if (docFile) {
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalIngestNewDocument',
|
||||
docFile,
|
||||
data.model?.engine
|
||||
)
|
||||
}
|
||||
} else if (
|
||||
// Check whether we need to ingest document or not
|
||||
// Otherwise wrong context will be sent
|
||||
!(await fs.existsSync(
|
||||
await joinPath([this._threadDir, data.threadId, 'memory'])
|
||||
))
|
||||
) {
|
||||
// No document ingested, reroute the result to inference engine
|
||||
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
// 2. Load agent on thread changed
|
||||
if (this.retrievalThreadId !== data.threadId) {
|
||||
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
|
||||
|
||||
this.retrievalThreadId = data.threadId
|
||||
|
||||
// Update the text splitter
|
||||
await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalUpdateTextSplitter',
|
||||
tool?.settings?.chunk_size ?? 4000,
|
||||
tool?.settings?.chunk_overlap ?? 200
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Using the retrieval template with the result and query
|
||||
if (latestMessage.content) {
|
||||
const prompt =
|
||||
typeof latestMessage.content === 'string'
|
||||
? latestMessage.content
|
||||
: latestMessage.content[0].text
|
||||
// Retrieve the result
|
||||
const retrievalResult = await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalQueryResult',
|
||||
prompt
|
||||
)
|
||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||
|
||||
// Update message content
|
||||
if (retrievalResult)
|
||||
data.messages[data.messages.length - 1].content =
|
||||
tool?.settings?.retrieval_template
|
||||
?.replace('{CONTEXT}', retrievalResult)
|
||||
.replace('{QUESTION}', prompt)
|
||||
}
|
||||
|
||||
// Filter out all the messages that are not text
|
||||
data.messages = data.messages.map((message) => {
|
||||
if (
|
||||
message.content &&
|
||||
typeof message.content !== 'string' &&
|
||||
(message.content.length ?? 0) > 0
|
||||
) {
|
||||
return {
|
||||
...message,
|
||||
content: [message.content[0]],
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
|
||||
// 4. Reroute the result to inference engine
|
||||
return Promise.resolve(data)
|
||||
}
|
||||
}
|
||||
@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
||||
return super.loadModel(model)
|
||||
}
|
||||
|
||||
override unloadModel(model: Model): void {
|
||||
super.unloadModel(model)
|
||||
|
||||
if (model.engine && model.engine !== this.provider) return
|
||||
override async unloadModel(model?: Model) {
|
||||
if (model?.engine && model.engine !== this.provider) return
|
||||
|
||||
// stop the periocally health check
|
||||
if (this.getNitroProcesHealthIntervalId) {
|
||||
clearInterval(this.getNitroProcesHealthIntervalId)
|
||||
this.getNitroProcesHealthIntervalId = undefined
|
||||
}
|
||||
return super.unloadModel(model)
|
||||
}
|
||||
}
|
||||
|
||||
@ -271,26 +271,7 @@ const DropdownListSidebar = ({
|
||||
)}
|
||||
>
|
||||
<div className="relative flex w-full justify-between">
|
||||
{x.engine === InferenceEngine.openai && (
|
||||
<svg
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="absolute top-1"
|
||||
>
|
||||
<path
|
||||
d="M18.5681 8.18423C18.7917 7.51079 18.8691 6.79739 18.795 6.09168C18.7209 5.38596 18.497 4.70419 18.1384 4.0919C17.6067 3.16642 16.7948 2.43369 15.8199 1.99936C14.8449 1.56503 13.7572 1.45153 12.7135 1.67523C12.1206 1.0157 11.3646 0.523789 10.5214 0.248906C9.67823 -0.0259764 8.77756 -0.0741542 7.90986 0.109212C7.04216 0.292577 6.23798 0.701031 5.57809 1.29355C4.91821 1.88607 4.42584 2.64179 4.15046 3.48481C3.45518 3.62739 2.79834 3.91672 2.22384 4.33347C1.64933 4.75023 1.1704 5.28481 0.81904 5.90148C0.281569 6.82542 0.0518576 7.89634 0.163116 8.95943C0.274374 10.0225 0.720837 11.0227 1.43796 11.8153C1.21351 12.4884 1.13539 13.2017 1.20883 13.9074C1.28227 14.6132 1.50557 15.2951 1.86379 15.9076C2.39616 16.8334 3.20872 17.5663 4.18438 18.0006C5.16004 18.4349 6.24841 18.5483 7.29262 18.3243C7.76367 18.8548 8.34248 19.2786 8.99038 19.5676C9.63828 19.8566 10.3404 20.004 11.0498 20C12.1195 20.001 13.1618 19.662 14.0263 19.032C14.8909 18.4021 15.5329 17.5137 15.8596 16.4951C16.5548 16.3523 17.2116 16.0629 17.786 15.6461C18.3605 15.2294 18.8395 14.6949 19.191 14.0784C19.7222 13.1558 19.9479 12.0889 19.836 11.0303C19.7242 9.97163 19.2804 8.9754 18.5681 8.18423ZM11.0498 18.691C10.1737 18.6924 9.32512 18.3853 8.65279 17.8236L8.77104 17.7566L12.753 15.4581C12.8521 15.4 12.9343 15.3171 12.9917 15.2176C13.0491 15.118 13.0796 15.0053 13.0802 14.8904V9.27631L14.7635 10.2501C14.7719 10.2544 14.7791 10.2605 14.7846 10.268C14.7901 10.2755 14.7937 10.2843 14.7952 10.2935V14.9456C14.7931 15.9383 14.3978 16.8898 13.6959 17.5917C12.9939 18.2936 12.0425 18.6889 11.0498 18.691ZM2.99921 15.2531C2.55985 14.4945 2.4021 13.6052 2.55371 12.7417L2.67204 12.8127L6.65787 15.1112C6.7565 15.1691 6.86877 15.1996 6.98312 15.1996C7.09747 15.1996 7.20975 15.1691 7.30837 15.1112L12.1774 12.3041V14.2478C12.1769 14.2579 12.1742 14.2677 12.1694 14.2766C12.1646 14.2855 12.1579 14.2932 12.1497 14.2991L8.11654 16.6251C7.25581 17.121 6.2335 17.255 5.27405 16.9978C4.3146 16.7405 3.49644 16.1131 2.99921 15.2531ZM1.95054 6.57965C2.39294 5.81612 3.09123 5.23375 3.92179 4.93565V9.66665C3.92029 9.78094 3.94949 9.89355 4.00635 9.99271C4.06321 10.0919 4.14564 10.174 4.24504 10.2304L9.09037 13.0256L7.40696 13.9994C7.39785 14.0042 7.38769 14.0068 7.37737 14.0068C7.36706 14.0068 7.3569 14.0042 7.34779 13.9994L3.32254 11.6773C2.46343 11.1793 1.83666 10.3612 1.57951 9.40204C1.32236 8.44291 1.45577 7.42095 1.95054 6.55998V6.57965ZM15.7808 9.79281L10.9197 6.96998L12.5992 5.99998C12.6083 5.99514 12.6185 5.99261 12.6288 5.99261C12.6391 5.99261 12.6493 5.99514 12.6584 5.99998L16.6836 8.32606C17.2991 8.68119 17.8008 9.20407 18.1303 9.83365C18.4597 10.4632 18.6032 11.1735 18.5441 11.8816C18.485 12.5898 18.2257 13.2664 17.7964 13.8327C17.3672 14.3989 16.7857 14.8314 16.1199 15.0796V10.3486C16.1164 10.2345 16.0833 10.1232 16.0238 10.0258C15.9644 9.92833 15.8807 9.8481 15.7808 9.79281ZM17.4564 7.27356L17.338 7.20256L13.3601 4.8844C13.2609 4.82617 13.1479 4.79547 13.0329 4.79547C12.9178 4.79547 12.8049 4.82617 12.7056 4.8844L7.84071 7.6914V5.74781C7.83967 5.73793 7.84132 5.72795 7.84549 5.71893C7.84965 5.70991 7.85618 5.70218 7.86437 5.69656L11.8896 3.3744C12.5066 3.01899 13.2119 2.84659 13.9232 2.87736C14.6345 2.90813 15.3224 3.14079 15.9063 3.54813C16.4903 3.95548 16.9461 4.52066 17.2206 5.17759C17.4952 5.83452 17.577 6.55602 17.4565 7.25773L17.4564 7.27356ZM6.92196 10.7191L5.23862 9.74931C5.2302 9.74424 5.223 9.73738 5.21753 9.72921C5.21205 9.72105 5.20845 9.71178 5.20696 9.70206V5.06181C5.20788 4.34996 5.41144 3.65307 5.79383 3.05265C6.17622 2.45222 6.72164 1.97305 7.36632 1.67118C8.011 1.3693 8.7283 1.2572 9.43434 1.34796C10.1404 1.43873 10.806 1.72861 11.3534 2.18373L11.235 2.25081L7.25321 4.54915C7.1541 4.60727 7.07182 4.69017 7.01445 4.78971C6.95707 4.88925 6.92658 5.00201 6.92596 5.1169L6.92196 10.7191ZM7.83662 8.74798L10.005 7.49815L12.1774 8.74798V11.2475L10.0129 12.4972L7.84062 11.2475L7.83662 8.74798Z"
|
||||
fill="#18181B"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
<div
|
||||
className={twMerge(
|
||||
x.engine === InferenceEngine.openai && 'pl-8'
|
||||
)}
|
||||
>
|
||||
<div>
|
||||
<span className="line-clamp-1 block">
|
||||
{x.name}
|
||||
</span>
|
||||
@ -307,8 +288,7 @@ const DropdownListSidebar = ({
|
||||
</SelectItem>
|
||||
<div
|
||||
className={twMerge(
|
||||
'absolute -mt-6 inline-flex items-center space-x-2 px-4 pb-2 text-muted-foreground',
|
||||
x.engine === InferenceEngine.openai && 'left-8'
|
||||
'absolute -mt-6 inline-flex items-center space-x-2 px-4 pb-2 text-muted-foreground'
|
||||
)}
|
||||
>
|
||||
<span className="text-xs">{x.id}</span>
|
||||
|
||||
@ -8,26 +8,17 @@ import {
|
||||
ExtensionTypeEnum,
|
||||
MessageStatus,
|
||||
MessageRequest,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
MessageEvent,
|
||||
MessageRequestType,
|
||||
ModelEvent,
|
||||
Thread,
|
||||
ModelInitFailed,
|
||||
EngineManager,
|
||||
} from '@janhq/core'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import {
|
||||
activeModelAtom,
|
||||
loadModelErrorAtom,
|
||||
stateModelAtom,
|
||||
} from '@/hooks/useActiveModel'
|
||||
|
||||
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
|
||||
|
||||
import { toaster } from '../Toast'
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
const activeModel = useAtomValue(activeModelAtom)
|
||||
const setActiveModel = useSetAtom(activeModelAtom)
|
||||
const setStateModel = useSetAtom(stateModelAtom)
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
||||
|
||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
[addNewMessage]
|
||||
)
|
||||
|
||||
const onModelReady = useCallback(
|
||||
(model: Model) => {
|
||||
setActiveModel(model)
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been started.`,
|
||||
type: 'success',
|
||||
})
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
},
|
||||
[setActiveModel, setStateModel]
|
||||
)
|
||||
|
||||
const onModelStopped = useCallback(() => {
|
||||
setTimeout(() => {
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
}, 500)
|
||||
}, [setActiveModel, setStateModel])
|
||||
|
||||
const onModelInitFailed = useCallback(
|
||||
(res: ModelInitFailed) => {
|
||||
console.error('Failed to load model: ', res.error.message)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: res.id,
|
||||
}))
|
||||
setLoadModelError(res.error.message)
|
||||
setQueuedMessage(false)
|
||||
},
|
||||
[setStateModel, setQueuedMessage, setLoadModelError]
|
||||
)
|
||||
|
||||
const updateThreadTitle = useCallback(
|
||||
(message: ThreadMessage) => {
|
||||
// Update only when it's finished
|
||||
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
|
||||
// 2. Update the title with the result of the inference
|
||||
setTimeout(() => {
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
const engine = EngineManager.instance().get(
|
||||
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
}, 1000)
|
||||
}
|
||||
}
|
||||
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
if (window.core?.events) {
|
||||
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||
events.on(ModelEvent.OnModelReady, onModelReady)
|
||||
events.on(ModelEvent.OnModelFail, onModelInitFailed)
|
||||
events.on(ModelEvent.OnModelStopped, onModelStopped)
|
||||
}
|
||||
}, [
|
||||
onNewMessageResponse,
|
||||
onMessageResponseUpdate,
|
||||
onModelReady,
|
||||
onModelInitFailed,
|
||||
onModelStopped,
|
||||
])
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||
events.off(ModelEvent.OnModelStopped, onModelStopped)
|
||||
}
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate])
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||
return <Fragment>{children}</Fragment>
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||
import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||
|
||||
import Extension from './Extension'
|
||||
|
||||
@ -8,14 +8,26 @@ import Extension from './Extension'
|
||||
* Manages the registration and retrieval of extensions.
|
||||
*/
|
||||
export class ExtensionManager {
|
||||
// Registered extensions
|
||||
private extensions = new Map<string, BaseExtension>()
|
||||
|
||||
// Registered inference engines
|
||||
private engines = new Map<string, AIEngine>()
|
||||
|
||||
/**
|
||||
* Registers an extension.
|
||||
* @param extension - The extension to register.
|
||||
*/
|
||||
register<T extends BaseExtension>(name: string, extension: T) {
|
||||
this.extensions.set(extension.type() ?? name, extension)
|
||||
|
||||
// Register AI Engines
|
||||
if ('provider' in extension && typeof extension.provider === 'string') {
|
||||
this.engines.set(
|
||||
extension.provider as unknown as string,
|
||||
extension as unknown as AIEngine
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -29,6 +41,15 @@ export class ExtensionManager {
|
||||
return this.extensions.get(type) as T | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a extension by its type.
|
||||
* @param engine - The engine name to retrieve.
|
||||
* @returns The extension, if found.
|
||||
*/
|
||||
getEngine<T extends AIEngine>(engine: string): T | undefined {
|
||||
return this.engines.get(engine) as T | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads all registered extension.
|
||||
*/
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
|
||||
import { events, Model, ModelEvent } from '@janhq/core'
|
||||
import { EngineManager, Model } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
@ -38,19 +38,13 @@ export function useActiveModel() {
|
||||
(stateModel.model === modelId && stateModel.loading)
|
||||
) {
|
||||
console.debug(`Model ${modelId} is already initialized. Ignore..`)
|
||||
return
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
|
||||
|
||||
// Switch between engines
|
||||
if (model && activeModel && activeModel.engine !== model.engine) {
|
||||
stopModel()
|
||||
// TODO: Refactor inference provider would address this
|
||||
await new Promise((res) => setTimeout(res, 1000))
|
||||
}
|
||||
await stopModel().catch()
|
||||
|
||||
// TODO: incase we have multiple assistants, the configuration will be from assistant
|
||||
setLoadModelError(undefined)
|
||||
|
||||
setActiveModel(undefined)
|
||||
@ -68,7 +62,8 @@ export function useActiveModel() {
|
||||
loading: false,
|
||||
model: '',
|
||||
}))
|
||||
return
|
||||
|
||||
return Promise.reject(`Model ${modelId} not found!`)
|
||||
}
|
||||
|
||||
/// Apply thread model settings
|
||||
@ -83,15 +78,52 @@ export function useActiveModel() {
|
||||
}
|
||||
|
||||
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
||||
events.emit(ModelEvent.OnModelInit, model)
|
||||
const engine = EngineManager.instance().get(model.engine)
|
||||
return engine
|
||||
?.loadModel(model)
|
||||
.then(() => {
|
||||
setActiveModel(model)
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been started.`,
|
||||
type: 'success',
|
||||
})
|
||||
})
|
||||
.catch((error) => {
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
|
||||
toaster({
|
||||
title: 'Failed!',
|
||||
description: `Model ${model.id} failed to start.`,
|
||||
type: 'success',
|
||||
})
|
||||
setLoadModelError(error)
|
||||
return Promise.reject(error)
|
||||
})
|
||||
}
|
||||
|
||||
const stopModel = useCallback(async () => {
|
||||
if (activeModel) {
|
||||
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
||||
events.emit(ModelEvent.OnModelStop, activeModel)
|
||||
const engine = EngineManager.instance().get(activeModel.engine)
|
||||
await engine
|
||||
?.unloadModel(activeModel)
|
||||
.catch()
|
||||
.then(() => {
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
})
|
||||
}
|
||||
}, [activeModel, setStateModel])
|
||||
}, [activeModel, setActiveModel, setStateModel])
|
||||
|
||||
return { activeModel, startModel, stopModel, stateModel }
|
||||
}
|
||||
|
||||
@ -2,27 +2,18 @@
|
||||
import { useEffect, useRef } from 'react'
|
||||
|
||||
import {
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
MessageRequest,
|
||||
MessageRequestType,
|
||||
MessageStatus,
|
||||
ExtensionTypeEnum,
|
||||
Thread,
|
||||
ThreadMessage,
|
||||
events,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
MessageEvent,
|
||||
InferenceEngine,
|
||||
ChatCompletionMessageContentType,
|
||||
AssistantTool,
|
||||
EngineManager,
|
||||
ToolManager,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import { selectedModelAtom } from '@/containers/DropdownListSidebar'
|
||||
import {
|
||||
currentPromptAtom,
|
||||
@ -31,8 +22,11 @@ import {
|
||||
} from '@/containers/Providers/Jotai'
|
||||
|
||||
import { compressImage, getBase64 } from '@/utils/base64'
|
||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
|
||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
|
||||
import { loadModelErrorAtom, useActiveModel } from './useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
@ -65,7 +59,6 @@ export default function useSendChatMessage() {
|
||||
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const { activeModel, startModel } = useActiveModel()
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
const loadModelFailed = useAtomValue(loadModelErrorAtom)
|
||||
|
||||
const modelRef = useRef<Model | undefined>()
|
||||
@ -78,6 +71,7 @@ export default function useSendChatMessage() {
|
||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||
const activeThreadRef = useRef<Thread | undefined>()
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
|
||||
const selectedModelRef = useRef<Model | undefined>()
|
||||
|
||||
@ -103,51 +97,27 @@ export default function useSendChatMessage() {
|
||||
return
|
||||
}
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
const messages: ChatCompletionMessage[] = [
|
||||
activeThreadRef.current.assistants[0]?.instructions,
|
||||
]
|
||||
.filter((e) => e && e.trim() !== '')
|
||||
.map<ChatCompletionMessage>((instructions) => {
|
||||
const systemMessage: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.System,
|
||||
content: instructions,
|
||||
}
|
||||
return systemMessage
|
||||
})
|
||||
.concat(
|
||||
currentMessages
|
||||
.filter(
|
||||
(e) =>
|
||||
(currentMessage.role === ChatCompletionRole.User ||
|
||||
e.id !== currentMessage.id) &&
|
||||
e.status !== MessageStatus.Error
|
||||
)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
)
|
||||
|
||||
const messageRequest: MessageRequest = {
|
||||
id: ulid(),
|
||||
type: MessageRequestType.Thread,
|
||||
messages: messages,
|
||||
threadId: activeThreadRef.current.id,
|
||||
model:
|
||||
const requestBuilder = new MessageRequestBuilder(
|
||||
MessageRequestType.Thread,
|
||||
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
|
||||
}
|
||||
activeThreadRef.current,
|
||||
currentMessages
|
||||
).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions)
|
||||
|
||||
const modelId =
|
||||
selectedModelRef.current?.id ??
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
startModel(modelId)
|
||||
await waitForModelStarting(modelId)
|
||||
setQueuedMessage(false)
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
if (error) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setIsGeneratingResponse(true)
|
||||
|
||||
if (currentMessage.role !== ChatCompletionRole.User) {
|
||||
// Delete last response before regenerating
|
||||
deleteMessage(currentMessage.id ?? '')
|
||||
@ -160,9 +130,22 @@ export default function useSendChatMessage() {
|
||||
)
|
||||
}
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
// Process message request with Assistants tools
|
||||
const request = await ToolManager.instance().process(
|
||||
requestBuilder.build(),
|
||||
activeThreadRef.current.assistants?.flatMap(
|
||||
(assistant) => assistant.tools ?? []
|
||||
) ?? []
|
||||
)
|
||||
|
||||
const engine =
|
||||
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
|
||||
EngineManager.instance().get(engine)?.inference(request)
|
||||
}
|
||||
|
||||
// Define interface extending Array prototype
|
||||
|
||||
const sendChatMessage = async (message: string) => {
|
||||
if (!message || message.trim().length === 0) return
|
||||
|
||||
@ -176,8 +159,9 @@ export default function useSendChatMessage() {
|
||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
||||
const settingParams = toSettingParams(activeModelParams)
|
||||
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
const prompt = message.trim()
|
||||
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
setCurrentPrompt('')
|
||||
setEditPrompt('')
|
||||
|
||||
@ -185,69 +169,12 @@ export default function useSendChatMessage() {
|
||||
? await getBase64(fileUpload[0].file)
|
||||
: undefined
|
||||
|
||||
const fileContentType = fileUpload[0]?.type
|
||||
|
||||
const msgId = ulid()
|
||||
|
||||
const isDocumentInput = base64Blob && fileContentType === 'pdf'
|
||||
const isImageInput = base64Blob && fileContentType === 'image'
|
||||
|
||||
if (isImageInput && base64Blob) {
|
||||
if (base64Blob && fileUpload[0]?.type === 'image') {
|
||||
// Compress image
|
||||
base64Blob = await compressImage(base64Blob, 512)
|
||||
}
|
||||
|
||||
const messages: ChatCompletionMessage[] = [
|
||||
activeThreadRef.current.assistants[0]?.instructions,
|
||||
]
|
||||
.filter((e) => e && e.trim() !== '')
|
||||
.map<ChatCompletionMessage>((instructions) => {
|
||||
const systemMessage: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.System,
|
||||
content: instructions,
|
||||
}
|
||||
return systemMessage
|
||||
})
|
||||
.concat(
|
||||
currentMessages
|
||||
.filter((e) => e.status !== MessageStatus.Error)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
.concat([
|
||||
{
|
||||
role: ChatCompletionRole.User,
|
||||
content:
|
||||
selectedModelRef.current && base64Blob
|
||||
? [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
},
|
||||
isDocumentInput
|
||||
? {
|
||||
type: ChatCompletionMessageContentType.Doc,
|
||||
doc_url: {
|
||||
url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`,
|
||||
},
|
||||
}
|
||||
: null,
|
||||
isImageInput
|
||||
? {
|
||||
type: ChatCompletionMessageContentType.Image,
|
||||
image_url: {
|
||||
url: base64Blob,
|
||||
},
|
||||
}
|
||||
: null,
|
||||
].filter((e) => e !== null)
|
||||
: prompt,
|
||||
} as ChatCompletionMessage,
|
||||
])
|
||||
)
|
||||
|
||||
let modelRequest =
|
||||
const modelRequest =
|
||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
||||
|
||||
// Fallback support for previous broken threads
|
||||
@ -261,131 +188,83 @@ export default function useSendChatMessage() {
|
||||
if (runtimeParams.stream == null) {
|
||||
runtimeParams.stream = true
|
||||
}
|
||||
// Add middleware to the model request with tool retrieval enabled
|
||||
if (
|
||||
activeThreadRef.current.assistants[0].tools?.some(
|
||||
(tool: AssistantTool) => tool.type === 'retrieval' && tool.enabled
|
||||
)
|
||||
) {
|
||||
modelRequest = {
|
||||
...modelRequest,
|
||||
// Tool retrieval support document input only for now
|
||||
...(isDocumentInput
|
||||
? {
|
||||
engine: InferenceEngine.tool_retrieval_enabled,
|
||||
proxy_model: modelRequest.engine,
|
||||
}
|
||||
: {}),
|
||||
}
|
||||
}
|
||||
const messageRequest: MessageRequest = {
|
||||
id: msgId,
|
||||
type: MessageRequestType.Thread,
|
||||
threadId: activeThreadRef.current.id,
|
||||
messages,
|
||||
model: {
|
||||
|
||||
// Build Message Request
|
||||
const requestBuilder = new MessageRequestBuilder(
|
||||
MessageRequestType.Thread,
|
||||
{
|
||||
...modelRequest,
|
||||
settings: settingParams,
|
||||
parameters: runtimeParams,
|
||||
},
|
||||
thread: activeThreadRef.current,
|
||||
}
|
||||
activeThreadRef.current,
|
||||
currentMessages
|
||||
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
|
||||
|
||||
const timestamp = Date.now()
|
||||
const content: any = []
|
||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
|
||||
|
||||
if (base64Blob && fileUpload[0]?.type === 'image') {
|
||||
content.push({
|
||||
type: ContentType.Image,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64Blob],
|
||||
},
|
||||
})
|
||||
}
|
||||
// Build Thread Message to persist
|
||||
const threadMessageBuilder = new ThreadMessageBuilder(
|
||||
requestBuilder
|
||||
).pushMessage(prompt, base64Blob, fileUpload)
|
||||
|
||||
if (base64Blob && fileUpload[0]?.type === 'pdf') {
|
||||
content.push({
|
||||
type: ContentType.Pdf,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64Blob],
|
||||
name: fileUpload[0].file.name,
|
||||
size: fileUpload[0].file.size,
|
||||
},
|
||||
})
|
||||
}
|
||||
const newMessage = threadMessageBuilder.build()
|
||||
|
||||
if (prompt && !base64Blob) {
|
||||
content.push({
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const threadMessage: ThreadMessage = {
|
||||
id: msgId,
|
||||
thread_id: activeThreadRef.current.id,
|
||||
role: ChatCompletionRole.User,
|
||||
status: MessageStatus.Ready,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: 'thread.message',
|
||||
content: content,
|
||||
}
|
||||
|
||||
addNewMessage(threadMessage)
|
||||
if (base64Blob) {
|
||||
setFileUpload([])
|
||||
}
|
||||
// Push to states
|
||||
addNewMessage(newMessage)
|
||||
|
||||
// Update thread state
|
||||
const updatedThread: Thread = {
|
||||
...activeThreadRef.current,
|
||||
updated: timestamp,
|
||||
updated: newMessage.created,
|
||||
metadata: {
|
||||
...(activeThreadRef.current.metadata ?? {}),
|
||||
lastMessage: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// change last update thread when send message
|
||||
updateThread(updatedThread)
|
||||
|
||||
// Add message
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.addNewMessage(threadMessage)
|
||||
?.addNewMessage(newMessage)
|
||||
|
||||
// Start Model if not started
|
||||
const modelId =
|
||||
selectedModelRef.current?.id ??
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
startModel(modelId)
|
||||
await waitForModelStarting(modelId)
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
setQueuedMessage(false)
|
||||
if (error) {
|
||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||
return
|
||||
}
|
||||
}
|
||||
setIsGeneratingResponse(true)
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
|
||||
// Process message request with Assistants tools
|
||||
const request = await ToolManager.instance().process(
|
||||
requestBuilder.build(),
|
||||
activeThreadRef.current.assistants?.flatMap(
|
||||
(assistant) => assistant.tools ?? []
|
||||
) ?? []
|
||||
)
|
||||
|
||||
// Request for inference
|
||||
EngineManager.instance()
|
||||
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
||||
?.inference(request)
|
||||
|
||||
// Reset states
|
||||
setReloadModel(false)
|
||||
setEngineParamsUpdate(false)
|
||||
}
|
||||
|
||||
const waitForModelStarting = async (modelId: string) => {
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(async () => {
|
||||
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
|
||||
await waitForModelStarting(modelId)
|
||||
resolve()
|
||||
} else {
|
||||
resolve()
|
||||
if (base64Blob) {
|
||||
setFileUpload([])
|
||||
}
|
||||
}, 200)
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
||||
</p>
|
||||
<ModalTroubleShooting />
|
||||
</div>
|
||||
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
||||
) : loadModelError &&
|
||||
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
||||
<div
|
||||
key={message.id}
|
||||
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"
|
||||
|
||||
@ -24,8 +24,8 @@ const FileUploadPreview: React.FC = () => {
|
||||
<div className="relative inline-flex w-60 space-x-3 rounded-lg bg-secondary p-4">
|
||||
<Icon type={fileUpload[0].type} />
|
||||
|
||||
<div>
|
||||
<h6 className="line-clamp-1 font-medium">
|
||||
<div className="w-full">
|
||||
<h6 className="line-clamp-1 w-3/4 truncate font-medium">
|
||||
{fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')}
|
||||
</h6>
|
||||
<p className="text-muted-foreground">
|
||||
|
||||
@ -260,8 +260,8 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
|
||||
|
||||
<Icon type={props.content[0].type} />
|
||||
|
||||
<div>
|
||||
<h6 className="line-clamp-1 font-medium">
|
||||
<div className="w-full">
|
||||
<h6 className="line-clamp-1 w-4/5 font-medium">
|
||||
{props.content[0].text.name?.replaceAll(/[-._]/g, ' ')}
|
||||
</h6>
|
||||
<p className="text-muted-foreground">
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { EngineManager, ToolManager } from '@janhq/core'
|
||||
|
||||
import { appService } from './appService'
|
||||
import { EventEmitter } from './eventsService'
|
||||
import { restAPI } from './restService'
|
||||
@ -12,6 +14,8 @@ export const setupCoreServices = () => {
|
||||
if (!window.core) {
|
||||
window.core = {
|
||||
events: new EventEmitter(),
|
||||
engineManager: new EngineManager(),
|
||||
toolManager: new ToolManager(),
|
||||
api: {
|
||||
...(window.electronAPI ? window.electronAPI : restAPI),
|
||||
...appService,
|
||||
|
||||
130
web/utils/messageRequestBuilder.ts
Normal file
130
web/utils/messageRequestBuilder.ts
Normal file
@ -0,0 +1,130 @@
|
||||
import {
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageContent,
|
||||
ChatCompletionMessageContentText,
|
||||
ChatCompletionMessageContentType,
|
||||
ChatCompletionRole,
|
||||
MessageRequest,
|
||||
MessageRequestType,
|
||||
MessageStatus,
|
||||
ModelInfo,
|
||||
Thread,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import { FileType } from '@/containers/Providers/Jotai'
|
||||
|
||||
export class MessageRequestBuilder {
|
||||
msgId: string
|
||||
type: MessageRequestType
|
||||
messages: ChatCompletionMessage[]
|
||||
model: ModelInfo
|
||||
thread: Thread
|
||||
|
||||
constructor(
|
||||
type: MessageRequestType,
|
||||
model: ModelInfo,
|
||||
thread: Thread,
|
||||
messages: ThreadMessage[]
|
||||
) {
|
||||
this.msgId = ulid()
|
||||
this.type = type
|
||||
this.model = model
|
||||
this.thread = thread
|
||||
this.messages = messages
|
||||
.filter((e) => e.status !== MessageStatus.Error)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
}
|
||||
|
||||
// Chainable
|
||||
pushMessage(
|
||||
message: string,
|
||||
base64Blob: string | undefined,
|
||||
fileContentType: FileType
|
||||
) {
|
||||
if (base64Blob && fileContentType === 'pdf')
|
||||
return this.addDocMessage(message)
|
||||
else if (base64Blob && fileContentType === 'image') {
|
||||
return this.addImageMessage(message, base64Blob)
|
||||
}
|
||||
this.messages = [
|
||||
...this.messages,
|
||||
{
|
||||
role: ChatCompletionRole.User,
|
||||
content: message,
|
||||
},
|
||||
]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addSystemMessage(message: string | undefined) {
|
||||
if (!message || message.trim() === '') return this
|
||||
this.messages = [
|
||||
{
|
||||
role: ChatCompletionRole.System,
|
||||
content: message,
|
||||
},
|
||||
...this.messages,
|
||||
]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addDocMessage(prompt: string) {
|
||||
const message: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.User,
|
||||
content: [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
} as ChatCompletionMessageContentText,
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Doc,
|
||||
doc_url: {
|
||||
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
|
||||
},
|
||||
},
|
||||
] as ChatCompletionMessageContent,
|
||||
}
|
||||
this.messages = [message, ...this.messages]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addImageMessage(prompt: string, base64: string) {
|
||||
const message: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.User,
|
||||
content: [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
} as ChatCompletionMessageContentText,
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Image,
|
||||
image_url: {
|
||||
url: base64,
|
||||
},
|
||||
},
|
||||
] as ChatCompletionMessageContent,
|
||||
}
|
||||
|
||||
this.messages = [message, ...this.messages]
|
||||
return this
|
||||
}
|
||||
|
||||
build(): MessageRequest {
|
||||
return {
|
||||
id: this.msgId,
|
||||
type: this.type,
|
||||
threadId: this.thread.id,
|
||||
messages: this.messages,
|
||||
model: this.model,
|
||||
thread: this.thread,
|
||||
}
|
||||
}
|
||||
}
|
||||
74
web/utils/threadMessageBuilder.ts
Normal file
74
web/utils/threadMessageBuilder.ts
Normal file
@ -0,0 +1,74 @@
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
MessageStatus,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { FileInfo } from '@/containers/Providers/Jotai'
|
||||
|
||||
import { MessageRequestBuilder } from './messageRequestBuilder'
|
||||
|
||||
export class ThreadMessageBuilder {
|
||||
messageRequest: MessageRequestBuilder
|
||||
|
||||
content: ThreadContent[] = []
|
||||
|
||||
constructor(messageRequest: MessageRequestBuilder) {
|
||||
this.messageRequest = messageRequest
|
||||
}
|
||||
|
||||
build(): ThreadMessage {
|
||||
const timestamp = Date.now()
|
||||
return {
|
||||
id: this.messageRequest.msgId,
|
||||
thread_id: this.messageRequest.thread.id,
|
||||
role: ChatCompletionRole.User,
|
||||
status: MessageStatus.Ready,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: 'thread.message',
|
||||
content: this.content,
|
||||
}
|
||||
}
|
||||
|
||||
pushMessage(
|
||||
prompt: string,
|
||||
base64: string | undefined,
|
||||
fileUpload: FileInfo[]
|
||||
) {
|
||||
if (base64 && fileUpload[0]?.type === 'image') {
|
||||
this.content.push({
|
||||
type: ContentType.Image,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if (base64 && fileUpload[0]?.type === 'pdf') {
|
||||
this.content.push({
|
||||
type: ContentType.Pdf,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64],
|
||||
name: fileUpload[0].file.name,
|
||||
size: fileUpload[0].file.size,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if (prompt && !base64) {
|
||||
this.content.push({
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [],
|
||||
},
|
||||
})
|
||||
}
|
||||
return this
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user