refactor: remove hardcoded provider names (#4995)

* refactor: remove hardcoded provider names

* chore: continue the replacement
This commit is contained in:
Louis 2025-05-15 22:10:43 +07:00 committed by GitHub
parent bf3f22c854
commit e9f37e98d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 49 additions and 114 deletions

View File

@ -1,4 +1,3 @@
import { InferenceEngine } from '../../../types'
import { AIEngine } from './AIEngine'
/**
@ -22,22 +21,6 @@ export class EngineManager {
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
// Backward compatible provider
// nitro is migrated to cortex
if (
[
InferenceEngine.nitro,
InferenceEngine.cortex,
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_onnx,
InferenceEngine.cortex_tensorrtllm,
InferenceEngine.cortex_onnx,
]
.map((e) => e.toString())
.includes(provider)
)
provider = InferenceEngine.cortex
return this.engines.get(provider) as T | undefined
}

View File

@ -1,5 +1,4 @@
import {
InferenceEngine,
Engines,
EngineVariant,
EngineReleased,
@ -28,7 +27,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @param name - Inference engine name.
* @returns A Promise that resolves to an array of installed engine.
*/
abstract getInstalledEngines(name: InferenceEngine): Promise<EngineVariant[]>
abstract getInstalledEngines(name: string): Promise<EngineVariant[]>
/**
* @param name - Inference engine name.
@ -37,7 +36,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an array of latest released engine by version.
*/
abstract getReleasedEnginesByVersion(
name: InferenceEngine,
name: string,
version: string,
platform?: string
): Promise<EngineReleased[]>
@ -48,7 +47,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an array of latest released engine.
*/
abstract getLatestReleasedEngine(
name: InferenceEngine,
name: string,
platform?: string
): Promise<EngineReleased[]>
@ -74,7 +73,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to unintall of engine.
*/
abstract uninstallEngine(
name: InferenceEngine,
name: string,
engineConfig: EngineConfig
): Promise<{ messages: string }>
@ -83,7 +82,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to an object of default engine.
*/
abstract getDefaultEngineVariant(
name: InferenceEngine
name: string
): Promise<DefaultEngineVariant>
/**
@ -92,7 +91,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to set default engine.
*/
abstract setDefaultEngineVariant(
name: InferenceEngine,
name: string,
engineConfig: EngineConfig
): Promise<{ messages: string }>
@ -100,7 +99,7 @@ export abstract class EngineManagementExtension extends BaseExtension {
* @returns A Promise that resolves to update engine.
*/
abstract updateEngine(
name: InferenceEngine,
name: string,
engineConfig?: EngineConfig
): Promise<{ messages: string }>
@ -112,5 +111,5 @@ export abstract class EngineManagementExtension extends BaseExtension {
/**
* @returns A Promise that resolves to an object of remote models list .
*/
abstract getRemoteModels(name: InferenceEngine | string): Promise<any>
abstract getRemoteModels(name: string): Promise<any>
}

View File

@ -1,7 +1,5 @@
import { InferenceEngine } from '../../types'
export type Engines = {
[key in InferenceEngine]: (EngineVariant & EngineConfig)[]
[key: string]: (EngineVariant & EngineConfig)[]
}
export type EngineMetadata = {
@ -22,13 +20,13 @@ export type EngineMetadata = {
}
export type EngineVariant = {
engine: InferenceEngine
engine: string
name: string
version: string
}
export type DefaultEngineVariant = {
engine: InferenceEngine
engine: string
variant: string
version: string
}

View File

@ -6,29 +6,7 @@ export type ModelInfo = {
id: string
settings?: ModelSettingParams
parameters?: ModelRuntimeParams
engine?: InferenceEngine
}
/**
* Represents the inference engine.
* @stored
*/
export enum InferenceEngine {
anthropic = 'anthropic',
mistral = 'mistral',
martian = 'martian',
openrouter = 'openrouter',
nitro = 'nitro',
openai = 'openai',
groq = 'groq',
triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm',
cohere = 'cohere',
nvidia = 'nvidia',
cortex = 'cortex',
cortex_llamacpp = 'llama-cpp',
cortex_onnx = 'onnxruntime',
cortex_tensorrtllm = 'tensorrt-llm',
engine?: string
}
// Represents an artifact of a model, including its filename and URL
@ -105,7 +83,7 @@ export type Model = {
/**
* The model engine.
*/
engine: InferenceEngine
engine: string
}
// Represents metadata associated with a model

View File

@ -1,6 +1,5 @@
import {
EngineManagementExtension,
InferenceEngine,
DefaultEngineVariant,
Engines,
EngineConfig,
@ -35,7 +34,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
*/
async apiInstance(): Promise<KyInstance> {
if (this.api) return this.api
const apiKey = (await window.core?.api.appToken())
const apiKey = await window.core?.api.appToken()
this.api = ky.extend({
prefixUrl: API_URL,
headers: apiKey
@ -96,7 +95,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name.
* @returns A Promise that resolves to an array of installed engine.
*/
async getInstalledEngines(name: InferenceEngine): Promise<EngineVariant[]> {
async getInstalledEngines(name: string): Promise<EngineVariant[]> {
return this.apiInstance().then((api) =>
api
.get(`v1/engines/${name}`)
@ -112,7 +111,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @returns A Promise that resolves to an array of latest released engine by version.
*/
async getReleasedEnginesByVersion(
name: InferenceEngine,
name: string,
version: string,
platform?: string
) {
@ -131,7 +130,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param platform - Optional to sort by operating system. macOS, linux, windows.
* @returns A Promise that resolves to an array of latest released engine by version.
*/
async getLatestReleasedEngine(name: InferenceEngine, platform?: string) {
async getLatestReleasedEngine(name: string, platform?: string) {
return this.apiInstance().then((api) =>
api
.get(`v1/engines/${name}/releases/latest`)
@ -197,7 +196,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name.
* @returns A Promise that resolves to unintall of engine.
*/
async uninstallEngine(name: InferenceEngine, engineConfig: EngineConfig) {
async uninstallEngine(name: string, engineConfig: EngineConfig) {
return this.apiInstance().then((api) =>
api
.delete(`v1/engines/${name}/install`, { json: engineConfig })
@ -234,7 +233,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @param name - Inference engine name.
* @returns A Promise that resolves to an object of default engine.
*/
async getDefaultEngineVariant(name: InferenceEngine) {
async getDefaultEngineVariant(name: string) {
return this.apiInstance().then((api) =>
api
.get(`v1/engines/${name}/default`)
@ -248,10 +247,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
* @body version - string
* @returns A Promise that resolves to set default engine.
*/
async setDefaultEngineVariant(
name: InferenceEngine,
engineConfig: EngineConfig
) {
async setDefaultEngineVariant(name: string, engineConfig: EngineConfig) {
return this.apiInstance().then((api) =>
api
.post(`v1/engines/${name}/default`, { json: engineConfig })
@ -262,7 +258,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
/**
* @returns A Promise that resolves to update engine.
*/
async updateEngine(name: InferenceEngine, engineConfig?: EngineConfig) {
async updateEngine(name: string, engineConfig?: EngineConfig) {
return this.apiInstance().then((api) =>
api
.post(`v1/engines/${name}/update`, { json: engineConfig })
@ -276,12 +272,8 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
*/
async updateDefaultEngine() {
try {
const variant = await this.getDefaultEngineVariant(
InferenceEngine.cortex_llamacpp
)
const installedEngines = await this.getInstalledEngines(
InferenceEngine.cortex_llamacpp
)
const variant = await this.getDefaultEngineVariant('llama-cpp')
const installedEngines = await this.getInstalledEngines('llama-cpp')
if (
!installedEngines.some(
(e) => e.name === variant.variant && e.version === variant.version
@ -299,7 +291,8 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
) {
const systemInfo = await systemInformation()
const variant = await engineVariant(systemInfo.gpuSetting)
await this.setDefaultEngineVariant(InferenceEngine.cortex_llamacpp, {
// TODO: Use correct provider name when moving to llama.cpp extension
await this.setDefaultEngineVariant('llama-cpp', {
variant: variant,
version: `${CORTEX_ENGINE_VERSION}`,
})
@ -368,7 +361,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens
models.data.map((model) =>
this.addRemoteModel({
...model,
engine: engineConfig.engine as InferenceEngine,
engine: engineConfig.engine,
model: model.model ?? model.id,
}).catch(console.info)
)

View File

@ -11,7 +11,6 @@ import {
executeOnMain,
EngineEvent,
LocalOAIEngine,
InferenceEngine,
extractModelLoadParams,
events,
ModelEvent,
@ -49,7 +48,7 @@ type LoadedModelResponse = { data: { engine: string; id: string }[] }
export default class JanInferenceCortexExtension extends LocalOAIEngine {
nodeModule: string = 'node'
provider: string = InferenceEngine.cortex
provider: string = 'cortex'
shouldReconnect = true
@ -198,8 +197,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
...extractModelLoadParams(model.settings),
model: model.id,
engine:
model.engine === InferenceEngine.nitro // Legacy model cache
? InferenceEngine.cortex_llamacpp
model.engine === "nitro" // Legacy model cache
? "llama-cpp"
: model.engine,
cont_batching: this.cont_batching,
n_parallel: this.n_parallel,

View File

@ -1,7 +1,6 @@
import {
ModelExtension,
Model,
InferenceEngine,
joinPath,
dirName,
fs,
@ -37,7 +36,7 @@ export default class JanModelExtension extends ModelExtension {
*/
async apiInstance(): Promise<KyInstance> {
if (this.api) return this.api
const apiKey = (await window.core?.api.appToken())
const apiKey = await window.core?.api.appToken()
this.api = ky.extend({
prefixUrl: CORTEX_API_URL,
headers: apiKey
@ -45,7 +44,7 @@ export default class JanModelExtension extends ModelExtension {
Authorization: `Bearer ${apiKey}`,
}
: {},
retry: 10
retry: 10,
})
return this.api
}
@ -153,9 +152,7 @@ export default class JanModelExtension extends ModelExtension {
* Here we are filtering out the models that are not imported
* and are not using llama.cpp engine
*/
var toImportModels = legacyModels.filter(
(e) => e.engine === InferenceEngine.nitro
)
var toImportModels = legacyModels.filter((e) => e.engine === 'nitro')
/**
* Fetch models from cortex.cpp

View File

@ -1,13 +1,5 @@
import { InferenceEngine, Model, fs, joinPath } from '@janhq/core'
import { Model, fs, joinPath } from '@janhq/core'
//// LEGACY MODEL FOLDER ////
const LocalEngines = [
InferenceEngine.cortex,
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_tensorrtllm,
InferenceEngine.cortex_onnx,
InferenceEngine.nitro_tensorrt_llm,
InferenceEngine.nitro,
]
/**
* Scan through models folder and return downloaded models
* @returns
@ -68,7 +60,7 @@ export const scanModelsFolder = async (): Promise<
)
)
if (
!LocalEngines.includes(model.engine) ||
!['cortex', 'llama-cpp', 'nitro'].includes(model.engine) ||
existFiles.every((exist) => exist)
)
return model
@ -86,9 +78,9 @@ export const scanModelsFolder = async (): Promise<
file.toLowerCase().endsWith('.engine') // Tensort-LLM
)
})?.length >=
(model.engine === InferenceEngine.nitro_tensorrt_llm
(model.engine === 'nitro-tensorrt-llm'
? 1
: (model.sources?.length ?? 1))
: model.sources?.length ?? 1)
)
})

View File

@ -1,6 +1,6 @@
import { models as providerModels } from 'token.js'
import { mockModelProvider } from '@/mock/data'
import { EngineManager, InferenceEngine, ModelManager } from '@janhq/core'
import { EngineManager, ModelManager } from '@janhq/core'
import { ModelCapabilities } from '@/types/models'
import { modelSettings } from '@/lib/predefined'
@ -40,7 +40,8 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
const runtimeProviders: ModelProvider[] = []
for (const [key, value] of EngineManager.instance().engines) {
const providerName = key === InferenceEngine.cortex ? 'llama.cpp' : key
// TODO: Remove this when the cortex extension is removed
const providerName = key === 'cortex' ? 'llama.cpp' : key
const models =
Array.from(ModelManager.instance().models.values()).filter(
(model) =>
@ -70,7 +71,7 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
description: model.description,
capabilities:
'capabilities' in model
? model.capabilities as string[]
? (model.capabilities as string[])
: [ModelCapabilities.COMPLETION],
provider: providerName,
settings: modelSettings,

View File

@ -1,9 +1,5 @@
import { ExtensionManager } from '@/lib/extension'
import {
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
} from '@janhq/core'
import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core'
/**
* Fetches all threads from the conversational extension.
@ -52,16 +48,15 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
{
model: {
id: thread.model?.id ?? '*',
engine: (thread.model?.provider ??
'llama.cpp') as InferenceEngine,
engine: thread.model?.provider ?? 'llama.cpp',
},
assistant_id: 'jan',
assistant_name: 'Jan',
},
],
metadata: {
order: 1
}
order: 1,
},
})
.then((e) => {
return {
@ -71,7 +66,7 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
id: e.assistants?.[0]?.model.id,
provider: e.assistants?.[0]?.model.engine,
},
order: 1
order: 1,
} as Thread
})
.catch(() => thread) ?? thread
@ -91,7 +86,7 @@ export const updateThread = (thread: Thread) => {
{
model: {
id: thread.model?.id ?? '*',
engine: (thread.model?.provider ?? 'llama.cpp') as InferenceEngine,
engine: (thread.model?.provider ?? 'llama.cpp'),
},
assistant_id: 'jan',
assistant_name: 'Jan',
@ -99,11 +94,11 @@ export const updateThread = (thread: Thread) => {
],
metadata: {
is_favorite: thread.isFavorite,
order: thread.order
order: thread.order,
},
object: 'thread',
created: Date.now()/ 1000,
updated: Date.now()/ 1000,
created: Date.now() / 1000,
updated: Date.now() / 1000,
})
}