fix: Add dynamic values from engine settings and model params

This commit is contained in:
hiro 2023-12-05 00:45:49 +07:00
parent 0c3e23665b
commit 7ed8c31629

View File

@ -17,20 +17,13 @@ import {
ThreadContent, ThreadContent,
ThreadMessage, ThreadMessage,
events, events,
executeOnMain,
getUserSpace,
fs, fs,
Model,
} from "@janhq/core"; } from "@janhq/core";
import { InferenceExtension } from "@janhq/core"; import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse"; import { requestInference } from "./helpers/sse";
import { ulid } from "ulid"; import { ulid } from "ulid";
import { join } from "path"; import { join } from "path";
import { EngineSettings, OpenAIModel } from "./@types/global";
interface EngineSettings {
base_url?: string;
api_key?: string;
}
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
@ -41,12 +34,16 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
private static readonly _homeDir = 'engines' private static readonly _homeDir = 'engines'
private static readonly _engineMetadataFileName = 'openai.json' private static readonly _engineMetadataFileName = 'openai.json'
private _engineSettings: EngineSettings = { static _currentModel: OpenAIModel;
static _engineSettings: EngineSettings = {
"base_url": "https://api.openai.com/v1", "base_url": "https://api.openai.com/v1",
"api_key": "sk-<your key here>" "api_key": "sk-<your key here>"
} };
controller = new AbortController(); controller = new AbortController();
isCancelled = false; isCancelled = false;
/** /**
* Returns the type of the extension. * Returns the type of the extension.
* @returns {ExtensionType} The type of the extension. * @returns {ExtensionType} The type of the extension.
@ -55,7 +52,6 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
type(): ExtensionType { type(): ExtensionType {
return undefined; return undefined;
} }
// janroot/engine/nitro.json
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
@ -68,12 +64,12 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
JanInferenceOpenAIExtension.handleMessageRequest(data, this) JanInferenceOpenAIExtension.handleMessageRequest(data, this)
); );
events.on(EventName.OnModelInit, (data: Model) => { events.on(EventName.OnModelInit, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelInit(data); JanInferenceOpenAIExtension.handleModelInit(model);
}); });
events.on(EventName.OnModelStop, (data: Model) => { events.on(EventName.OnModelStop, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelStop(data); JanInferenceOpenAIExtension.handleModelStop(model);
}); });
} }
@ -98,10 +94,10 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
try { try {
const engine_json = join(JanInferenceOpenAIExtension._homeDir, JanInferenceOpenAIExtension._engineMetadataFileName) const engine_json = join(JanInferenceOpenAIExtension._homeDir, JanInferenceOpenAIExtension._engineMetadataFileName)
if (await fs.checkFileExists(engine_json)) { if (await fs.checkFileExists(engine_json)) {
this._engineSettings = JSON.parse(await fs.readFile(engine_json)) JanInferenceOpenAIExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
} }
else { else {
await fs.writeFile(engine_json, JSON.stringify(this._engineSettings, null, 2)) await fs.writeFile(engine_json, JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2))
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err)
@ -141,32 +137,34 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
}; };
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? []).subscribe({ requestInference(data.messages ?? [],
next: (_content) => {}, JanInferenceOpenAIExtension._engineSettings,
complete: async () => { JanInferenceOpenAIExtension._currentModel)
resolve(message); .subscribe({
}, next: (_content) => {},
error: async (err) => { complete: async () => {
reject(err); resolve(message);
}, },
error: async (err) => {
reject(err);
},
}); });
}); });
} }
private static async handleModelInit(data: Model) { private static async handleModelInit(model: OpenAIModel) {
console.log('Model init success', data) if (model.engine !== 'openai') { return }
// Add filter data engine = openai else {
if (data.engine !== 'openai') { return } JanInferenceOpenAIExtension._currentModel = model
// If model success // Todo: Check model list with API key
events.emit(EventName.OnModelReady, {modelId: data.id}) events.emit(EventName.OnModelReady, model)
// If model failed // events.emit(EventName.OnModelFail, model)
// events.emit(EventName.OnModelFail, {modelId: data.id}) }
} }
private static async handleModelStop(data: Model) { private static async handleModelStop(model: OpenAIModel) {
// Add filter data engine = openai if (model.engine !== 'openai') { return }
if (data.engine !== 'openai') { return } events.emit(EventName.OnModelStopped, model)
events.emit(EventName.OnModelStop, {modelId: data.id})
} }
/** /**
@ -179,6 +177,8 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
data: MessageRequest, data: MessageRequest,
instance: JanInferenceOpenAIExtension instance: JanInferenceOpenAIExtension
) { ) {
if (data.model.engine !== 'openai') { return }
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
@ -196,7 +196,12 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
instance.isCancelled = false; instance.isCancelled = false;
instance.controller = new AbortController(); instance.controller = new AbortController();
requestInference(data.messages, instance.controller).subscribe({ requestInference(
data?.messages ?? [],
this._engineSettings,
JanInferenceOpenAIExtension._currentModel,
instance.controller
).subscribe({
next: (content) => { next: (content) => {
const messageContent: ThreadContent = { const messageContent: ThreadContent = {
type: ContentType.Text, type: ContentType.Text,