feat: sync model hub and download progress from cortex.cpp
This commit is contained in:
parent
f44f291bd8
commit
03e15fb70f
@ -1 +1 @@
|
||||
npm run lint --fix
|
||||
oxlint --fix || npm run lint --fix
|
||||
@ -13,7 +13,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
|
||||
}
|
||||
|
||||
abstract getModels(): Promise<Model[]>
|
||||
abstract pullModel(model: string): Promise<void>
|
||||
abstract pullModel(model: string, id?: string): Promise<void>
|
||||
abstract cancelModelPull(modelId: string): Promise<void>
|
||||
abstract importModel(model: string, modePath: string): Promise<void>
|
||||
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { Model } from './modelEntity'
|
||||
import { OptionType } from './modelImport'
|
||||
|
||||
/**
|
||||
* Model extension for managing models.
|
||||
@ -10,14 +9,14 @@ export interface ModelInterface {
|
||||
* @param model - The model to download.
|
||||
* @returns A Promise that resolves when the model has been downloaded.
|
||||
*/
|
||||
pullModel(model: string): Promise<void>
|
||||
pullModel(model: string, id?: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Cancels the download of a specific model.
|
||||
* @param {string} modelId - The ID of the model to cancel the download for.
|
||||
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
|
||||
*/
|
||||
cancelModelPull(modelId: string): Promise<void>
|
||||
cancelModelPull(model: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Deletes a model.
|
||||
|
||||
@ -1,16 +1,13 @@
|
||||
import * as monitoringInterface from './monitoringInterface'
|
||||
import * as resourceInfo from './resourceInfo'
|
||||
|
||||
import * as monitoringInterface from './monitoringInterface';
|
||||
import * as resourceInfo from './resourceInfo';
|
||||
import * as index from './index'
|
||||
|
||||
import * as index from './index';
|
||||
import * as monitoringInterface from './monitoringInterface';
|
||||
import * as resourceInfo from './resourceInfo';
|
||||
|
||||
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
|
||||
for (const key in monitoringInterface) {
|
||||
expect(index[key]).toBe(monitoringInterface[key]);
|
||||
}
|
||||
for (const key in resourceInfo) {
|
||||
expect(index[key]).toBe(resourceInfo[key]);
|
||||
}
|
||||
});
|
||||
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
|
||||
for (const key in monitoringInterface) {
|
||||
expect(index[key]).toBe(monitoringInterface[key])
|
||||
}
|
||||
for (const key in resourceInfo) {
|
||||
expect(index[key]).toBe(resourceInfo[key])
|
||||
}
|
||||
})
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import path from 'path'
|
||||
import { log, SystemInformation } from '@janhq/core/node'
|
||||
import { getJanDataFolderPath, log, SystemInformation } from '@janhq/core/node'
|
||||
import { executableCortexFile } from './execute'
|
||||
import { ProcessWatchdog } from './watchdog'
|
||||
|
||||
@ -40,9 +40,18 @@ function run(systemInfo?: SystemInformation): Promise<any> {
|
||||
executableOptions.enginePath
|
||||
)
|
||||
|
||||
const dataFolderPath = getJanDataFolderPath()
|
||||
watchdog = new ProcessWatchdog(
|
||||
executableOptions.executablePath,
|
||||
['--start-server', '--port', LOCAL_PORT.toString()],
|
||||
[
|
||||
'--start-server',
|
||||
'--port',
|
||||
LOCAL_PORT.toString(),
|
||||
'--config_file_path',
|
||||
`${path.join(dataFolderPath, '.janrc')}`,
|
||||
'--data_folder_path',
|
||||
dataFolderPath,
|
||||
],
|
||||
{
|
||||
cwd: executableOptions.enginePath,
|
||||
env: {
|
||||
|
||||
@ -20,6 +20,8 @@ export default [
|
||||
replace({
|
||||
preventAssignment: true,
|
||||
SETTINGS: JSON.stringify(settingJson),
|
||||
API_URL: 'http://127.0.0.1:39291',
|
||||
SOCKET_URL: 'ws://127.0.0.1:39291',
|
||||
}),
|
||||
// Allow json resolution
|
||||
json(),
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
export {}
|
||||
declare global {
|
||||
declare const NODE: string
|
||||
declare const API_URL: string
|
||||
declare const SOCKET_URL: string
|
||||
|
||||
interface Core {
|
||||
api: APIFunctions
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import PQueue from 'p-queue'
|
||||
import ky from 'ky'
|
||||
import {
|
||||
DownloadEvent,
|
||||
events,
|
||||
Model,
|
||||
ModelEvent,
|
||||
@ -13,18 +14,12 @@ import {
|
||||
interface ICortexAPI {
|
||||
getModel(model: string): Promise<Model>
|
||||
getModels(): Promise<Model[]>
|
||||
pullModel(model: string): Promise<void>
|
||||
pullModel(model: string, id?: string): Promise<void>
|
||||
importModel(path: string, modelPath: string): Promise<void>
|
||||
deleteModel(model: string): Promise<void>
|
||||
updateModel(model: object): Promise<void>
|
||||
cancelModelPull(model: string): Promise<void>
|
||||
}
|
||||
/**
|
||||
* Simple CortexAPI service
|
||||
* It could be replaced by cortex client sdk later on
|
||||
*/
|
||||
const API_URL = 'http://127.0.0.1:39291'
|
||||
const SOCKET_URL = 'ws://127.0.0.1:39291'
|
||||
|
||||
type ModelList = {
|
||||
data: any[]
|
||||
@ -71,10 +66,10 @@ export class CortexAPI implements ICortexAPI {
|
||||
* @param model
|
||||
* @returns
|
||||
*/
|
||||
pullModel(model: string): Promise<void> {
|
||||
pullModel(model: string, id?: string): Promise<void> {
|
||||
return this.queue.add(() =>
|
||||
ky
|
||||
.post(`${API_URL}/v1/models/pull`, { json: { model } })
|
||||
.post(`${API_URL}/v1/models/pull`, { json: { model, id } })
|
||||
.json()
|
||||
.catch(async (e) => {
|
||||
throw (await e.response?.json()) ?? e
|
||||
@ -160,7 +155,6 @@ export class CortexAPI implements ICortexAPI {
|
||||
() =>
|
||||
new Promise<void>((resolve) => {
|
||||
this.socket = new WebSocket(`${SOCKET_URL}/events`)
|
||||
console.log('Socket connected')
|
||||
|
||||
this.socket.addEventListener('message', (event) => {
|
||||
const data = JSON.parse(event.data)
|
||||
@ -173,7 +167,7 @@ export class CortexAPI implements ICortexAPI {
|
||||
(accumulator, currentValue) => accumulator + currentValue.bytes,
|
||||
0
|
||||
)
|
||||
const percent = ((transferred ?? 1) / (total ?? 1)) * 100
|
||||
const percent = (transferred / total || 0) * 100
|
||||
|
||||
events.emit(data.type, {
|
||||
modelId: data.task.id,
|
||||
@ -184,7 +178,13 @@ export class CortexAPI implements ICortexAPI {
|
||||
},
|
||||
})
|
||||
// Update models list from Hub
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
if (data.type === DownloadEvent.onFileDownloadSuccess) {
|
||||
// Delay for the state update from cortex.cpp
|
||||
// Just to be sure
|
||||
setTimeout(() => {
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
}, 500)
|
||||
}
|
||||
})
|
||||
resolve()
|
||||
})
|
||||
|
||||
@ -47,11 +47,11 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* @param model - The model to download.
|
||||
* @returns A Promise that resolves when the model is downloaded.
|
||||
*/
|
||||
async pullModel(model: string): Promise<void> {
|
||||
async pullModel(model: string, id?: string): Promise<void> {
|
||||
/**
|
||||
* Sending POST to /models/pull/{id} endpoint to pull the model
|
||||
*/
|
||||
return this.cortexAPI.pullModel(model)
|
||||
return this.cortexAPI.pullModel(model, id)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -4,7 +4,7 @@ import { Model } from '@janhq/core'
|
||||
|
||||
import { Modal, Button, Progress, ModalClose } from '@janhq/joi'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import useDownloadModel from '@/hooks/useDownloadModel'
|
||||
|
||||
@ -12,7 +12,7 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
|
||||
|
||||
import { formatDownloadPercentage } from '@/utils/converter'
|
||||
|
||||
import { getDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { removeDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
@ -21,20 +21,16 @@ type Props = {
|
||||
|
||||
const ModalCancelDownload = ({ model, isFromList }: Props) => {
|
||||
const { abortModelDownload } = useDownloadModel()
|
||||
const downloadingModels = useAtomValue(getDownloadingModelAtom)
|
||||
const removeModelDownload = useSetAtom(removeDownloadingModelAtom)
|
||||
const allDownloadStates = useAtomValue(modelDownloadStateAtom)
|
||||
const downloadState = allDownloadStates[model.id]
|
||||
|
||||
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
|
||||
const cancelText = `Cancel ${formatDownloadPercentage(downloadState?.percent ?? 0)}`
|
||||
|
||||
const onAbortDownloadClick = useCallback(() => {
|
||||
if (downloadState?.modelId) {
|
||||
const model = downloadingModels.find(
|
||||
(model) => model === downloadState.modelId
|
||||
)
|
||||
if (model) abortModelDownload(model)
|
||||
}
|
||||
}, [downloadState, downloadingModels, abortModelDownload])
|
||||
removeModelDownload(model.id)
|
||||
abortModelDownload(downloadState?.modelId ?? model.id)
|
||||
}, [downloadState, abortModelDownload, removeModelDownload, model])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@ -51,13 +47,13 @@ const ModalCancelDownload = ({ model, isFromList }: Props) => {
|
||||
<Progress
|
||||
className="w-[80px]"
|
||||
value={
|
||||
formatDownloadPercentage(downloadState?.percent, {
|
||||
formatDownloadPercentage(downloadState?.percent ?? 0, {
|
||||
hidePercentage: true,
|
||||
}) as number
|
||||
}
|
||||
/>
|
||||
<span className="tabular-nums">
|
||||
{formatDownloadPercentage(downloadState.percent)}
|
||||
{formatDownloadPercentage(downloadState?.percent ?? 0)}
|
||||
</span>
|
||||
</div>
|
||||
</Button>
|
||||
|
||||
@ -472,7 +472,10 @@ const ModelDropdown = ({
|
||||
size={18}
|
||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||
onClick={() =>
|
||||
downloadModel(model.sources[0].url)
|
||||
downloadModel(
|
||||
model.sources[0].url,
|
||||
model.id
|
||||
)
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
@ -559,7 +562,10 @@ const ModelDropdown = ({
|
||||
size={18}
|
||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||
onClick={() =>
|
||||
downloadModel(model.sources[0].url)
|
||||
downloadModel(
|
||||
model.sources[0].url,
|
||||
model.id
|
||||
)
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
|
||||
@ -23,11 +23,17 @@ import {
|
||||
removeInstallingExtensionAtom,
|
||||
setInstallingExtensionAtom,
|
||||
} from '@/helpers/atoms/Extension.atom'
|
||||
import {
|
||||
addDownloadingModelAtom,
|
||||
removeDownloadingModelAtom,
|
||||
} from '@/helpers/atoms/Model.atom'
|
||||
|
||||
const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
const setDownloadState = useSetAtom(setDownloadStateAtom)
|
||||
const setInstallingExtension = useSetAtom(setInstallingExtensionAtom)
|
||||
const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom)
|
||||
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
|
||||
|
||||
const onFileDownloadUpdate = useCallback(
|
||||
async (state: DownloadState) => {
|
||||
@ -40,6 +46,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
}
|
||||
setInstallingExtension(state.extensionId!, installingExtensionState)
|
||||
} else {
|
||||
addDownloadingModel(state.modelId)
|
||||
setDownloadState(state)
|
||||
}
|
||||
},
|
||||
@ -54,6 +61,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
} else {
|
||||
state.downloadState = 'error'
|
||||
setDownloadState(state)
|
||||
removeDownloadingModel(state.modelId)
|
||||
}
|
||||
},
|
||||
[setDownloadState, removeInstallingExtension]
|
||||
@ -68,6 +76,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
state.downloadState = 'error'
|
||||
state.error = 'aborted'
|
||||
setDownloadState(state)
|
||||
removeDownloadingModel(state.modelId)
|
||||
}
|
||||
},
|
||||
[setDownloadState, removeInstallingExtension]
|
||||
@ -79,6 +88,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
if (state.downloadType !== 'extension') {
|
||||
state.downloadState = 'end'
|
||||
setDownloadState(state)
|
||||
removeDownloadingModel(state.modelId)
|
||||
}
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
},
|
||||
|
||||
@ -1,11 +1,6 @@
|
||||
import { useCallback } from 'react'
|
||||
|
||||
import {
|
||||
events,
|
||||
ExtensionTypeEnum,
|
||||
ModelEvent,
|
||||
ModelExtension,
|
||||
} from '@janhq/core'
|
||||
import { ExtensionTypeEnum, ModelExtension } from '@janhq/core'
|
||||
|
||||
import { useSetAtom } from 'jotai'
|
||||
|
||||
@ -19,13 +14,13 @@ import {
|
||||
} from '@/helpers/atoms/Model.atom'
|
||||
|
||||
export default function useDownloadModel() {
|
||||
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
|
||||
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||
|
||||
const downloadModel = useCallback(
|
||||
async (model: string) => {
|
||||
addDownloadingModel(model)
|
||||
localDownloadModel(model).catch((error) => {
|
||||
async (model: string, id?: string) => {
|
||||
addDownloadingModel(id ?? model)
|
||||
downloadLocalModel(model, id).catch((error) => {
|
||||
if (error.message) {
|
||||
toaster({
|
||||
title: 'Download failed',
|
||||
@ -37,7 +32,7 @@ export default function useDownloadModel() {
|
||||
removeDownloadingModel(model)
|
||||
})
|
||||
},
|
||||
[addDownloadingModel]
|
||||
[removeDownloadingModel, addDownloadingModel]
|
||||
)
|
||||
|
||||
const abortModelDownload = useCallback(async (model: string) => {
|
||||
@ -50,10 +45,10 @@ export default function useDownloadModel() {
|
||||
}
|
||||
}
|
||||
|
||||
const localDownloadModel = async (model: string) =>
|
||||
const downloadLocalModel = async (model: string, id?: string) =>
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.pullModel(model)
|
||||
?.pullModel(model, id)
|
||||
|
||||
const cancelModelDownload = async (model: string) =>
|
||||
extensionManager
|
||||
|
||||
@ -2,6 +2,8 @@ import { useCallback, useState } from 'react'
|
||||
|
||||
import { HuggingFaceRepoData } from '@janhq/core'
|
||||
|
||||
import { fetchHuggingFaceRepoData } from '@/utils/huggingface'
|
||||
|
||||
export const useGetHFRepoData = () => {
|
||||
const [error, setError] = useState<string | undefined>(undefined)
|
||||
const [loading, setLoading] = useState(false)
|
||||
@ -29,8 +31,5 @@ export const useGetHFRepoData = () => {
|
||||
const extensionGetHfRepoData = async (
|
||||
repoId: string
|
||||
): Promise<HuggingFaceRepoData | undefined> => {
|
||||
return Promise.resolve(undefined)
|
||||
// return extensionManager
|
||||
// .get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
// ?.fetchHuggingFaceRepoData(repoId)
|
||||
return fetchHuggingFaceRepoData(repoId)
|
||||
}
|
||||
|
||||
@ -216,7 +216,7 @@ export default function useSendChatMessage() {
|
||||
...activeThreadRef.current,
|
||||
updated: newMessage.created,
|
||||
metadata: {
|
||||
...(activeThreadRef.current.metadata ?? {}),
|
||||
...activeThreadRef.current.metadata,
|
||||
lastMessage: prompt,
|
||||
},
|
||||
}
|
||||
@ -256,7 +256,6 @@ export default function useSendChatMessage() {
|
||||
)
|
||||
request.messages = normalizeMessages(request.messages ?? [])
|
||||
|
||||
console.log(requestBuilder.model?.engine ?? modelRequest.engine, request)
|
||||
// Request for inference
|
||||
EngineManager.instance()
|
||||
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
||||
|
||||
@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
|
||||
const assistants = useAtomValue(assistantsAtom)
|
||||
|
||||
const onDownloadClick = useCallback(() => {
|
||||
downloadModel(model.sources[0].url)
|
||||
downloadModel(model.sources[0].url, model.id)
|
||||
}, [model, downloadModel])
|
||||
|
||||
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
||||
@ -123,17 +123,6 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
|
||||
className="cursor-pointer rounded-t-md bg-[hsla(var(--app-bg))]"
|
||||
onClick={onClick}
|
||||
>
|
||||
{/* TODO: @faisal are we still using cover? */}
|
||||
{/* {model.metadata.cover && imageLoaded && (
|
||||
<div className="relative h-full w-full">
|
||||
<img
|
||||
onError={() => setImageLoaded(false)}
|
||||
src={model.metadata.cover}
|
||||
className="h-[250px] w-full object-cover"
|
||||
alt={`Cover - ${model.id}`}
|
||||
/>
|
||||
</div>
|
||||
)} */}
|
||||
<div className="flex items-center justify-between px-4 py-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="line-clamp-1 text-base font-semibold">
|
||||
|
||||
@ -20,6 +20,7 @@ import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
||||
|
||||
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { normalizeModelId } from '@/utils/model'
|
||||
|
||||
type Props = {
|
||||
index: number
|
||||
@ -50,13 +51,13 @@ const ModelDownloadRow: React.FC<Props> = ({
|
||||
|
||||
const onAbortDownloadClick = useCallback(() => {
|
||||
if (downloadUrl) {
|
||||
abortModelDownload(downloadUrl)
|
||||
abortModelDownload(normalizeModelId(downloadUrl))
|
||||
}
|
||||
}, [downloadUrl, abortModelDownload])
|
||||
|
||||
const onDownloadClick = useCallback(async () => {
|
||||
if (downloadUrl) {
|
||||
downloadModel(downloadUrl)
|
||||
downloadModel(downloadUrl, normalizeModelId(downloadUrl))
|
||||
}
|
||||
}, [downloadUrl, downloadModel])
|
||||
|
||||
|
||||
@ -168,7 +168,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
||||
size={18}
|
||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||
onClick={() =>
|
||||
downloadModel(model.sources[0].url)
|
||||
downloadModel(
|
||||
model.sources[0].url,
|
||||
model.id
|
||||
)
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
@ -256,7 +259,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
||||
theme="ghost"
|
||||
className="!bg-[hsla(var(--secondary-bg))]"
|
||||
onClick={() =>
|
||||
downloadModel(featModel.sources[0].url)
|
||||
downloadModel(
|
||||
featModel.sources[0].url,
|
||||
featModel.id
|
||||
)
|
||||
}
|
||||
>
|
||||
Download
|
||||
|
||||
@ -9,7 +9,7 @@ export function openExternalUrl(url: string) {
|
||||
}
|
||||
|
||||
// Define API routes based on different route types
|
||||
export const APIRoutes = [...CoreRoutes.map((r) => ({ path: `app`, route: r }))]
|
||||
export const APIRoutes = CoreRoutes.map((r) => ({ path: `app`, route: r }))
|
||||
|
||||
// Define the restAPI object with methods for each API route
|
||||
export const restAPI = {
|
||||
|
||||
93
web/utils/huggingface.ts
Normal file
93
web/utils/huggingface.ts
Normal file
@ -0,0 +1,93 @@
|
||||
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
|
||||
|
||||
export const fetchHuggingFaceRepoData = async (
|
||||
repoId: string,
|
||||
huggingFaceAccessToken?: string
|
||||
): Promise<HuggingFaceRepoData> => {
|
||||
const sanitizedUrl = toHuggingFaceUrl(repoId)
|
||||
console.debug('sanitizedUrl', sanitizedUrl)
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
Accept: 'application/json',
|
||||
}
|
||||
|
||||
if (huggingFaceAccessToken && huggingFaceAccessToken.length > 0) {
|
||||
headers['Authorization'] = `Bearer ${huggingFaceAccessToken}`
|
||||
}
|
||||
|
||||
const res = await fetch(sanitizedUrl, {
|
||||
headers: headers,
|
||||
})
|
||||
const response = await res.json()
|
||||
if (response['error'] != null) {
|
||||
throw new Error(response['error'])
|
||||
}
|
||||
|
||||
const data = response as HuggingFaceRepoData
|
||||
|
||||
if (data.tags.indexOf('gguf') === -1) {
|
||||
throw new Error(
|
||||
`${repoId} is not supported. Only GGUF models are supported.`
|
||||
)
|
||||
}
|
||||
|
||||
const promises: Promise<number>[] = []
|
||||
|
||||
// fetching file sizes
|
||||
const url = new URL(sanitizedUrl)
|
||||
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
||||
|
||||
for (const sibling of data.siblings) {
|
||||
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
|
||||
sibling.downloadUrl = downloadUrl
|
||||
promises.push(getFileSize(downloadUrl))
|
||||
}
|
||||
|
||||
const result = await Promise.all(promises)
|
||||
for (let i = 0; i < data.siblings.length; i++) {
|
||||
data.siblings[i].fileSize = result[i]
|
||||
}
|
||||
|
||||
AllQuantizations.forEach((quantization) => {
|
||||
data.siblings.forEach((sibling) => {
|
||||
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
|
||||
sibling.quantization = quantization
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}`
|
||||
return data
|
||||
}
|
||||
|
||||
function toHuggingFaceUrl(repoId: string): string {
|
||||
try {
|
||||
const url = new URL(repoId)
|
||||
if (url.host !== 'huggingface.co') {
|
||||
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
|
||||
}
|
||||
|
||||
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
||||
if (paths.length < 2) {
|
||||
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
|
||||
}
|
||||
|
||||
return `${url.origin}/api/models/${paths[0]}/${paths[1]}`
|
||||
} catch (err) {
|
||||
if (err instanceof InvalidHostError) {
|
||||
throw err
|
||||
}
|
||||
|
||||
if (repoId.startsWith('https')) {
|
||||
throw new Error(`Cannot parse url: ${repoId}`)
|
||||
}
|
||||
|
||||
return `https://huggingface.co/api/models/${repoId}`
|
||||
}
|
||||
}
|
||||
class InvalidHostError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'InvalidHostError'
|
||||
}
|
||||
}
|
||||
3
web/utils/model.ts
Normal file
3
web/utils/model.ts
Normal file
@ -0,0 +1,3 @@
|
||||
export const normalizeModelId = (downloadUrl: string): string => {
|
||||
return downloadUrl.split('/').pop() ?? downloadUrl
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user