chore: decide model name on pull and import

This commit is contained in:
Louis 2024-11-01 16:35:55 +07:00
parent d0ffe6c611
commit a986c6de2d
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
13 changed files with 45 additions and 35 deletions

View File

@ -13,9 +13,9 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
} }
abstract getModels(): Promise<Model[]> abstract getModels(): Promise<Model[]>
abstract pullModel(model: string, id?: string): Promise<void> abstract pullModel(model: string, id?: string, name?: string): Promise<void>
abstract cancelModelPull(modelId: string): Promise<void> abstract cancelModelPull(modelId: string): Promise<void>
abstract importModel(model: string, modePath: string): Promise<void> abstract importModel(model: string, modePath: string, name?: string): Promise<void>
abstract updateModel(modelInfo: Partial<Model>): Promise<Model> abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
abstract deleteModel(model: string): Promise<void> abstract deleteModel(model: string): Promise<void>
} }

View File

@ -9,7 +9,7 @@ export interface ModelInterface {
* @param model - The model to download. * @param model - The model to download.
* @returns A Promise that resolves when the model has been downloaded. * @returns A Promise that resolves when the model has been downloaded.
*/ */
pullModel(model: string, id?: string): Promise<void> pullModel(model: string, id?: string, name?: string): Promise<void>
/** /**
* Cancels the download of a specific model. * Cancels the download of a specific model.
@ -43,5 +43,5 @@ export interface ModelInterface {
* @param model id of the model to import * @param model id of the model to import
* @param modelPath - path of the model file * @param modelPath - path of the model file
*/ */
importModel(model: string, modePath: string): Promise<void> importModel(model: string, modePath: string, name?: string): Promise<void>
} }

View File

@ -1 +1 @@
1.0.2-rc1 1.0.2-rc2

View File

@ -8,8 +8,8 @@ import { extractInferenceParams } from '@janhq/core'
interface ICortexAPI { interface ICortexAPI {
getModel(model: string): Promise<Model> getModel(model: string): Promise<Model>
getModels(): Promise<Model[]> getModels(): Promise<Model[]>
pullModel(model: string, id?: string): Promise<void> pullModel(model: string, id?: string, name?: string): Promise<void>
importModel(path: string, modelPath: string): Promise<void> importModel(path: string, modelPath: string, name?: string): Promise<void>
deleteModel(model: string): Promise<void> deleteModel(model: string): Promise<void>
updateModel(model: object): Promise<void> updateModel(model: object): Promise<void>
cancelModelPull(model: string): Promise<void> cancelModelPull(model: string): Promise<void>
@ -68,10 +68,10 @@ export class CortexAPI implements ICortexAPI {
* @param model * @param model
* @returns * @returns
*/ */
pullModel(model: string, id?: string): Promise<void> { pullModel(model: string, id?: string, name?: string): Promise<void> {
return this.queue.add(() => return this.queue.add(() =>
ky ky
.post(`${API_URL}/v1/models/pull`, { json: { model, id } }) .post(`${API_URL}/v1/models/pull`, { json: { model, id, name } })
.json() .json()
.catch(async (e) => { .catch(async (e) => {
throw (await e.response?.json()) ?? e throw (await e.response?.json()) ?? e
@ -85,10 +85,10 @@ export class CortexAPI implements ICortexAPI {
* @param model * @param model
* @returns * @returns
*/ */
importModel(model: string, modelPath: string): Promise<void> { importModel(model: string, modelPath: string, name?: string): Promise<void> {
return this.queue.add(() => return this.queue.add(() =>
ky ky
.post(`${API_URL}/v1/models/import`, { json: { model, modelPath } }) .post(`${API_URL}/v1/models/import`, { json: { model, modelPath, name } })
.json() .json()
.catch((e) => console.debug(e)) // Ignore error .catch((e) => console.debug(e)) // Ignore error
.then() .then()

View File

@ -58,7 +58,7 @@ export default class JanModelExtension extends ModelExtension {
* @param model - The model to download. * @param model - The model to download.
* @returns A Promise that resolves when the model is downloaded. * @returns A Promise that resolves when the model is downloaded.
*/ */
async pullModel(model: string, id?: string): Promise<void> { async pullModel(model: string, id?: string, name?: string): Promise<void> {
if (id) { if (id) {
const model: Model = ModelManager.instance().get(id) const model: Model = ModelManager.instance().get(id)
// Clip vision model - should not be handled by cortex.cpp // Clip vision model - should not be handled by cortex.cpp
@ -74,7 +74,7 @@ export default class JanModelExtension extends ModelExtension {
/** /**
* Sending POST to /models/pull/{id} endpoint to pull the model * Sending POST to /models/pull/{id} endpoint to pull the model
*/ */
return this.cortexAPI.pullModel(model, id) return this.cortexAPI.pullModel(model, id, name)
} }
/** /**
@ -111,14 +111,12 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is deleted. * @returns A Promise that resolves when the model is deleted.
*/ */
async deleteModel(model: string): Promise<void> { async deleteModel(model: string): Promise<void> {
const modelDto: Model = ModelManager.instance().get(model)
return this.cortexAPI return this.cortexAPI
.deleteModel(model) .deleteModel(model)
.catch((e) => console.debug(e)) .catch((e) => console.debug(e))
.finally(async () => { .finally(async () => {
// Delete legacy model files // Delete legacy model files
if (modelDto) await deleteModelFiles(model).catch((e) => console.debug(e))
await deleteModelFiles(modelDto).catch((e) => console.debug(e))
}) })
} }
@ -227,8 +225,12 @@ export default class JanModelExtension extends ModelExtension {
* @param model * @param model
* @param optionType * @param optionType
*/ */
async importModel(model: string, modelPath: string): Promise<void> { async importModel(
return this.cortexAPI.importModel(model, modelPath) model: string,
modelPath: string,
name?: string
): Promise<void> {
return this.cortexAPI.importModel(model, modelPath, name)
} }
/** /**

View File

@ -1,10 +1,10 @@
import { fs, joinPath, Model } from '@janhq/core' import { fs, joinPath } from '@janhq/core'
export const deleteModelFiles = async (model: Model) => { export const deleteModelFiles = async (id: string) => {
try { try {
const dirPath = await joinPath(['file://models', model.id]) const dirPath = await joinPath(['file://models', id])
// remove model folder directory // remove model folder directory
await fs.unlinkSync(dirPath) await fs.rm(dirPath)
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }

View File

@ -40,7 +40,8 @@ describe('useDownloadModel', () => {
expect(mockExtension.pullModel).toHaveBeenCalledWith( expect(mockExtension.pullModel).toHaveBeenCalledWith(
mockModel.sources[0].url, mockModel.sources[0].url,
mockModel.id mockModel.id,
undefined
) )
}) })
@ -87,7 +88,8 @@ describe('useDownloadModel', () => {
expect(mockExtension.pullModel).toHaveBeenCalledWith( expect(mockExtension.pullModel).toHaveBeenCalledWith(
mockModel.sources[0].url, mockModel.sources[0].url,
mockModel.id mockModel.id,
undefined
) )
}) })
}) })

View File

@ -18,9 +18,9 @@ export default function useDownloadModel() {
const addDownloadingModel = useSetAtom(addDownloadingModelAtom) const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
const downloadModel = useCallback( const downloadModel = useCallback(
async (model: string, id?: string) => { async (model: string, id?: string, name?: string) => {
addDownloadingModel(id ?? model) addDownloadingModel(id ?? model)
downloadLocalModel(model, id).catch((error) => { downloadLocalModel(model, id, name).catch((error) => {
if (error.message) { if (error.message) {
toaster({ toaster({
title: 'Download failed', title: 'Download failed',
@ -45,10 +45,10 @@ export default function useDownloadModel() {
} }
} }
const downloadLocalModel = async (model: string, id?: string) => const downloadLocalModel = async (model: string, id?: string, name?: string) =>
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.pullModel(model, id) ?.pullModel(model, id, name)
const cancelModelDownload = async (model: string) => const cancelModelDownload = async (model: string) =>
extensionManager extensionManager

View File

@ -34,8 +34,8 @@ describe('useImportModel', () => {
await result.current.importModels(models, 'local' as any) await result.current.importModels(models, 'local' as any)
}) })
expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1') expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1', undefined)
expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2') expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2', undefined)
}) })
it('should update model info successfully', async () => { it('should update model info successfully', async () => {

View File

@ -66,7 +66,7 @@ const useImportModel = () => {
addDownloadingModel(modelId) addDownloadingModel(modelId)
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.importModel(model.modelId, model.path) ?.importModel(model.modelId, model.path, model.name)
.finally(() => removeDownloadingModel(modelId)) .finally(() => removeDownloadingModel(modelId))
} }
}) })

View File

@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
const assistants = useAtomValue(assistantsAtom) const assistants = useAtomValue(assistantsAtom)
const onDownloadClick = useCallback(() => { const onDownloadClick = useCallback(() => {
downloadModel(model.sources[0].url, model.id) downloadModel(model.sources[0].url, model.id, model.name)
}, [model, downloadModel]) }, [model, downloadModel])
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null

View File

@ -63,7 +63,11 @@ const ModelDownloadRow: React.FC<Props> = ({
const onDownloadClick = useCallback(async () => { const onDownloadClick = useCallback(async () => {
if (downloadUrl) { if (downloadUrl) {
downloadModel(downloadUrl, normalizeModelId(downloadUrl)) downloadModel(
downloadUrl,
normalizeModelId(downloadUrl),
normalizeModelId(downloadUrl)
)
} }
}, [downloadUrl, downloadModel]) }, [downloadUrl, downloadModel])

View File

@ -170,7 +170,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
onClick={() => onClick={() =>
downloadModel( downloadModel(
model.sources[0].url, model.sources[0].url,
model.id model.id,
model.name
) )
} }
/> />
@ -261,7 +262,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
onClick={() => onClick={() =>
downloadModel( downloadModel(
featModel.sources[0].url, featModel.sources[0].url,
featModel.id featModel.id,
featModel.name
) )
} }
> >