chore: decide model name on pull and import
This commit is contained in:
parent
d0ffe6c611
commit
a986c6de2d
@ -13,9 +13,9 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
|
|||||||
}
|
}
|
||||||
|
|
||||||
abstract getModels(): Promise<Model[]>
|
abstract getModels(): Promise<Model[]>
|
||||||
abstract pullModel(model: string, id?: string): Promise<void>
|
abstract pullModel(model: string, id?: string, name?: string): Promise<void>
|
||||||
abstract cancelModelPull(modelId: string): Promise<void>
|
abstract cancelModelPull(modelId: string): Promise<void>
|
||||||
abstract importModel(model: string, modePath: string): Promise<void>
|
abstract importModel(model: string, modePath: string, name?: string): Promise<void>
|
||||||
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
|
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
|
||||||
abstract deleteModel(model: string): Promise<void>
|
abstract deleteModel(model: string): Promise<void>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,7 @@ export interface ModelInterface {
|
|||||||
* @param model - The model to download.
|
* @param model - The model to download.
|
||||||
* @returns A Promise that resolves when the model has been downloaded.
|
* @returns A Promise that resolves when the model has been downloaded.
|
||||||
*/
|
*/
|
||||||
pullModel(model: string, id?: string): Promise<void>
|
pullModel(model: string, id?: string, name?: string): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cancels the download of a specific model.
|
* Cancels the download of a specific model.
|
||||||
@ -43,5 +43,5 @@ export interface ModelInterface {
|
|||||||
* @param model id of the model to import
|
* @param model id of the model to import
|
||||||
* @param modelPath - path of the model file
|
* @param modelPath - path of the model file
|
||||||
*/
|
*/
|
||||||
importModel(model: string, modePath: string): Promise<void>
|
importModel(model: string, modePath: string, name?: string): Promise<void>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
1.0.2-rc1
|
1.0.2-rc2
|
||||||
@ -8,8 +8,8 @@ import { extractInferenceParams } from '@janhq/core'
|
|||||||
interface ICortexAPI {
|
interface ICortexAPI {
|
||||||
getModel(model: string): Promise<Model>
|
getModel(model: string): Promise<Model>
|
||||||
getModels(): Promise<Model[]>
|
getModels(): Promise<Model[]>
|
||||||
pullModel(model: string, id?: string): Promise<void>
|
pullModel(model: string, id?: string, name?: string): Promise<void>
|
||||||
importModel(path: string, modelPath: string): Promise<void>
|
importModel(path: string, modelPath: string, name?: string): Promise<void>
|
||||||
deleteModel(model: string): Promise<void>
|
deleteModel(model: string): Promise<void>
|
||||||
updateModel(model: object): Promise<void>
|
updateModel(model: object): Promise<void>
|
||||||
cancelModelPull(model: string): Promise<void>
|
cancelModelPull(model: string): Promise<void>
|
||||||
@ -68,10 +68,10 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
* @param model
|
* @param model
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
pullModel(model: string, id?: string): Promise<void> {
|
pullModel(model: string, id?: string, name?: string): Promise<void> {
|
||||||
return this.queue.add(() =>
|
return this.queue.add(() =>
|
||||||
ky
|
ky
|
||||||
.post(`${API_URL}/v1/models/pull`, { json: { model, id } })
|
.post(`${API_URL}/v1/models/pull`, { json: { model, id, name } })
|
||||||
.json()
|
.json()
|
||||||
.catch(async (e) => {
|
.catch(async (e) => {
|
||||||
throw (await e.response?.json()) ?? e
|
throw (await e.response?.json()) ?? e
|
||||||
@ -85,10 +85,10 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
* @param model
|
* @param model
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
importModel(model: string, modelPath: string): Promise<void> {
|
importModel(model: string, modelPath: string, name?: string): Promise<void> {
|
||||||
return this.queue.add(() =>
|
return this.queue.add(() =>
|
||||||
ky
|
ky
|
||||||
.post(`${API_URL}/v1/models/import`, { json: { model, modelPath } })
|
.post(`${API_URL}/v1/models/import`, { json: { model, modelPath, name } })
|
||||||
.json()
|
.json()
|
||||||
.catch((e) => console.debug(e)) // Ignore error
|
.catch((e) => console.debug(e)) // Ignore error
|
||||||
.then()
|
.then()
|
||||||
|
|||||||
@ -58,7 +58,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @param model - The model to download.
|
* @param model - The model to download.
|
||||||
* @returns A Promise that resolves when the model is downloaded.
|
* @returns A Promise that resolves when the model is downloaded.
|
||||||
*/
|
*/
|
||||||
async pullModel(model: string, id?: string): Promise<void> {
|
async pullModel(model: string, id?: string, name?: string): Promise<void> {
|
||||||
if (id) {
|
if (id) {
|
||||||
const model: Model = ModelManager.instance().get(id)
|
const model: Model = ModelManager.instance().get(id)
|
||||||
// Clip vision model - should not be handled by cortex.cpp
|
// Clip vision model - should not be handled by cortex.cpp
|
||||||
@ -74,7 +74,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
/**
|
/**
|
||||||
* Sending POST to /models/pull/{id} endpoint to pull the model
|
* Sending POST to /models/pull/{id} endpoint to pull the model
|
||||||
*/
|
*/
|
||||||
return this.cortexAPI.pullModel(model, id)
|
return this.cortexAPI.pullModel(model, id, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -111,14 +111,12 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @returns A Promise that resolves when the model is deleted.
|
* @returns A Promise that resolves when the model is deleted.
|
||||||
*/
|
*/
|
||||||
async deleteModel(model: string): Promise<void> {
|
async deleteModel(model: string): Promise<void> {
|
||||||
const modelDto: Model = ModelManager.instance().get(model)
|
|
||||||
return this.cortexAPI
|
return this.cortexAPI
|
||||||
.deleteModel(model)
|
.deleteModel(model)
|
||||||
.catch((e) => console.debug(e))
|
.catch((e) => console.debug(e))
|
||||||
.finally(async () => {
|
.finally(async () => {
|
||||||
// Delete legacy model files
|
// Delete legacy model files
|
||||||
if (modelDto)
|
await deleteModelFiles(model).catch((e) => console.debug(e))
|
||||||
await deleteModelFiles(modelDto).catch((e) => console.debug(e))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,8 +225,12 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @param model
|
* @param model
|
||||||
* @param optionType
|
* @param optionType
|
||||||
*/
|
*/
|
||||||
async importModel(model: string, modelPath: string): Promise<void> {
|
async importModel(
|
||||||
return this.cortexAPI.importModel(model, modelPath)
|
model: string,
|
||||||
|
modelPath: string,
|
||||||
|
name?: string
|
||||||
|
): Promise<void> {
|
||||||
|
return this.cortexAPI.importModel(model, modelPath, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import { fs, joinPath, Model } from '@janhq/core'
|
import { fs, joinPath } from '@janhq/core'
|
||||||
|
|
||||||
export const deleteModelFiles = async (model: Model) => {
|
export const deleteModelFiles = async (id: string) => {
|
||||||
try {
|
try {
|
||||||
const dirPath = await joinPath(['file://models', model.id])
|
const dirPath = await joinPath(['file://models', id])
|
||||||
// remove model folder directory
|
// remove model folder directory
|
||||||
await fs.unlinkSync(dirPath)
|
await fs.rm(dirPath)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,7 +40,8 @@ describe('useDownloadModel', () => {
|
|||||||
|
|
||||||
expect(mockExtension.pullModel).toHaveBeenCalledWith(
|
expect(mockExtension.pullModel).toHaveBeenCalledWith(
|
||||||
mockModel.sources[0].url,
|
mockModel.sources[0].url,
|
||||||
mockModel.id
|
mockModel.id,
|
||||||
|
undefined
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -87,7 +88,8 @@ describe('useDownloadModel', () => {
|
|||||||
|
|
||||||
expect(mockExtension.pullModel).toHaveBeenCalledWith(
|
expect(mockExtension.pullModel).toHaveBeenCalledWith(
|
||||||
mockModel.sources[0].url,
|
mockModel.sources[0].url,
|
||||||
mockModel.id
|
mockModel.id,
|
||||||
|
undefined
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -18,9 +18,9 @@ export default function useDownloadModel() {
|
|||||||
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||||
|
|
||||||
const downloadModel = useCallback(
|
const downloadModel = useCallback(
|
||||||
async (model: string, id?: string) => {
|
async (model: string, id?: string, name?: string) => {
|
||||||
addDownloadingModel(id ?? model)
|
addDownloadingModel(id ?? model)
|
||||||
downloadLocalModel(model, id).catch((error) => {
|
downloadLocalModel(model, id, name).catch((error) => {
|
||||||
if (error.message) {
|
if (error.message) {
|
||||||
toaster({
|
toaster({
|
||||||
title: 'Download failed',
|
title: 'Download failed',
|
||||||
@ -45,10 +45,10 @@ export default function useDownloadModel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const downloadLocalModel = async (model: string, id?: string) =>
|
const downloadLocalModel = async (model: string, id?: string, name?: string) =>
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.pullModel(model, id)
|
?.pullModel(model, id, name)
|
||||||
|
|
||||||
const cancelModelDownload = async (model: string) =>
|
const cancelModelDownload = async (model: string) =>
|
||||||
extensionManager
|
extensionManager
|
||||||
|
|||||||
@ -34,8 +34,8 @@ describe('useImportModel', () => {
|
|||||||
await result.current.importModels(models, 'local' as any)
|
await result.current.importModels(models, 'local' as any)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1')
|
expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1', undefined)
|
||||||
expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2')
|
expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2', undefined)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should update model info successfully', async () => {
|
it('should update model info successfully', async () => {
|
||||||
|
|||||||
@ -66,7 +66,7 @@ const useImportModel = () => {
|
|||||||
addDownloadingModel(modelId)
|
addDownloadingModel(modelId)
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.importModel(model.modelId, model.path)
|
?.importModel(model.modelId, model.path, model.name)
|
||||||
.finally(() => removeDownloadingModel(modelId))
|
.finally(() => removeDownloadingModel(modelId))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
|
|||||||
const assistants = useAtomValue(assistantsAtom)
|
const assistants = useAtomValue(assistantsAtom)
|
||||||
|
|
||||||
const onDownloadClick = useCallback(() => {
|
const onDownloadClick = useCallback(() => {
|
||||||
downloadModel(model.sources[0].url, model.id)
|
downloadModel(model.sources[0].url, model.id, model.name)
|
||||||
}, [model, downloadModel])
|
}, [model, downloadModel])
|
||||||
|
|
||||||
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
||||||
|
|||||||
@ -63,7 +63,11 @@ const ModelDownloadRow: React.FC<Props> = ({
|
|||||||
|
|
||||||
const onDownloadClick = useCallback(async () => {
|
const onDownloadClick = useCallback(async () => {
|
||||||
if (downloadUrl) {
|
if (downloadUrl) {
|
||||||
downloadModel(downloadUrl, normalizeModelId(downloadUrl))
|
downloadModel(
|
||||||
|
downloadUrl,
|
||||||
|
normalizeModelId(downloadUrl),
|
||||||
|
normalizeModelId(downloadUrl)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}, [downloadUrl, downloadModel])
|
}, [downloadUrl, downloadModel])
|
||||||
|
|
||||||
|
|||||||
@ -170,7 +170,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(
|
downloadModel(
|
||||||
model.sources[0].url,
|
model.sources[0].url,
|
||||||
model.id
|
model.id,
|
||||||
|
model.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
@ -261,7 +262,8 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(
|
downloadModel(
|
||||||
featModel.sources[0].url,
|
featModel.sources[0].url,
|
||||||
featModel.id
|
featModel.id,
|
||||||
|
featModel.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user