diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index f3609b3b2..b237fad9d 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -13,9 +13,9 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter } abstract getModels(): Promise - abstract pullModel(model: string, id?: string): Promise + abstract pullModel(model: string, id?: string, name?: string): Promise abstract cancelModelPull(modelId: string): Promise - abstract importModel(model: string, modePath: string): Promise + abstract importModel(model: string, modePath: string, name?: string): Promise abstract updateModel(modelInfo: Partial): Promise abstract deleteModel(model: string): Promise } diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index b676db949..c35bae9ce 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -9,7 +9,7 @@ export interface ModelInterface { * @param model - The model to download. * @returns A Promise that resolves when the model has been downloaded. */ - pullModel(model: string, id?: string): Promise + pullModel(model: string, id?: string, name?: string): Promise /** * Cancels the download of a specific model. @@ -43,5 +43,5 @@ export interface ModelInterface { * @param model id of the model to import * @param modelPath - path of the model file */ - importModel(model: string, modePath: string): Promise + importModel(model: string, modePath: string, name?: string): Promise } diff --git a/extensions/inference-cortex-extension/bin/version.txt b/extensions/inference-cortex-extension/bin/version.txt index a9d40871b..57d77db55 100644 --- a/extensions/inference-cortex-extension/bin/version.txt +++ b/extensions/inference-cortex-extension/bin/version.txt @@ -1 +1 @@ -1.0.2-rc1 \ No newline at end of file +1.0.2-rc2 \ No newline at end of file diff --git a/extensions/model-extension/src/cortex.ts b/extensions/model-extension/src/cortex.ts index ca9c2b921..50eace5e5 100644 --- a/extensions/model-extension/src/cortex.ts +++ b/extensions/model-extension/src/cortex.ts @@ -8,8 +8,8 @@ import { extractInferenceParams } from '@janhq/core' interface ICortexAPI { getModel(model: string): Promise getModels(): Promise - pullModel(model: string, id?: string): Promise - importModel(path: string, modelPath: string): Promise + pullModel(model: string, id?: string, name?: string): Promise + importModel(path: string, modelPath: string, name?: string): Promise deleteModel(model: string): Promise updateModel(model: object): Promise cancelModelPull(model: string): Promise @@ -68,10 +68,10 @@ export class CortexAPI implements ICortexAPI { * @param model * @returns */ - pullModel(model: string, id?: string): Promise { + pullModel(model: string, id?: string, name?: string): Promise { return this.queue.add(() => ky - .post(`${API_URL}/v1/models/pull`, { json: { model, id } }) + .post(`${API_URL}/v1/models/pull`, { json: { model, id, name } }) .json() .catch(async (e) => { throw (await e.response?.json()) ?? e @@ -85,10 +85,10 @@ export class CortexAPI implements ICortexAPI { * @param model * @returns */ - importModel(model: string, modelPath: string): Promise { + importModel(model: string, modelPath: string, name?: string): Promise { return this.queue.add(() => ky - .post(`${API_URL}/v1/models/import`, { json: { model, modelPath } }) + .post(`${API_URL}/v1/models/import`, { json: { model, modelPath, name } }) .json() .catch((e) => console.debug(e)) // Ignore error .then() diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 439481bc4..17c00263d 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -58,7 +58,7 @@ export default class JanModelExtension extends ModelExtension { * @param model - The model to download. * @returns A Promise that resolves when the model is downloaded. */ - async pullModel(model: string, id?: string): Promise { + async pullModel(model: string, id?: string, name?: string): Promise { if (id) { const model: Model = ModelManager.instance().get(id) // Clip vision model - should not be handled by cortex.cpp @@ -74,7 +74,7 @@ export default class JanModelExtension extends ModelExtension { /** * Sending POST to /models/pull/{id} endpoint to pull the model */ - return this.cortexAPI.pullModel(model, id) + return this.cortexAPI.pullModel(model, id, name) } /** @@ -111,14 +111,12 @@ export default class JanModelExtension extends ModelExtension { * @returns A Promise that resolves when the model is deleted. */ async deleteModel(model: string): Promise { - const modelDto: Model = ModelManager.instance().get(model) return this.cortexAPI .deleteModel(model) .catch((e) => console.debug(e)) .finally(async () => { // Delete legacy model files - if (modelDto) - await deleteModelFiles(modelDto).catch((e) => console.debug(e)) + await deleteModelFiles(model).catch((e) => console.debug(e)) }) } @@ -227,8 +225,12 @@ export default class JanModelExtension extends ModelExtension { * @param model * @param optionType */ - async importModel(model: string, modelPath: string): Promise { - return this.cortexAPI.importModel(model, modelPath) + async importModel( + model: string, + modelPath: string, + name?: string + ): Promise { + return this.cortexAPI.importModel(model, modelPath, name) } /** diff --git a/extensions/model-extension/src/legacy/delete.ts b/extensions/model-extension/src/legacy/delete.ts index 039eab4cf..5288e30ee 100644 --- a/extensions/model-extension/src/legacy/delete.ts +++ b/extensions/model-extension/src/legacy/delete.ts @@ -1,10 +1,10 @@ -import { fs, joinPath, Model } from '@janhq/core' +import { fs, joinPath } from '@janhq/core' -export const deleteModelFiles = async (model: Model) => { +export const deleteModelFiles = async (id: string) => { try { - const dirPath = await joinPath(['file://models', model.id]) + const dirPath = await joinPath(['file://models', id]) // remove model folder directory - await fs.unlinkSync(dirPath) + await fs.rm(dirPath) } catch (err) { console.error(err) } diff --git a/web/hooks/useDownloadModel.test.ts b/web/hooks/useDownloadModel.test.ts index ff75fbcd8..7e9d7b518 100644 --- a/web/hooks/useDownloadModel.test.ts +++ b/web/hooks/useDownloadModel.test.ts @@ -40,7 +40,8 @@ describe('useDownloadModel', () => { expect(mockExtension.pullModel).toHaveBeenCalledWith( mockModel.sources[0].url, - mockModel.id + mockModel.id, + undefined ) }) @@ -87,7 +88,8 @@ describe('useDownloadModel', () => { expect(mockExtension.pullModel).toHaveBeenCalledWith( mockModel.sources[0].url, - mockModel.id + mockModel.id, + undefined ) }) }) diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts index 3b25cb86f..bbf03e2e7 100644 --- a/web/hooks/useDownloadModel.ts +++ b/web/hooks/useDownloadModel.ts @@ -18,9 +18,9 @@ export default function useDownloadModel() { const addDownloadingModel = useSetAtom(addDownloadingModelAtom) const downloadModel = useCallback( - async (model: string, id?: string) => { + async (model: string, id?: string, name?: string) => { addDownloadingModel(id ?? model) - downloadLocalModel(model, id).catch((error) => { + downloadLocalModel(model, id, name).catch((error) => { if (error.message) { toaster({ title: 'Download failed', @@ -45,10 +45,10 @@ export default function useDownloadModel() { } } -const downloadLocalModel = async (model: string, id?: string) => +const downloadLocalModel = async (model: string, id?: string, name?: string) => extensionManager .get(ExtensionTypeEnum.Model) - ?.pullModel(model, id) + ?.pullModel(model, id, name) const cancelModelDownload = async (model: string) => extensionManager diff --git a/web/hooks/useImportModel.test.ts b/web/hooks/useImportModel.test.ts index d37e4a853..9b623226d 100644 --- a/web/hooks/useImportModel.test.ts +++ b/web/hooks/useImportModel.test.ts @@ -34,8 +34,8 @@ describe('useImportModel', () => { await result.current.importModels(models, 'local' as any) }) - expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1') - expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2') + expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1', undefined) + expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2', undefined) }) it('should update model info successfully', async () => { diff --git a/web/hooks/useImportModel.ts b/web/hooks/useImportModel.ts index 951e93bef..b8f64db98 100644 --- a/web/hooks/useImportModel.ts +++ b/web/hooks/useImportModel.ts @@ -66,7 +66,7 @@ const useImportModel = () => { addDownloadingModel(modelId) extensionManager .get(ExtensionTypeEnum.Model) - ?.importModel(model.modelId, model.path) + ?.importModel(model.modelId, model.path, model.name) .finally(() => removeDownloadingModel(modelId)) } }) diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx index 725b0216a..da98e41e3 100644 --- a/web/screens/Hub/ModelList/ModelHeader/index.tsx +++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx @@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => { const assistants = useAtomValue(assistantsAtom) const onDownloadClick = useCallback(() => { - downloadModel(model.sources[0].url, model.id) + downloadModel(model.sources[0].url, model.id, model.name) }, [model, downloadModel]) const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index bd9f67ebb..dbd2798b7 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -63,7 +63,11 @@ const ModelDownloadRow: React.FC = ({ const onDownloadClick = useCallback(async () => { if (downloadUrl) { - downloadModel(downloadUrl, normalizeModelId(downloadUrl)) + downloadModel( + downloadUrl, + normalizeModelId(downloadUrl), + normalizeModelId(downloadUrl) + ) } }, [downloadUrl, downloadModel]) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx index 366575a40..0b999c19d 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx @@ -170,7 +170,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { onClick={() => downloadModel( model.sources[0].url, - model.id + model.id, + model.name ) } /> @@ -261,7 +262,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { onClick={() => downloadModel( featModel.sources[0].url, - featModel.id + featModel.id, + featModel.name ) } >