fix: model import edge cases

This commit is contained in:
Louis 2024-11-19 11:44:19 +07:00
parent 04dd8367a1
commit 363008d37f
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
5 changed files with 40 additions and 51 deletions

View File

@ -19,6 +19,7 @@ import {
events, events,
ModelEvent, ModelEvent,
SystemInformation, SystemInformation,
dirName,
} from '@janhq/core' } from '@janhq/core'
import PQueue from 'p-queue' import PQueue from 'p-queue'
import ky from 'ky' import ky from 'ky'
@ -99,10 +100,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// Legacy chat model support // Legacy chat model support
model.settings = { model.settings = {
...model.settings, ...model.settings,
llama_model_path: await getModelFilePath( llama_model_path: model.file_path
model, ? await joinPath([
model.settings.llama_model_path await dirName(model.file_path),
), model.settings.llama_model_path,
])
: await getModelFilePath(model, model.settings.llama_model_path),
} }
} else { } else {
const { llama_model_path, ...settings } = model.settings const { llama_model_path, ...settings } = model.settings
@ -168,7 +171,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
* Set default engine variant on launch * Set default engine variant on launch
*/ */
private async setDefaultEngine(systemInfo: SystemInformation) { private async setDefaultEngine(systemInfo: SystemInformation) {
const variant = await executeOnMain(NODE, 'engineVariant', systemInfo.gpuSetting) const variant = await executeOnMain(
NODE,
'engineVariant',
systemInfo.gpuSetting
)
return ky return ky
.post( .post(
`${CORTEX_API_URL}/v1/engines/${InferenceEngine.cortex_llamacpp}/default?version=${CORTEX_ENGINE_VERSION}&variant=${variant}`, `${CORTEX_API_URL}/v1/engines/${InferenceEngine.cortex_llamacpp}/default?version=${CORTEX_ENGINE_VERSION}&variant=${variant}`,

View File

@ -20,13 +20,6 @@ import { deleteModelFiles } from './legacy/delete'
declare const SETTINGS: Array<any> declare const SETTINGS: Array<any>
/**
* Extension enum
*/
enum ExtensionEnum {
downloadedModels = 'downloadedModels',
}
/** /**
* A extension for models * A extension for models
*/ */
@ -122,39 +115,16 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves with an array of all models. * @returns A Promise that resolves with an array of all models.
*/ */
async getModels(): Promise<Model[]> { async getModels(): Promise<Model[]> {
/**
* In this action, if return empty array right away
* it would reset app cache and app will not function properly
* should compare and try import
*/
let currentModels: Model[] = []
/** /**
* Legacy models should be supported * Legacy models should be supported
*/ */
let legacyModels = await scanModelsFolder() let legacyModels = await scanModelsFolder()
try {
if (!localStorage.getItem(ExtensionEnum.downloadedModels)) {
// Updated from an older version than 0.5.5
// Scan through the models folder and import them (Legacy flow)
// Return models immediately
currentModels = legacyModels
} else {
currentModels = JSON.parse(
localStorage.getItem(ExtensionEnum.downloadedModels)
) as Model[]
}
} catch (e) {
currentModels = []
console.error(e)
}
/** /**
* Here we are filtering out the models that are not imported * Here we are filtering out the models that are not imported
* and are not using llama.cpp engine * and are not using llama.cpp engine
*/ */
var toImportModels = currentModels.filter( var toImportModels = legacyModels.filter(
(e) => e.engine === InferenceEngine.nitro (e) => e.engine === InferenceEngine.nitro
) )
@ -196,13 +166,17 @@ export default class JanModelExtension extends ModelExtension {
]) // Copied models ]) // Copied models
: model.sources[0].url, // Symlink models, : model.sources[0].url, // Symlink models,
model.name model.name
).then((e) => { )
this.updateModel({ .then((e) => {
id: model.id, this.updateModel({
...model.settings, id: model.id,
...model.parameters, ...model.settings,
} as Partial<Model>) ...model.parameters,
}) } as Partial<Model>)
})
.catch((e) => {
console.debug(e)
})
}) })
) )
} }

View File

@ -1,10 +1,12 @@
import { fs, joinPath } from '@janhq/core' import { dirName, fs } from '@janhq/core'
import { scanModelsFolder } from './model-json'
export const deleteModelFiles = async (id: string) => { export const deleteModelFiles = async (id: string) => {
try { try {
const dirPath = await joinPath(['file://models', id]) const models = await scanModelsFolder()
const dirPath = models.find((e) => e.id === id)?.file_path
// remove model folder directory // remove model folder directory
await fs.rm(dirPath) if (dirPath) await fs.rm(await dirName(dirPath))
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }

View File

@ -12,7 +12,9 @@ const LocalEngines = [
* Scan through models folder and return downloaded models * Scan through models folder and return downloaded models
* @returns * @returns
*/ */
export const scanModelsFolder = async (): Promise<Model[]> => { export const scanModelsFolder = async (): Promise<
(Model & { file_path?: string })[]
> => {
const _homeDir = 'file://models' const _homeDir = 'file://models'
try { try {
if (!(await fs.existsSync(_homeDir))) { if (!(await fs.existsSync(_homeDir))) {
@ -37,7 +39,7 @@ export const scanModelsFolder = async (): Promise<Model[]> => {
const jsonPath = await getModelJsonPath(folderFullPath) const jsonPath = await getModelJsonPath(folderFullPath)
if (await fs.existsSync(jsonPath)) { if (jsonPath && (await fs.existsSync(jsonPath))) {
// if we have the model.json file, read it // if we have the model.json file, read it
let model = await fs.readFileSync(jsonPath, 'utf-8') let model = await fs.readFileSync(jsonPath, 'utf-8')
@ -83,7 +85,10 @@ export const scanModelsFolder = async (): Promise<Model[]> => {
file.toLowerCase().endsWith('.gguf') || // GGUF file.toLowerCase().endsWith('.gguf') || // GGUF
file.toLowerCase().endsWith('.engine') // Tensort-LLM file.toLowerCase().endsWith('.engine') // Tensort-LLM
) )
})?.length >= (model.engine === InferenceEngine.nitro_tensorrt_llm ? 1 : (model.sources?.length ?? 1)) })?.length >=
(model.engine === InferenceEngine.nitro_tensorrt_llm
? 1
: (model.sources?.length ?? 1))
) )
}) })

View File

@ -34,7 +34,7 @@ const useModels = () => {
const getDownloadedModels = async () => { const getDownloadedModels = async () => {
const localModels = (await getModels()).map((e) => ({ const localModels = (await getModels()).map((e) => ({
...e, ...e,
name: ModelManager.instance().models.get(e.id)?.name ?? e.id, name: ModelManager.instance().models.get(e.id)?.name ?? e.name ?? e.id,
metadata: metadata:
ModelManager.instance().models.get(e.id)?.metadata ?? e.metadata, ModelManager.instance().models.get(e.id)?.metadata ?? e.metadata,
})) }))
@ -92,7 +92,8 @@ const useModels = () => {
const getModels = async (): Promise<Model[]> => const getModels = async (): Promise<Model[]> =>
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.getModels() ?? [] ?.getModels()
.catch(() => []) ?? []
useEffect(() => { useEffect(() => {
// Listen for model updates // Listen for model updates