fix: 963 can not run openai models on windows (#974)

This commit is contained in:
Louis 2023-12-13 16:26:26 +07:00 committed by GitHub
parent f7c7ad5ecf
commit 3266014b29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 37 deletions

View File

@ -195,7 +195,10 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
requestInference( requestInference(
data?.messages ?? [], data?.messages ?? [],
this._engineSettings, this._engineSettings,
JanInferenceOpenAIExtension._currentModel, {
...JanInferenceOpenAIExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller instance.controller
).subscribe({ ).subscribe({
next: (content) => { next: (content) => {

View File

@ -31,14 +31,16 @@ import { EngineSettings } from "./@types/global";
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension { export default class JanInferenceTritonTrtLLMExtension
private static readonly _homeDir = 'engines' implements InferenceExtension
private static readonly _engineMetadataFileName = 'triton_trtllm.json' {
private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = "triton_trtllm.json";
static _currentModel: Model; static _currentModel: Model;
static _engineSettings: EngineSettings = { static _engineSettings: EngineSettings = {
"base_url": "", base_url: "",
}; };
controller = new AbortController(); controller = new AbortController();
@ -56,8 +58,8 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
onLoad(): void { onLoad(): void {
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir) fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir);
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings();
// Events subscription // Events subscription
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) =>
@ -87,20 +89,31 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
modelId: string, modelId: string,
settings?: ModelSettingParams settings?: ModelSettingParams
): Promise<void> { ): Promise<void> {
return return;
} }
static async writeDefaultEngineSettings() { static async writeDefaultEngineSettings() {
try { try {
const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName) const engine_json = join(
JanInferenceTritonTrtLLMExtension._homeDir,
JanInferenceTritonTrtLLMExtension._engineMetadataFileName
);
if (await fs.exists(engine_json)) { if (await fs.exists(engine_json)) {
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(
} await fs.readFile(engine_json)
else { );
await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2)) } else {
await fs.writeFile(
engine_json,
JSON.stringify(
JanInferenceTritonTrtLLMExtension._engineSettings,
null,
2
)
);
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err);
} }
} }
/** /**
@ -137,10 +150,11 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
}; };
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? [], requestInference(
data.messages ?? [],
JanInferenceTritonTrtLLMExtension._engineSettings, JanInferenceTritonTrtLLMExtension._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel) JanInferenceTritonTrtLLMExtension._currentModel
.subscribe({ ).subscribe({
next: (_content) => {}, next: (_content) => {},
complete: async () => { complete: async () => {
resolve(message); resolve(message);
@ -153,19 +167,22 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
} }
private static async handleModelInit(model: Model) { private static async handleModelInit(model: Model) {
if (model.engine !== 'triton_trtllm') { return } if (model.engine !== "triton_trtllm") {
else { return;
JanInferenceTritonTrtLLMExtension._currentModel = model } else {
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() JanInferenceTritonTrtLLMExtension._currentModel = model;
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings();
// Todo: Check model list with API key // Todo: Check model list with API key
events.emit(EventName.OnModelReady, model) events.emit(EventName.OnModelReady, model);
// events.emit(EventName.OnModelFail, model) // events.emit(EventName.OnModelFail, model)
} }
} }
private static async handleModelStop(model: Model) { private static async handleModelStop(model: Model) {
if (model.engine !== 'triton_trtllm') { return } if (model.engine !== "triton_trtllm") {
events.emit(EventName.OnModelStopped, model) return;
}
events.emit(EventName.OnModelStopped, model);
} }
/** /**
@ -178,7 +195,9 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
data: MessageRequest, data: MessageRequest,
instance: JanInferenceTritonTrtLLMExtension instance: JanInferenceTritonTrtLLMExtension
) { ) {
if (data.model.engine !== 'triton_trtllm') { return } if (data.model.engine !== "triton_trtllm") {
return;
}
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
@ -200,7 +219,10 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
requestInference( requestInference(
data?.messages ?? [], data?.messages ?? [],
this._engineSettings, this._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel, {
...JanInferenceTritonTrtLLMExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller instance.controller
).subscribe({ ).subscribe({
next: (content) => { next: (content) => {

View File

@ -85,7 +85,6 @@ export const useCreateNewThread = () => {
created: createdAt, created: createdAt,
updated: createdAt, updated: createdAt,
} }
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters) setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)
// add the new thread on top of the thread list to the state // add the new thread on top of the thread list to the state