diff --git a/extensions/inference-nitro-extension/rollup.config.ts b/extensions/inference-nitro-extension/rollup.config.ts index 374a054cd..77a9fb208 100644 --- a/extensions/inference-nitro-extension/rollup.config.ts +++ b/extensions/inference-nitro-extension/rollup.config.ts @@ -27,6 +27,9 @@ export default [ TROUBLESHOOTING_URL: JSON.stringify( "https://jan.ai/guides/troubleshooting" ), + JAN_SERVER_INFERENCE_URL: JSON.stringify( + "http://localhost:1337/v1/chat/completions" + ), }), // Allow json resolution json(), diff --git a/extensions/inference-nitro-extension/src/@types/global.d.ts b/extensions/inference-nitro-extension/src/@types/global.d.ts index bc126337f..7a4fb4805 100644 --- a/extensions/inference-nitro-extension/src/@types/global.d.ts +++ b/extensions/inference-nitro-extension/src/@types/global.d.ts @@ -1,6 +1,7 @@ declare const NODE: string; declare const INFERENCE_URL: string; declare const TROUBLESHOOTING_URL: string; +declare const JAN_SERVER_INFERENCE_URL: string; /** * The response from the initModel function. diff --git a/extensions/inference-nitro-extension/src/helpers/sse.ts b/extensions/inference-nitro-extension/src/helpers/sse.ts index c6352383d..aab260828 100644 --- a/extensions/inference-nitro-extension/src/helpers/sse.ts +++ b/extensions/inference-nitro-extension/src/helpers/sse.ts @@ -6,6 +6,7 @@ import { Observable } from "rxjs"; * @returns An Observable that emits the generated response as a string. */ export function requestInference( + inferenceUrl: string, recentMessages: any[], model: Model, controller?: AbortController @@ -17,7 +18,7 @@ export function requestInference( stream: true, ...model.parameters, }); - fetch(INFERENCE_URL, { + fetch(inferenceUrl, { method: "POST", headers: { "Content-Type": "application/json", diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 9f1f00263..81a0031ac 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -68,35 +68,48 @@ export default class JanInferenceNitroExtension extends InferenceExtension { */ private nitroProcessInfo: any = undefined; + private inferenceUrl = ""; + /** * Subscribes to events emitted by the @janhq/core package. */ async onLoad() { if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) { - await fs - .mkdirSync(JanInferenceNitroExtension._homeDir) - .catch((err: Error) => console.debug(err)); + try { + await fs.mkdirSync(JanInferenceNitroExtension._homeDir); + } catch (e) { + console.debug(e); + } } + // init inference url + // @ts-ignore + const electronApi = window?.electronAPI; + this.inferenceUrl = INFERENCE_URL; + if (!electronApi) { + this.inferenceUrl = JAN_SERVER_INFERENCE_URL; + } + console.debug("Inference url: ", this.inferenceUrl); + if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir))) await fs.mkdirSync(JanInferenceNitroExtension._settingsDir); this.writeDefaultEngineSettings(); // Events subscription events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => - this.onMessageRequest(data), + this.onMessageRequest(data) ); events.on(ModelEvent.OnModelInit, (model: Model) => - this.onModelInit(model), + this.onModelInit(model) ); events.on(ModelEvent.OnModelStop, (model: Model) => - this.onModelStop(model), + this.onModelStop(model) ); events.on(InferenceEvent.OnInferenceStopped, () => - this.onInferenceStopped(), + this.onInferenceStopped() ); // Attempt to fetch nvidia info @@ -121,7 +134,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension { } else { await fs.writeFileSync( engineFile, - JSON.stringify(this._engineSettings, null, 2), + JSON.stringify(this._engineSettings, null, 2) ); } } catch (err) { @@ -149,7 +162,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension { this.getNitroProcesHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), - JanInferenceNitroExtension._intervalHealthCheck, + JanInferenceNitroExtension._intervalHealthCheck ); } @@ -206,7 +219,11 @@ export default class JanInferenceNitroExtension extends InferenceExtension { return new Promise(async (resolve, reject) => { if (!this._currentModel) return Promise.reject("No model loaded"); - requestInference(data.messages ?? [], this._currentModel).subscribe({ + requestInference( + this.inferenceUrl, + data.messages ?? [], + this._currentModel + ).subscribe({ next: (_content: any) => {}, complete: async () => { resolve(message); @@ -254,7 +271,12 @@ export default class JanInferenceNitroExtension extends InferenceExtension { ...(this._currentModel || {}), ...(data.model || {}), }; - requestInference(data.messages ?? [], model, this.controller).subscribe({ + requestInference( + this.inferenceUrl, + data.messages ?? [], + model, + this.controller + ).subscribe({ next: (content: any) => { const messageContent: ThreadContent = { type: ContentType.Text, diff --git a/web/containers/Toast/index.tsx b/web/containers/Toast/index.tsx index 7cffa89b9..eae340fee 100644 --- a/web/containers/Toast/index.tsx +++ b/web/containers/Toast/index.tsx @@ -19,8 +19,8 @@ const ErrorIcon = () => { xmlns="http://www.w3.org/2000/svg" > @@ -38,8 +38,8 @@ const WarningIcon = () => { xmlns="http://www.w3.org/2000/svg" > @@ -57,8 +57,8 @@ const SuccessIcon = () => { xmlns="http://www.w3.org/2000/svg" > @@ -76,8 +76,8 @@ const DefaultIcon = () => { xmlns="http://www.w3.org/2000/svg" >