diff --git a/extensions/inference-nitro-extension/rollup.config.ts b/extensions/inference-nitro-extension/rollup.config.ts
index 374a054cd..77a9fb208 100644
--- a/extensions/inference-nitro-extension/rollup.config.ts
+++ b/extensions/inference-nitro-extension/rollup.config.ts
@@ -27,6 +27,9 @@ export default [
TROUBLESHOOTING_URL: JSON.stringify(
"https://jan.ai/guides/troubleshooting"
),
+ JAN_SERVER_INFERENCE_URL: JSON.stringify(
+ "http://localhost:1337/v1/chat/completions"
+ ),
}),
// Allow json resolution
json(),
diff --git a/extensions/inference-nitro-extension/src/@types/global.d.ts b/extensions/inference-nitro-extension/src/@types/global.d.ts
index bc126337f..7a4fb4805 100644
--- a/extensions/inference-nitro-extension/src/@types/global.d.ts
+++ b/extensions/inference-nitro-extension/src/@types/global.d.ts
@@ -1,6 +1,7 @@
declare const NODE: string;
declare const INFERENCE_URL: string;
declare const TROUBLESHOOTING_URL: string;
+declare const JAN_SERVER_INFERENCE_URL: string;
/**
* The response from the initModel function.
diff --git a/extensions/inference-nitro-extension/src/helpers/sse.ts b/extensions/inference-nitro-extension/src/helpers/sse.ts
index c6352383d..aab260828 100644
--- a/extensions/inference-nitro-extension/src/helpers/sse.ts
+++ b/extensions/inference-nitro-extension/src/helpers/sse.ts
@@ -6,6 +6,7 @@ import { Observable } from "rxjs";
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
+ inferenceUrl: string,
recentMessages: any[],
model: Model,
controller?: AbortController
@@ -17,7 +18,7 @@ export function requestInference(
stream: true,
...model.parameters,
});
- fetch(INFERENCE_URL, {
+ fetch(inferenceUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts
index 9f1f00263..81a0031ac 100644
--- a/extensions/inference-nitro-extension/src/index.ts
+++ b/extensions/inference-nitro-extension/src/index.ts
@@ -68,35 +68,48 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
*/
private nitroProcessInfo: any = undefined;
+ private inferenceUrl = "";
+
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) {
- await fs
- .mkdirSync(JanInferenceNitroExtension._homeDir)
- .catch((err: Error) => console.debug(err));
+ try {
+ await fs.mkdirSync(JanInferenceNitroExtension._homeDir);
+ } catch (e) {
+ console.debug(e);
+ }
}
+ // init inference url
+ // @ts-ignore
+ const electronApi = window?.electronAPI;
+ this.inferenceUrl = INFERENCE_URL;
+ if (!electronApi) {
+ this.inferenceUrl = JAN_SERVER_INFERENCE_URL;
+ }
+ console.debug("Inference url: ", this.inferenceUrl);
+
if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir)))
await fs.mkdirSync(JanInferenceNitroExtension._settingsDir);
this.writeDefaultEngineSettings();
// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
- this.onMessageRequest(data),
+ this.onMessageRequest(data)
);
events.on(ModelEvent.OnModelInit, (model: Model) =>
- this.onModelInit(model),
+ this.onModelInit(model)
);
events.on(ModelEvent.OnModelStop, (model: Model) =>
- this.onModelStop(model),
+ this.onModelStop(model)
);
events.on(InferenceEvent.OnInferenceStopped, () =>
- this.onInferenceStopped(),
+ this.onInferenceStopped()
);
// Attempt to fetch nvidia info
@@ -121,7 +134,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
} else {
await fs.writeFileSync(
engineFile,
- JSON.stringify(this._engineSettings, null, 2),
+ JSON.stringify(this._engineSettings, null, 2)
);
}
} catch (err) {
@@ -149,7 +162,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
- JanInferenceNitroExtension._intervalHealthCheck,
+ JanInferenceNitroExtension._intervalHealthCheck
);
}
@@ -206,7 +219,11 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
return new Promise(async (resolve, reject) => {
if (!this._currentModel) return Promise.reject("No model loaded");
- requestInference(data.messages ?? [], this._currentModel).subscribe({
+ requestInference(
+ this.inferenceUrl,
+ data.messages ?? [],
+ this._currentModel
+ ).subscribe({
next: (_content: any) => {},
complete: async () => {
resolve(message);
@@ -254,7 +271,12 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
...(this._currentModel || {}),
...(data.model || {}),
};
- requestInference(data.messages ?? [], model, this.controller).subscribe({
+ requestInference(
+ this.inferenceUrl,
+ data.messages ?? [],
+ model,
+ this.controller
+ ).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
diff --git a/web/containers/Toast/index.tsx b/web/containers/Toast/index.tsx
index 7cffa89b9..eae340fee 100644
--- a/web/containers/Toast/index.tsx
+++ b/web/containers/Toast/index.tsx
@@ -19,8 +19,8 @@ const ErrorIcon = () => {
xmlns="http://www.w3.org/2000/svg"
>
@@ -38,8 +38,8 @@ const WarningIcon = () => {
xmlns="http://www.w3.org/2000/svg"
>
@@ -57,8 +57,8 @@ const SuccessIcon = () => {
xmlns="http://www.w3.org/2000/svg"
>
@@ -76,8 +76,8 @@ const DefaultIcon = () => {
xmlns="http://www.w3.org/2000/svg"
>