chore: Update openai engine

This commit is contained in:
hiro 2023-12-01 18:14:29 +07:00
parent 337da50840
commit d69f0e3321
3 changed files with 127 additions and 32 deletions

View File

@ -0,0 +1,56 @@
import { Observable } from "rxjs";
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: "gpt-3.5-turbo",
max_tokens: 2048,
});
fetch(INFERENCE_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origin": "*",
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
}
}
subscriber.complete();
})
.catch((err) => subscriber.error(err));
});
}

View File

@ -3,7 +3,7 @@
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0 * @version 1.0.0
* @module inference-extension/src/index * @module inference-openai-extension/src/index
*/ */
import { import {
@ -19,6 +19,7 @@ import {
events, events,
executeOnMain, executeOnMain,
getUserSpace, getUserSpace,
fs
} from "@janhq/core"; } from "@janhq/core";
import { InferenceExtension } from "@janhq/core"; import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse"; import { requestInference } from "./helpers/sse";
@ -31,20 +32,26 @@ import { join } from "path";
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceExtension implements InferenceExtension { export default class JanInferenceExtension implements InferenceExtension {
private static readonly _homeDir = 'engines'
private static readonly _engineMetadataFileName = 'openai.json'
controller = new AbortController(); controller = new AbortController();
isCancelled = false; isCancelled = false;
/** /**
* Returns the type of the extension. * Returns the type of the extension.
* @returns {ExtensionType} The type of the extension. * @returns {ExtensionType} The type of the extension.
*/ */
// TODO: To fix
type(): ExtensionType { type(): ExtensionType {
return ExtensionType.Inference; return undefined;
} }
// janroot/engine/nitro.json
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
onLoad(): void { onLoad(): void {
fs.mkdir(JanInferenceExtension._homeDir)
// TODO: Copy nitro.json to janroot/engine/nitro.json
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) =>
JanInferenceExtension.handleMessageRequest(data, this) JanInferenceExtension.handleMessageRequest(data, this)
); );
@ -53,9 +60,7 @@ export default class JanInferenceExtension implements InferenceExtension {
/** /**
* Stops the model inference. * Stops the model inference.
*/ */
onUnload(): void { onUnload(): void {}
this.stopModel();
}
/** /**
* Initializes the model with the specified file name. * Initializes the model with the specified file name.
@ -79,9 +84,7 @@ export default class JanInferenceExtension implements InferenceExtension {
* Stops the model. * Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped. * @returns {Promise<void>} A promise that resolves when the model is stopped.
*/ */
async stopModel(): Promise<void> { async stopModel(): Promise<void> {}
return executeOnMain(MODULE, "killSubprocess");
}
/** /**
* Stops streaming inference. * Stops streaming inference.
@ -92,35 +95,37 @@ export default class JanInferenceExtension implements InferenceExtension {
this.controller?.abort(); this.controller?.abort();
} }
private async copyModelsToHomeDir() {
try {
// list all of the files under the home directory
const files = await fs.listFiles('')
if (files.includes(JanInferenceExtension._homeDir)) {
// ignore if the model is already downloaded
console.debug('Model already downloaded')
return
}
// copy models folder from resources to home directory
const resourePath = await getResourcePath()
const srcPath = join(resourePath, 'models')
const userSpace = await getUserSpace()
const destPath = join(userSpace, JanInferenceExtension._homeDir)
await fs.copyFile(srcPath, destPath)
} catch (err) {
console.error(err)
}
}
/** /**
* Makes a single response inference request. * Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request. * @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response. * @returns {Promise<any>} A promise that resolves with the inference response.
*/ */
async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> { async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now(); // TODO: @louis
const message: ThreadMessage = {
thread_id: data.threadId,
created: timestamp,
updated: timestamp,
status: MessageStatus.Ready,
id: "",
role: ChatCompletionRole.Assistant,
object: "thread.message",
content: [],
};
return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? []).subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
},
error: async (err) => {
reject(err);
},
});
});
} }
/** /**

View File

@ -0,0 +1,34 @@
const fetchRetry = require("fetch-retry")(global.fetch);
const log = require("electron-log");
const OPENAI_BASE_URL = "https://api.openai.com/v1";
const OPENAI_API_KEY = process.env.OPENAI_API_KEY;
/**
* The response from the initModel function.
* @property error - An error message if the model fails to load.
*/
interface InitModelResponse {
error?: any;
modelFile?: string;
}
// /root/engine/nitro.json
/**
* Initializes a Nitro subprocess to load a machine learning model.
* @param modelFile - The name of the machine learning model file.
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
*/
function initModel(wrapper: any): Promise<InitModelResponse> {
const engine_settings = {
...wrapper.settings,
};
return (
)
}
module.exports = {
initModel,
};