refactor: remove hardcoded provider names (#4995)

* refactor: remove hardcoded provider names

* chore: continue the replacement
This commit is contained in:
Louis 2025-05-15 22:10:43 +07:00 committed by GitHub
parent bf3f22c854
commit e9f37e98d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 49 additions and 114 deletions

View File

@ -1,4 +1,3 @@
import { InferenceEngine } from '../../../types'
import { AIEngine } from './AIEngine' import { AIEngine } from './AIEngine'
/** /**
@ -22,22 +21,6 @@ export class EngineManager {
* @returns The engine, if found. * @returns The engine, if found.
*/ */
get<T extends AIEngine>(provider: string): T | undefined { get<T extends AIEngine>(provider: string): T | undefined {
// Backward compatible provider
// nitro is migrated to cortex
if (
[
InferenceEngine.nitro,
InferenceEngine.cortex,
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_onnx,
InferenceEngine.cortex_tensorrtllm,
InferenceEngine.cortex_onnx,
]
.map((e) => e.toString())
.includes(provider)
)
provider = InferenceEngine.cortex
return this.engines.get(provider) as T | undefined return this.engines.get(provider) as T | undefined
} }

View File

@ -1,5 +1,4 @@
import { import {
InferenceEngine,
Engines, Engines,
EngineVariant, EngineVariant,
EngineReleased, EngineReleased,
@ -28,7 +27,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @param name - Inference engine name. * @param name - Inference engine name.
* @returns A Promise that resolves to an array of installed engine. * @returns A Promise that resolves to an array of installed engine.
*/ */
abstract getInstalledEngines(name: InferenceEngine): Promise<EngineVariant[]> abstract getInstalledEngines(name: string): Promise<EngineVariant[]>
/** /**
* @param name - Inference engine name. * @param name - Inference engine name.
@ -37,7 +36,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an array of latest released engine by version. * @returns A Promise that resolves to an array of latest released engine by version.
*/ */
abstract getReleasedEnginesByVersion( abstract getReleasedEnginesByVersion(
name: InferenceEngine, name: string,
version: string, version: string,
platform?: string platform?: string
): Promise<EngineReleased[]> ): Promise<EngineReleased[]>
@ -48,7 +47,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an array of latest released engine. * @returns A Promise that resolves to an array of latest released engine.
*/ */
abstract getLatestReleasedEngine( abstract getLatestReleasedEngine(
name: InferenceEngine, name: string,
platform?: string platform?: string
): Promise<EngineReleased[]> ): Promise<EngineReleased[]>
@ -74,7 +73,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to unintall of engine. * @returns A Promise that resolves to unintall of engine.
*/ */
abstract uninstallEngine( abstract uninstallEngine(
name: InferenceEngine, name: string,
engineConfig: EngineConfig engineConfig: EngineConfig
): Promise<{ messages: string }> ): Promise<{ messages: string }>
@ -83,7 +82,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an object of default engine. * @returns A Promise that resolves to an object of default engine.
*/ */
abstract getDefaultEngineVariant( abstract getDefaultEngineVariant(
name: InferenceEngine name: string
): Promise<DefaultEngineVariant> ): Promise<DefaultEngineVariant>
/** /**
@ -92,7 +91,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to set default engine. * @returns A Promise that resolves to set default engine.
*/ */
abstract setDefaultEngineVariant( abstract setDefaultEngineVariant(
name: InferenceEngine, name: string,
engineConfig: EngineConfig engineConfig: EngineConfig
): Promise<{ messages: string }> ): Promise<{ messages: string }>
@ -100,7 +99,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to update engine. * @returns A Promise that resolves to update engine.
*/ */
abstract updateEngine( abstract updateEngine(
name: InferenceEngine, name: string,
engineConfig?: EngineConfig engineConfig?: EngineConfig
): Promise<{ messages: string }> ): Promise<{ messages: string }>
@ -112,5 +111,5 @@ export abstract class EngineManagementExtension extends BaseExtension {
/** /**
* @returns A Promise that resolves to an object of remote models list . * @returns A Promise that resolves to an object of remote models list .
*/ */
abstract getRemoteModels(name: InferenceEngine | string): Promise<any> abstract getRemoteModels(name: string): Promise<any>
} }

View File

@ -1,7 +1,5 @@
import { InferenceEngine } from '../../types'
export type Engines = { export type Engines = {
[key in InferenceEngine]: (EngineVariant & EngineConfig)[] [key: string]: (EngineVariant & EngineConfig)[]
} }
export type EngineMetadata = { export type EngineMetadata = {
@ -22,13 +20,13 @@ export type EngineMetadata = {
} }
export type EngineVariant = { export type EngineVariant = {
engine: InferenceEngine engine: string
name: string name: string
version: string version: string
} }
export type DefaultEngineVariant = { export type DefaultEngineVariant = {
engine: InferenceEngine engine: string
variant: string variant: string
version: string version: string
} }

View File

@ -6,29 +6,7 @@ export type ModelInfo = {
id: string id: string
settings?: ModelSettingParams settings?: ModelSettingParams
parameters?: ModelRuntimeParams parameters?: ModelRuntimeParams
engine?: InferenceEngine engine?: string
}
/**
* Represents the inference engine.
* @stored
*/
export enum InferenceEngine {
anthropic = 'anthropic',
mistral = 'mistral',
martian = 'martian',
openrouter = 'openrouter',
nitro = 'nitro',
openai = 'openai',
groq = 'groq',
triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
cohere = 'cohere',
nvidia = 'nvidia',
cortex = 'cortex',
cortex_llamacpp = 'llama-cpp',
cortex_onnx = 'onnxruntime',
cortex_tensorrtllm = 'tensorrt-llm',
} }
// Represents an artifact of a model, including its filename and URL // Represents an artifact of a model, including its filename and URL
@ -105,7 +83,7 @@ export type Model = {
/** /**
* The model engine. * The model engine.
*/ */
engine: InferenceEngine engine: string
} }
// Represents metadata associated with a model // Represents metadata associated with a model

View File

@ -1,6 +1,5 @@
import { import {
EngineManagementExtension, EngineManagementExtension,
InferenceEngine,
DefaultEngineVariant, DefaultEngineVariant,
Engines, Engines,
EngineConfig, EngineConfig,
@ -35,7 +34,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
*/ */
async apiInstance(): Promise<KyInstance> { async apiInstance(): Promise<KyInstance> {
if (this.api) return this.api if (this.api) return this.api
const apiKey = (await window.core?.api.appToken()) const apiKey = await window.core?.api.appToken()
this.api = ky.extend({ this.api = ky.extend({
prefixUrl: API_URL, prefixUrl: API_URL,
headers: apiKey headers: apiKey
@ -96,7 +95,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name. * @param name - Inference engine name.
* @returns A Promise that resolves to an array of installed engine. * @returns A Promise that resolves to an array of installed engine.
*/ */
async getInstalledEngines(name: InferenceEngine): Promise<EngineVariant[]> { async getInstalledEngines(name: string): Promise<EngineVariant[]> {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.get(`v1/engines/${name}`) .get(`v1/engines/${name}`)
@ -112,7 +111,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @returns A Promise that resolves to an array of latest released engine by version. * @returns A Promise that resolves to an array of latest released engine by version.
*/ */
async getReleasedEnginesByVersion( async getReleasedEnginesByVersion(
name: InferenceEngine, name: string,
version: string, version: string,
platform?: string platform?: string
) { ) {
@ -131,7 +130,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param platform - Optional to sort by operating system. macOS, linux, windows. * @param platform - Optional to sort by operating system. macOS, linux, windows.
* @returns A Promise that resolves to an array of latest released engine by version. * @returns A Promise that resolves to an array of latest released engine by version.
*/ */
async getLatestReleasedEngine(name: InferenceEngine, platform?: string) { async getLatestReleasedEngine(name: string, platform?: string) {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.get(`v1/engines/${name}/releases/latest`) .get(`v1/engines/${name}/releases/latest`)
@ -197,7 +196,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name. * @param name - Inference engine name.
* @returns A Promise that resolves to unintall of engine. * @returns A Promise that resolves to unintall of engine.
*/ */
async uninstallEngine(name: InferenceEngine, engineConfig: EngineConfig) { async uninstallEngine(name: string, engineConfig: EngineConfig) {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.delete(`v1/engines/${name}/install`, { json: engineConfig }) .delete(`v1/engines/${name}/install`, { json: engineConfig })
@ -234,7 +233,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name. * @param name - Inference engine name.
* @returns A Promise that resolves to an object of default engine. * @returns A Promise that resolves to an object of default engine.
*/ */
async getDefaultEngineVariant(name: InferenceEngine) { async getDefaultEngineVariant(name: string) {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.get(`v1/engines/${name}/default`) .get(`v1/engines/${name}/default`)
@ -248,10 +247,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @body version - string * @body version - string
* @returns A Promise that resolves to set default engine. * @returns A Promise that resolves to set default engine.
*/ */
async setDefaultEngineVariant( async setDefaultEngineVariant(name: string, engineConfig: EngineConfig) {
name: InferenceEngine,
engineConfig: EngineConfig
) {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.post(`v1/engines/${name}/default`, { json: engineConfig }) .post(`v1/engines/${name}/default`, { json: engineConfig })
@ -262,7 +258,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
/** /**
* @returns A Promise that resolves to update engine. * @returns A Promise that resolves to update engine.
*/ */
async updateEngine(name: InferenceEngine, engineConfig?: EngineConfig) { async updateEngine(name: string, engineConfig?: EngineConfig) {
return this.apiInstance().then((api) => return this.apiInstance().then((api) =>
api api
.post(`v1/engines/${name}/update`, { json: engineConfig }) .post(`v1/engines/${name}/update`, { json: engineConfig })
@ -276,12 +272,8 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
*/ */
async updateDefaultEngine() { async updateDefaultEngine() {
try { try {
const variant = await this.getDefaultEngineVariant( const variant = await this.getDefaultEngineVariant('llama-cpp')
InferenceEngine.cortex_llamacpp const installedEngines = await this.getInstalledEngines('llama-cpp')
)
const installedEngines = await this.getInstalledEngines(
InferenceEngine.cortex_llamacpp
)
if ( if (
!installedEngines.some( !installedEngines.some(
(e) => e.name === variant.variant && e.version === variant.version (e) => e.name === variant.variant && e.version === variant.version
@ -299,7 +291,8 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
) { ) {
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
const variant = await engineVariant(systemInfo.gpuSetting) const variant = await engineVariant(systemInfo.gpuSetting)
await this.setDefaultEngineVariant(InferenceEngine.cortex_llamacpp, { // TODO: Use correct provider name when moving to llama.cpp extension
await this.setDefaultEngineVariant('llama-cpp', {
variant: variant, variant: variant,
version: `${CORTEX_ENGINE_VERSION}`, version: `${CORTEX_ENGINE_VERSION}`,
}) })
@ -368,7 +361,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
models.data.map((model) => models.data.map((model) =>
this.addRemoteModel({ this.addRemoteModel({
...model, ...model,
engine: engineConfig.engine as InferenceEngine, engine: engineConfig.engine,
model: model.model ?? model.id, model: model.model ?? model.id,
}).catch(console.info) }).catch(console.info)
) )

View File

@ -11,7 +11,6 @@ import {
executeOnMain, executeOnMain,
EngineEvent, EngineEvent,
LocalOAIEngine, LocalOAIEngine,
InferenceEngine,
extractModelLoadParams, extractModelLoadParams,
events, events,
ModelEvent, ModelEvent,
@ -49,7 +48,7 @@ type LoadedModelResponse = { data: { engine: string; id: string }[] }
export default class JanInferenceCortexExtension extends LocalOAIEngine { export default class JanInferenceCortexExtension extends LocalOAIEngine {
nodeModule: string = 'node' nodeModule: string = 'node'
provider: string = InferenceEngine.cortex provider: string = 'cortex'
shouldReconnect = true shouldReconnect = true
@ -198,8 +197,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
...extractModelLoadParams(model.settings), ...extractModelLoadParams(model.settings),
model: model.id, model: model.id,
engine: engine:
model.engine === InferenceEngine.nitro // Legacy model cache model.engine === "nitro" // Legacy model cache
? InferenceEngine.cortex_llamacpp ? "llama-cpp"
: model.engine, : model.engine,
cont_batching: this.cont_batching, cont_batching: this.cont_batching,
n_parallel: this.n_parallel, n_parallel: this.n_parallel,

View File

@ -1,7 +1,6 @@
import { import {
ModelExtension, ModelExtension,
Model, Model,
InferenceEngine,
joinPath, joinPath,
dirName, dirName,
fs, fs,
@ -37,7 +36,7 @@ export default class JanModelExtension extends ModelExtension {
*/ */
async apiInstance(): Promise<KyInstance> { async apiInstance(): Promise<KyInstance> {
if (this.api) return this.api if (this.api) return this.api
const apiKey = (await window.core?.api.appToken()) const apiKey = await window.core?.api.appToken()
this.api = ky.extend({ this.api = ky.extend({
prefixUrl: CORTEX_API_URL, prefixUrl: CORTEX_API_URL,
headers: apiKey headers: apiKey
@ -45,7 +44,7 @@ export default class JanModelExtension extends ModelExtension {
Authorization: `Bearer ${apiKey}`, Authorization: `Bearer ${apiKey}`,
} }
: {}, : {},
retry: 10 retry: 10,
}) })
return this.api return this.api
} }
@ -153,9 +152,7 @@ export default class JanModelExtension extends ModelExtension {
* Here we are filtering out the models that are not imported * Here we are filtering out the models that are not imported
* and are not using llama.cpp engine * and are not using llama.cpp engine
*/ */
var toImportModels = legacyModels.filter( var toImportModels = legacyModels.filter((e) => e.engine === 'nitro')
(e) => e.engine === InferenceEngine.nitro
)
/** /**
* Fetch models from cortex.cpp * Fetch models from cortex.cpp

View File

@ -1,13 +1,5 @@
import { InferenceEngine, Model, fs, joinPath } from '@janhq/core' import { Model, fs, joinPath } from '@janhq/core'
//// LEGACY MODEL FOLDER //// //// LEGACY MODEL FOLDER ////
const LocalEngines = [
InferenceEngine.cortex,
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_tensorrtllm,
InferenceEngine.cortex_onnx,
InferenceEngine.nitro_tensorrt_llm,
InferenceEngine.nitro,
]
/** /**
* Scan through models folder and return downloaded models * Scan through models folder and return downloaded models
* @returns * @returns
@ -68,7 +60,7 @@ export const scanModelsFolder = async (): Promise<
) )
) )
if ( if (
!LocalEngines.includes(model.engine) || !['cortex', 'llama-cpp', 'nitro'].includes(model.engine) ||
existFiles.every((exist) => exist) existFiles.every((exist) => exist)
) )
return model return model
@ -86,9 +78,9 @@ export const scanModelsFolder = async (): Promise<
file.toLowerCase().endsWith('.engine') // Tensort-LLM file.toLowerCase().endsWith('.engine') // Tensort-LLM
) )
})?.length >= })?.length >=
(model.engine === InferenceEngine.nitro_tensorrt_llm (model.engine === 'nitro-tensorrt-llm'
? 1 ? 1
: (model.sources?.length ?? 1)) : model.sources?.length ?? 1)
) )
}) })

View File

@ -1,6 +1,6 @@
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { mockModelProvider } from '@/mock/data' import { mockModelProvider } from '@/mock/data'
import { EngineManager, InferenceEngine, ModelManager } from '@janhq/core' import { EngineManager, ModelManager } from '@janhq/core'
import { ModelCapabilities } from '@/types/models' import { ModelCapabilities } from '@/types/models'
import { modelSettings } from '@/lib/predefined' import { modelSettings } from '@/lib/predefined'
@ -40,7 +40,8 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
const runtimeProviders: ModelProvider[] = [] const runtimeProviders: ModelProvider[] = []
for (const [key, value] of EngineManager.instance().engines) { for (const [key, value] of EngineManager.instance().engines) {
const providerName = key === InferenceEngine.cortex ? 'llama.cpp' : key // TODO: Remove this when the cortex extension is removed
const providerName = key === 'cortex' ? 'llama.cpp' : key
const models = const models =
Array.from(ModelManager.instance().models.values()).filter( Array.from(ModelManager.instance().models.values()).filter(
(model) => (model) =>
@ -70,7 +71,7 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
description: model.description, description: model.description,
capabilities: capabilities:
'capabilities' in model 'capabilities' in model
? model.capabilities as string[] ? (model.capabilities as string[])
: [ModelCapabilities.COMPLETION], : [ModelCapabilities.COMPLETION],
provider: providerName, provider: providerName,
settings: modelSettings, settings: modelSettings,

View File

@ -1,9 +1,5 @@
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
import { import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core'
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
} from '@janhq/core'
/** /**
* Fetches all threads from the conversational extension. * Fetches all threads from the conversational extension.
@ -52,16 +48,15 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
{ {
model: { model: {
id: thread.model?.id ?? '*', id: thread.model?.id ?? '*',
engine: (thread.model?.provider ?? engine: thread.model?.provider ?? 'llama.cpp',
'llama.cpp') as InferenceEngine,
}, },
assistant_id: 'jan', assistant_id: 'jan',
assistant_name: 'Jan', assistant_name: 'Jan',
}, },
], ],
metadata: { metadata: {
order: 1 order: 1,
} },
}) })
.then((e) => { .then((e) => {
return { return {
@ -71,7 +66,7 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
id: e.assistants?.[0]?.model.id, id: e.assistants?.[0]?.model.id,
provider: e.assistants?.[0]?.model.engine, provider: e.assistants?.[0]?.model.engine,
}, },
order: 1 order: 1,
} as Thread } as Thread
}) })
.catch(() => thread) ?? thread .catch(() => thread) ?? thread
@ -91,7 +86,7 @@ export const updateThread = (thread: Thread) => {
{ {
model: { model: {
id: thread.model?.id ?? '*', id: thread.model?.id ?? '*',
engine: (thread.model?.provider ?? 'llama.cpp') as InferenceEngine, engine: (thread.model?.provider ?? 'llama.cpp'),
}, },
assistant_id: 'jan', assistant_id: 'jan',
assistant_name: 'Jan', assistant_name: 'Jan',
@ -99,11 +94,11 @@ export const updateThread = (thread: Thread) => {
], ],
metadata: { metadata: {
is_favorite: thread.isFavorite, is_favorite: thread.isFavorite,
order: thread.order order: thread.order,
}, },
object: 'thread', object: 'thread',
created: Date.now()/ 1000, created: Date.now() / 1000,
updated: Date.now()/ 1000, updated: Date.now() / 1000,
}) })
} }