diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index 6bab563dd..c719e405f 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -195,7 +195,10 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { requestInference( data?.messages ?? [], this._engineSettings, - JanInferenceOpenAIExtension._currentModel, + { + ...JanInferenceOpenAIExtension._currentModel, + parameters: data.model.parameters, + }, instance.controller ).subscribe({ next: (content) => { diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts index 9e8d64bb2..103d61f68 100644 --- a/extensions/inference-triton-trtllm-extension/src/index.ts +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -31,14 +31,16 @@ import { EngineSettings } from "./@types/global"; * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension { - private static readonly _homeDir = 'engines' - private static readonly _engineMetadataFileName = 'triton_trtllm.json' - +export default class JanInferenceTritonTrtLLMExtension + implements InferenceExtension +{ + private static readonly _homeDir = "engines"; + private static readonly _engineMetadataFileName = "triton_trtllm.json"; + static _currentModel: Model; static _engineSettings: EngineSettings = { - "base_url": "", + base_url: "", }; controller = new AbortController(); @@ -56,8 +58,8 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten * Subscribes to events emitted by the @janhq/core package. */ onLoad(): void { - fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir) - JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() + fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir); + JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings(); // Events subscription events.on(EventName.OnMessageSent, (data) => @@ -87,20 +89,31 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten modelId: string, settings?: ModelSettingParams ): Promise { - return + return; } static async writeDefaultEngineSettings() { try { - const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName) + const engine_json = join( + JanInferenceTritonTrtLLMExtension._homeDir, + JanInferenceTritonTrtLLMExtension._engineMetadataFileName + ); if (await fs.exists(engine_json)) { - JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) - } - else { - await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2)) + JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse( + await fs.readFile(engine_json) + ); + } else { + await fs.writeFile( + engine_json, + JSON.stringify( + JanInferenceTritonTrtLLMExtension._engineSettings, + null, + 2 + ) + ); } } catch (err) { - console.error(err) + console.error(err); } } /** @@ -137,35 +150,39 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten }; return new Promise(async (resolve, reject) => { - requestInference(data.messages ?? [], - JanInferenceTritonTrtLLMExtension._engineSettings, - JanInferenceTritonTrtLLMExtension._currentModel) - .subscribe({ - next: (_content) => {}, - complete: async () => { - resolve(message); - }, - error: async (err) => { - reject(err); - }, + requestInference( + data.messages ?? [], + JanInferenceTritonTrtLLMExtension._engineSettings, + JanInferenceTritonTrtLLMExtension._currentModel + ).subscribe({ + next: (_content) => {}, + complete: async () => { + resolve(message); + }, + error: async (err) => { + reject(err); + }, }); }); } private static async handleModelInit(model: Model) { - if (model.engine !== 'triton_trtllm') { return } - else { - JanInferenceTritonTrtLLMExtension._currentModel = model - JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() + if (model.engine !== "triton_trtllm") { + return; + } else { + JanInferenceTritonTrtLLMExtension._currentModel = model; + JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings(); // Todo: Check model list with API key - events.emit(EventName.OnModelReady, model) + events.emit(EventName.OnModelReady, model); // events.emit(EventName.OnModelFail, model) } } private static async handleModelStop(model: Model) { - if (model.engine !== 'triton_trtllm') { return } - events.emit(EventName.OnModelStopped, model) + if (model.engine !== "triton_trtllm") { + return; + } + events.emit(EventName.OnModelStopped, model); } /** @@ -178,8 +195,10 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten data: MessageRequest, instance: JanInferenceTritonTrtLLMExtension ) { - if (data.model.engine !== 'triton_trtllm') { return } - + if (data.model.engine !== "triton_trtllm") { + return; + } + const timestamp = Date.now(); const message: ThreadMessage = { id: ulid(), @@ -200,7 +219,10 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten requestInference( data?.messages ?? [], this._engineSettings, - JanInferenceTritonTrtLLMExtension._currentModel, + { + ...JanInferenceTritonTrtLLMExtension._currentModel, + parameters: data.model.parameters, + }, instance.controller ).subscribe({ next: (content) => { diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index ff0b4d049..954929553 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -85,7 +85,6 @@ export const useCreateNewThread = () => { created: createdAt, updated: createdAt, } - setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters) // add the new thread on top of the thread list to the state