diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index 8d72422d3..95fac2fc0 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -17,20 +17,13 @@ import { ThreadContent, ThreadMessage, events, - executeOnMain, - getUserSpace, fs, - Model, } from "@janhq/core"; import { InferenceExtension } from "@janhq/core"; import { requestInference } from "./helpers/sse"; import { ulid } from "ulid"; import { join } from "path"; - -interface EngineSettings { - base_url?: string; - api_key?: string; -} +import { EngineSettings, OpenAIModel } from "./@types/global"; /** * A class that implements the InferenceExtension interface from the @janhq/core package. @@ -41,12 +34,16 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { private static readonly _homeDir = 'engines' private static readonly _engineMetadataFileName = 'openai.json' - private _engineSettings: EngineSettings = { + static _currentModel: OpenAIModel; + + static _engineSettings: EngineSettings = { "base_url": "https://api.openai.com/v1", "api_key": "sk-" - } + }; + controller = new AbortController(); isCancelled = false; + /** * Returns the type of the extension. * @returns {ExtensionType} The type of the extension. @@ -55,7 +52,6 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { type(): ExtensionType { return undefined; } -// janroot/engine/nitro.json /** * Subscribes to events emitted by the @janhq/core package. */ @@ -68,12 +64,12 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { JanInferenceOpenAIExtension.handleMessageRequest(data, this) ); - events.on(EventName.OnModelInit, (data: Model) => { - JanInferenceOpenAIExtension.handleModelInit(data); + events.on(EventName.OnModelInit, (model: OpenAIModel) => { + JanInferenceOpenAIExtension.handleModelInit(model); }); - events.on(EventName.OnModelStop, (data: Model) => { - JanInferenceOpenAIExtension.handleModelStop(data); + events.on(EventName.OnModelStop, (model: OpenAIModel) => { + JanInferenceOpenAIExtension.handleModelStop(model); }); } @@ -98,10 +94,10 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { try { const engine_json = join(JanInferenceOpenAIExtension._homeDir, JanInferenceOpenAIExtension._engineMetadataFileName) if (await fs.checkFileExists(engine_json)) { - this._engineSettings = JSON.parse(await fs.readFile(engine_json)) + JanInferenceOpenAIExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) } else { - await fs.writeFile(engine_json, JSON.stringify(this._engineSettings, null, 2)) + await fs.writeFile(engine_json, JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)) } } catch (err) { console.error(err) @@ -141,32 +137,34 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { }; return new Promise(async (resolve, reject) => { - requestInference(data.messages ?? []).subscribe({ - next: (_content) => {}, - complete: async () => { - resolve(message); - }, - error: async (err) => { - reject(err); - }, + requestInference(data.messages ?? [], + JanInferenceOpenAIExtension._engineSettings, + JanInferenceOpenAIExtension._currentModel) + .subscribe({ + next: (_content) => {}, + complete: async () => { + resolve(message); + }, + error: async (err) => { + reject(err); + }, }); }); } - private static async handleModelInit(data: Model) { - console.log('Model init success', data) - // Add filter data engine = openai - if (data.engine !== 'openai') { return } - // If model success - events.emit(EventName.OnModelReady, {modelId: data.id}) - // If model failed - // events.emit(EventName.OnModelFail, {modelId: data.id}) + private static async handleModelInit(model: OpenAIModel) { + if (model.engine !== 'openai') { return } + else { + JanInferenceOpenAIExtension._currentModel = model + // Todo: Check model list with API key + events.emit(EventName.OnModelReady, model) + // events.emit(EventName.OnModelFail, model) + } } - private static async handleModelStop(data: Model) { - // Add filter data engine = openai - if (data.engine !== 'openai') { return } - events.emit(EventName.OnModelStop, {modelId: data.id}) + private static async handleModelStop(model: OpenAIModel) { + if (model.engine !== 'openai') { return } + events.emit(EventName.OnModelStopped, model) } /** @@ -179,6 +177,8 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { data: MessageRequest, instance: JanInferenceOpenAIExtension ) { + if (data.model.engine !== 'openai') { return } + const timestamp = Date.now(); const message: ThreadMessage = { id: ulid(), @@ -196,7 +196,12 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { instance.isCancelled = false; instance.controller = new AbortController(); - requestInference(data.messages, instance.controller).subscribe({ + requestInference( + data?.messages ?? [], + this._engineSettings, + JanInferenceOpenAIExtension._currentModel, + instance.controller + ).subscribe({ next: (content) => { const messageContent: ThreadContent = { type: ContentType.Text,