From 03e15fb70fa9bacd901dd3e31de49b31594c4f61 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 21 Oct 2024 12:18:14 +0700 Subject: [PATCH] feat: sync model hub and download progress from cortex.cpp --- .husky/pre-commit | 2 +- core/src/browser/extensions/model.ts | 2 +- core/src/types/model/modelInterface.ts | 5 +- core/src/types/monitoring/index.test.ts | 25 +++-- .../src/node/index.ts | 13 ++- extensions/model-extension/rollup.config.ts | 2 + .../model-extension/src/@types/global.d.ts | 2 + extensions/model-extension/src/cortex.ts | 24 ++--- extensions/model-extension/src/index.ts | 4 +- web/containers/ModalCancelDownload/index.tsx | 22 ++--- web/containers/ModelDropdown/index.tsx | 10 +- web/containers/Providers/EventListener.tsx | 10 ++ web/hooks/useDownloadModel.ts | 21 ++--- web/hooks/useGetHFRepoData.ts | 7 +- web/hooks/useSendChatMessage.ts | 3 +- .../Hub/ModelList/ModelHeader/index.tsx | 13 +-- .../ModelDownloadRow/index.tsx | 5 +- .../ChatBody/OnDeviceStarterScreen/index.tsx | 10 +- web/services/restService.ts | 2 +- web/utils/huggingface.ts | 93 +++++++++++++++++++ web/utils/model.ts | 3 + 21 files changed, 192 insertions(+), 86 deletions(-) create mode 100644 web/utils/huggingface.ts create mode 100644 web/utils/model.ts diff --git a/.husky/pre-commit b/.husky/pre-commit index a4aa5add4..177cd4216 100644 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1 +1 @@ -npm run lint --fix \ No newline at end of file +oxlint --fix || npm run lint --fix \ No newline at end of file diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index d111c1d3a..f3609b3b2 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -13,7 +13,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter } abstract getModels(): Promise - abstract pullModel(model: string): Promise + abstract pullModel(model: string, id?: string): Promise abstract cancelModelPull(modelId: string): Promise abstract importModel(model: string, modePath: string): Promise abstract updateModel(modelInfo: Partial): Promise diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 088118f69..b676db949 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -1,5 +1,4 @@ import { Model } from './modelEntity' -import { OptionType } from './modelImport' /** * Model extension for managing models. @@ -10,14 +9,14 @@ export interface ModelInterface { * @param model - The model to download. * @returns A Promise that resolves when the model has been downloaded. */ - pullModel(model: string): Promise + pullModel(model: string, id?: string): Promise /** * Cancels the download of a specific model. * @param {string} modelId - The ID of the model to cancel the download for. * @returns {Promise} A promise that resolves when the download has been cancelled. */ - cancelModelPull(modelId: string): Promise + cancelModelPull(model: string): Promise /** * Deletes a model. diff --git a/core/src/types/monitoring/index.test.ts b/core/src/types/monitoring/index.test.ts index 010fcb97a..56c5879e4 100644 --- a/core/src/types/monitoring/index.test.ts +++ b/core/src/types/monitoring/index.test.ts @@ -1,16 +1,13 @@ +import * as monitoringInterface from './monitoringInterface' +import * as resourceInfo from './resourceInfo' -import * as monitoringInterface from './monitoringInterface'; -import * as resourceInfo from './resourceInfo'; +import * as index from './index' - import * as index from './index'; - import * as monitoringInterface from './monitoringInterface'; - import * as resourceInfo from './resourceInfo'; - - it('should re-export all symbols from monitoringInterface and resourceInfo', () => { - for (const key in monitoringInterface) { - expect(index[key]).toBe(monitoringInterface[key]); - } - for (const key in resourceInfo) { - expect(index[key]).toBe(resourceInfo[key]); - } - }); +it('should re-export all symbols from monitoringInterface and resourceInfo', () => { + for (const key in monitoringInterface) { + expect(index[key]).toBe(monitoringInterface[key]) + } + for (const key in resourceInfo) { + expect(index[key]).toBe(resourceInfo[key]) + } +}) diff --git a/extensions/inference-cortex-extension/src/node/index.ts b/extensions/inference-cortex-extension/src/node/index.ts index f1c365ade..788318c84 100644 --- a/extensions/inference-cortex-extension/src/node/index.ts +++ b/extensions/inference-cortex-extension/src/node/index.ts @@ -1,5 +1,5 @@ import path from 'path' -import { log, SystemInformation } from '@janhq/core/node' +import { getJanDataFolderPath, log, SystemInformation } from '@janhq/core/node' import { executableCortexFile } from './execute' import { ProcessWatchdog } from './watchdog' @@ -40,9 +40,18 @@ function run(systemInfo?: SystemInformation): Promise { executableOptions.enginePath ) + const dataFolderPath = getJanDataFolderPath() watchdog = new ProcessWatchdog( executableOptions.executablePath, - ['--start-server', '--port', LOCAL_PORT.toString()], + [ + '--start-server', + '--port', + LOCAL_PORT.toString(), + '--config_file_path', + `${path.join(dataFolderPath, '.janrc')}`, + '--data_folder_path', + dataFolderPath, + ], { cwd: executableOptions.enginePath, env: { diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index 6e506140f..781c4df84 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -20,6 +20,8 @@ export default [ replace({ preventAssignment: true, SETTINGS: JSON.stringify(settingJson), + API_URL: 'http://127.0.0.1:39291', + SOCKET_URL: 'ws://127.0.0.1:39291', }), // Allow json resolution json(), diff --git a/extensions/model-extension/src/@types/global.d.ts b/extensions/model-extension/src/@types/global.d.ts index 01bd272f2..bff3811e3 100644 --- a/extensions/model-extension/src/@types/global.d.ts +++ b/extensions/model-extension/src/@types/global.d.ts @@ -1,6 +1,8 @@ export {} declare global { declare const NODE: string + declare const API_URL: string + declare const SOCKET_URL: string interface Core { api: APIFunctions diff --git a/extensions/model-extension/src/cortex.ts b/extensions/model-extension/src/cortex.ts index 4945e4756..b0acd6d08 100644 --- a/extensions/model-extension/src/cortex.ts +++ b/extensions/model-extension/src/cortex.ts @@ -1,6 +1,7 @@ import PQueue from 'p-queue' import ky from 'ky' import { + DownloadEvent, events, Model, ModelEvent, @@ -13,18 +14,12 @@ import { interface ICortexAPI { getModel(model: string): Promise getModels(): Promise - pullModel(model: string): Promise + pullModel(model: string, id?: string): Promise importModel(path: string, modelPath: string): Promise deleteModel(model: string): Promise updateModel(model: object): Promise cancelModelPull(model: string): Promise } -/** - * Simple CortexAPI service - * It could be replaced by cortex client sdk later on - */ -const API_URL = 'http://127.0.0.1:39291' -const SOCKET_URL = 'ws://127.0.0.1:39291' type ModelList = { data: any[] @@ -71,10 +66,10 @@ export class CortexAPI implements ICortexAPI { * @param model * @returns */ - pullModel(model: string): Promise { + pullModel(model: string, id?: string): Promise { return this.queue.add(() => ky - .post(`${API_URL}/v1/models/pull`, { json: { model } }) + .post(`${API_URL}/v1/models/pull`, { json: { model, id } }) .json() .catch(async (e) => { throw (await e.response?.json()) ?? e @@ -160,7 +155,6 @@ export class CortexAPI implements ICortexAPI { () => new Promise((resolve) => { this.socket = new WebSocket(`${SOCKET_URL}/events`) - console.log('Socket connected') this.socket.addEventListener('message', (event) => { const data = JSON.parse(event.data) @@ -173,7 +167,7 @@ export class CortexAPI implements ICortexAPI { (accumulator, currentValue) => accumulator + currentValue.bytes, 0 ) - const percent = ((transferred ?? 1) / (total ?? 1)) * 100 + const percent = (transferred / total || 0) * 100 events.emit(data.type, { modelId: data.task.id, @@ -184,7 +178,13 @@ export class CortexAPI implements ICortexAPI { }, }) // Update models list from Hub - events.emit(ModelEvent.OnModelsUpdate, {}) + if (data.type === DownloadEvent.onFileDownloadSuccess) { + // Delay for the state update from cortex.cpp + // Just to be sure + setTimeout(() => { + events.emit(ModelEvent.OnModelsUpdate, {}) + }, 500) + } }) resolve() }) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index c154c3754..38fd0634a 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -47,11 +47,11 @@ export default class JanModelExtension extends ModelExtension { * @param model - The model to download. * @returns A Promise that resolves when the model is downloaded. */ - async pullModel(model: string): Promise { + async pullModel(model: string, id?: string): Promise { /** * Sending POST to /models/pull/{id} endpoint to pull the model */ - return this.cortexAPI.pullModel(model) + return this.cortexAPI.pullModel(model, id) } /** diff --git a/web/containers/ModalCancelDownload/index.tsx b/web/containers/ModalCancelDownload/index.tsx index fdc583911..8a92c9279 100644 --- a/web/containers/ModalCancelDownload/index.tsx +++ b/web/containers/ModalCancelDownload/index.tsx @@ -4,7 +4,7 @@ import { Model } from '@janhq/core' import { Modal, Button, Progress, ModalClose } from '@janhq/joi' -import { useAtomValue } from 'jotai' +import { useAtomValue, useSetAtom } from 'jotai' import useDownloadModel from '@/hooks/useDownloadModel' @@ -12,7 +12,7 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { formatDownloadPercentage } from '@/utils/converter' -import { getDownloadingModelAtom } from '@/helpers/atoms/Model.atom' +import { removeDownloadingModelAtom } from '@/helpers/atoms/Model.atom' type Props = { model: Model @@ -21,20 +21,16 @@ type Props = { const ModalCancelDownload = ({ model, isFromList }: Props) => { const { abortModelDownload } = useDownloadModel() - const downloadingModels = useAtomValue(getDownloadingModelAtom) + const removeModelDownload = useSetAtom(removeDownloadingModelAtom) const allDownloadStates = useAtomValue(modelDownloadStateAtom) const downloadState = allDownloadStates[model.id] - const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}` + const cancelText = `Cancel ${formatDownloadPercentage(downloadState?.percent ?? 0)}` const onAbortDownloadClick = useCallback(() => { - if (downloadState?.modelId) { - const model = downloadingModels.find( - (model) => model === downloadState.modelId - ) - if (model) abortModelDownload(model) - } - }, [downloadState, downloadingModels, abortModelDownload]) + removeModelDownload(model.id) + abortModelDownload(downloadState?.modelId ?? model.id) + }, [downloadState, abortModelDownload, removeModelDownload, model]) return ( { - {formatDownloadPercentage(downloadState.percent)} + {formatDownloadPercentage(downloadState?.percent ?? 0)} diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 7415f1165..a5874b3de 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -472,7 +472,10 @@ const ModelDropdown = ({ size={18} className="cursor-pointer text-[hsla(var(--app-link))]" onClick={() => - downloadModel(model.sources[0].url) + downloadModel( + model.sources[0].url, + model.id + ) } /> ) : ( @@ -559,7 +562,10 @@ const ModelDropdown = ({ size={18} className="cursor-pointer text-[hsla(var(--app-link))]" onClick={() => - downloadModel(model.sources[0].url) + downloadModel( + model.sources[0].url, + model.id + ) } /> ) : ( diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 1832256e2..5df59b0fd 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -23,11 +23,17 @@ import { removeInstallingExtensionAtom, setInstallingExtensionAtom, } from '@/helpers/atoms/Extension.atom' +import { + addDownloadingModelAtom, + removeDownloadingModelAtom, +} from '@/helpers/atoms/Model.atom' const EventListenerWrapper = ({ children }: PropsWithChildren) => { const setDownloadState = useSetAtom(setDownloadStateAtom) const setInstallingExtension = useSetAtom(setInstallingExtensionAtom) const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom) + const addDownloadingModel = useSetAtom(addDownloadingModelAtom) + const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom) const onFileDownloadUpdate = useCallback( async (state: DownloadState) => { @@ -40,6 +46,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { } setInstallingExtension(state.extensionId!, installingExtensionState) } else { + addDownloadingModel(state.modelId) setDownloadState(state) } }, @@ -54,6 +61,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { } else { state.downloadState = 'error' setDownloadState(state) + removeDownloadingModel(state.modelId) } }, [setDownloadState, removeInstallingExtension] @@ -68,6 +76,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { state.downloadState = 'error' state.error = 'aborted' setDownloadState(state) + removeDownloadingModel(state.modelId) } }, [setDownloadState, removeInstallingExtension] @@ -79,6 +88,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { if (state.downloadType !== 'extension') { state.downloadState = 'end' setDownloadState(state) + removeDownloadingModel(state.modelId) } events.emit(ModelEvent.OnModelsUpdate, {}) }, diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts index 82ce593e2..3b25cb86f 100644 --- a/web/hooks/useDownloadModel.ts +++ b/web/hooks/useDownloadModel.ts @@ -1,11 +1,6 @@ import { useCallback } from 'react' -import { - events, - ExtensionTypeEnum, - ModelEvent, - ModelExtension, -} from '@janhq/core' +import { ExtensionTypeEnum, ModelExtension } from '@janhq/core' import { useSetAtom } from 'jotai' @@ -19,13 +14,13 @@ import { } from '@/helpers/atoms/Model.atom' export default function useDownloadModel() { - const addDownloadingModel = useSetAtom(addDownloadingModelAtom) const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom) + const addDownloadingModel = useSetAtom(addDownloadingModelAtom) const downloadModel = useCallback( - async (model: string) => { - addDownloadingModel(model) - localDownloadModel(model).catch((error) => { + async (model: string, id?: string) => { + addDownloadingModel(id ?? model) + downloadLocalModel(model, id).catch((error) => { if (error.message) { toaster({ title: 'Download failed', @@ -37,7 +32,7 @@ export default function useDownloadModel() { removeDownloadingModel(model) }) }, - [addDownloadingModel] + [removeDownloadingModel, addDownloadingModel] ) const abortModelDownload = useCallback(async (model: string) => { @@ -50,10 +45,10 @@ export default function useDownloadModel() { } } -const localDownloadModel = async (model: string) => +const downloadLocalModel = async (model: string, id?: string) => extensionManager .get(ExtensionTypeEnum.Model) - ?.pullModel(model) + ?.pullModel(model, id) const cancelModelDownload = async (model: string) => extensionManager diff --git a/web/hooks/useGetHFRepoData.ts b/web/hooks/useGetHFRepoData.ts index 4e3308116..6f2ec2b57 100644 --- a/web/hooks/useGetHFRepoData.ts +++ b/web/hooks/useGetHFRepoData.ts @@ -2,6 +2,8 @@ import { useCallback, useState } from 'react' import { HuggingFaceRepoData } from '@janhq/core' +import { fetchHuggingFaceRepoData } from '@/utils/huggingface' + export const useGetHFRepoData = () => { const [error, setError] = useState(undefined) const [loading, setLoading] = useState(false) @@ -29,8 +31,5 @@ export const useGetHFRepoData = () => { const extensionGetHfRepoData = async ( repoId: string ): Promise => { - return Promise.resolve(undefined) - // return extensionManager - // .get(ExtensionTypeEnum.Model) - // ?.fetchHuggingFaceRepoData(repoId) + return fetchHuggingFaceRepoData(repoId) } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index bab515a30..4bc91cad2 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -216,7 +216,7 @@ export default function useSendChatMessage() { ...activeThreadRef.current, updated: newMessage.created, metadata: { - ...(activeThreadRef.current.metadata ?? {}), + ...activeThreadRef.current.metadata, lastMessage: prompt, }, } @@ -256,7 +256,6 @@ export default function useSendChatMessage() { ) request.messages = normalizeMessages(request.messages ?? []) - console.log(requestBuilder.model?.engine ?? modelRequest.engine, request) // Request for inference EngineManager.instance() .get(requestBuilder.model?.engine ?? modelRequest.engine ?? '') diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx index ce5a12957..725b0216a 100644 --- a/web/screens/Hub/ModelList/ModelHeader/index.tsx +++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx @@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => { const assistants = useAtomValue(assistantsAtom) const onDownloadClick = useCallback(() => { - downloadModel(model.sources[0].url) + downloadModel(model.sources[0].url, model.id) }, [model, downloadModel]) const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null @@ -123,17 +123,6 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => { className="cursor-pointer rounded-t-md bg-[hsla(var(--app-bg))]" onClick={onClick} > - {/* TODO: @faisal are we still using cover? */} - {/* {model.metadata.cover && imageLoaded && ( -
- setImageLoaded(false)} - src={model.metadata.cover} - className="h-[250px] w-full object-cover" - alt={`Cover - ${model.id}`} - /> -
- )} */}
diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index 454905332..03413006f 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -20,6 +20,7 @@ import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { normalizeModelId } from '@/utils/model' type Props = { index: number @@ -50,13 +51,13 @@ const ModelDownloadRow: React.FC = ({ const onAbortDownloadClick = useCallback(() => { if (downloadUrl) { - abortModelDownload(downloadUrl) + abortModelDownload(normalizeModelId(downloadUrl)) } }, [downloadUrl, abortModelDownload]) const onDownloadClick = useCallback(async () => { if (downloadUrl) { - downloadModel(downloadUrl) + downloadModel(downloadUrl, normalizeModelId(downloadUrl)) } }, [downloadUrl, downloadModel]) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx index 0adc7ddd4..366575a40 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx @@ -168,7 +168,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { size={18} className="cursor-pointer text-[hsla(var(--app-link))]" onClick={() => - downloadModel(model.sources[0].url) + downloadModel( + model.sources[0].url, + model.id + ) } /> ) : ( @@ -256,7 +259,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { theme="ghost" className="!bg-[hsla(var(--secondary-bg))]" onClick={() => - downloadModel(featModel.sources[0].url) + downloadModel( + featModel.sources[0].url, + featModel.id + ) } > Download diff --git a/web/services/restService.ts b/web/services/restService.ts index 73348caeb..3c1cfc6a8 100644 --- a/web/services/restService.ts +++ b/web/services/restService.ts @@ -9,7 +9,7 @@ export function openExternalUrl(url: string) { } // Define API routes based on different route types -export const APIRoutes = [...CoreRoutes.map((r) => ({ path: `app`, route: r }))] +export const APIRoutes = CoreRoutes.map((r) => ({ path: `app`, route: r })) // Define the restAPI object with methods for each API route export const restAPI = { diff --git a/web/utils/huggingface.ts b/web/utils/huggingface.ts new file mode 100644 index 000000000..328d684e6 --- /dev/null +++ b/web/utils/huggingface.ts @@ -0,0 +1,93 @@ +import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core' + +export const fetchHuggingFaceRepoData = async ( + repoId: string, + huggingFaceAccessToken?: string +): Promise => { + const sanitizedUrl = toHuggingFaceUrl(repoId) + console.debug('sanitizedUrl', sanitizedUrl) + + const headers: Record = { + Accept: 'application/json', + } + + if (huggingFaceAccessToken && huggingFaceAccessToken.length > 0) { + headers['Authorization'] = `Bearer ${huggingFaceAccessToken}` + } + + const res = await fetch(sanitizedUrl, { + headers: headers, + }) + const response = await res.json() + if (response['error'] != null) { + throw new Error(response['error']) + } + + const data = response as HuggingFaceRepoData + + if (data.tags.indexOf('gguf') === -1) { + throw new Error( + `${repoId} is not supported. Only GGUF models are supported.` + ) + } + + const promises: Promise[] = [] + + // fetching file sizes + const url = new URL(sanitizedUrl) + const paths = url.pathname.split('/').filter((e) => e.trim().length > 0) + + for (const sibling of data.siblings) { + const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}` + sibling.downloadUrl = downloadUrl + promises.push(getFileSize(downloadUrl)) + } + + const result = await Promise.all(promises) + for (let i = 0; i < data.siblings.length; i++) { + data.siblings[i].fileSize = result[i] + } + + AllQuantizations.forEach((quantization) => { + data.siblings.forEach((sibling) => { + if (!sibling.quantization && sibling.rfilename.includes(quantization)) { + sibling.quantization = quantization + } + }) + }) + + data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}` + return data +} + +function toHuggingFaceUrl(repoId: string): string { + try { + const url = new URL(repoId) + if (url.host !== 'huggingface.co') { + throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`) + } + + const paths = url.pathname.split('/').filter((e) => e.trim().length > 0) + if (paths.length < 2) { + throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`) + } + + return `${url.origin}/api/models/${paths[0]}/${paths[1]}` + } catch (err) { + if (err instanceof InvalidHostError) { + throw err + } + + if (repoId.startsWith('https')) { + throw new Error(`Cannot parse url: ${repoId}`) + } + + return `https://huggingface.co/api/models/${repoId}` + } +} +class InvalidHostError extends Error { + constructor(message: string) { + super(message) + this.name = 'InvalidHostError' + } +} diff --git a/web/utils/model.ts b/web/utils/model.ts new file mode 100644 index 000000000..00efc1155 --- /dev/null +++ b/web/utils/model.ts @@ -0,0 +1,3 @@ +export const normalizeModelId = (downloadUrl: string): string => { + return downloadUrl.split('/').pop() ?? downloadUrl +}