fix: allow users to download the same model from different authors (#6577)
* fix: allow users to download the same model from different authors * fix: importing models should have author name in the ID * fix: incorrect model id show * fix: tests * fix: default to mmproj f16 instead of bf16 * fix: type * fix: build error
This commit is contained in:
parent
fe05478336
commit
57110d2bd7
@ -240,6 +240,12 @@ export abstract class AIEngine extends BaseExtension {
|
||||
EngineManager.instance().register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets model info
|
||||
* @param modelId
|
||||
*/
|
||||
abstract get(modelId: string): Promise<modelInfo | undefined>
|
||||
|
||||
/**
|
||||
* Lists available models
|
||||
*/
|
||||
|
||||
@ -22,7 +22,7 @@ export default class JanProviderWeb extends AIEngine {
|
||||
|
||||
override async onLoad() {
|
||||
console.log('Loading Jan Provider Extension...')
|
||||
|
||||
|
||||
try {
|
||||
// Initialize authentication and fetch models
|
||||
await janApiClient.initialize()
|
||||
@ -37,20 +37,43 @@ export default class JanProviderWeb extends AIEngine {
|
||||
|
||||
override async onUnload() {
|
||||
console.log('Unloading Jan Provider Extension...')
|
||||
|
||||
|
||||
// Clear all sessions
|
||||
for (const sessionId of this.activeSessions.keys()) {
|
||||
await this.unload(sessionId)
|
||||
}
|
||||
|
||||
|
||||
janProviderStore.reset()
|
||||
console.log('Jan Provider Extension unloaded')
|
||||
}
|
||||
|
||||
async get(modelId: string): Promise<modelInfo | undefined> {
|
||||
return janApiClient
|
||||
.getModels()
|
||||
.then((list) => list.find((e) => e.id === modelId))
|
||||
.then((model) =>
|
||||
model
|
||||
? {
|
||||
id: model.id,
|
||||
name: model.id, // Use ID as name for now
|
||||
quant_type: undefined,
|
||||
providerId: this.provider,
|
||||
port: 443, // HTTPS port for API
|
||||
sizeBytes: 0, // Size not provided by Jan API
|
||||
tags: [],
|
||||
path: undefined, // Remote model, no local path
|
||||
owned_by: model.owned_by,
|
||||
object: model.object,
|
||||
capabilities: ['tools'], // Jan models support both tools via MCP
|
||||
}
|
||||
: undefined
|
||||
)
|
||||
}
|
||||
|
||||
async list(): Promise<modelInfo[]> {
|
||||
try {
|
||||
const janModels = await janApiClient.getModels()
|
||||
|
||||
|
||||
return janModels.map((model) => ({
|
||||
id: model.id,
|
||||
name: model.id, // Use ID as name for now
|
||||
@ -75,7 +98,7 @@ export default class JanProviderWeb extends AIEngine {
|
||||
// For Jan API, we don't actually "load" models in the traditional sense
|
||||
// We just create a session reference for tracking
|
||||
const sessionId = `jan-${modelId}-${Date.now()}`
|
||||
|
||||
|
||||
const sessionInfo: SessionInfo = {
|
||||
pid: Date.now(), // Use timestamp as pseudo-PID
|
||||
port: 443, // HTTPS port
|
||||
@ -85,8 +108,10 @@ export default class JanProviderWeb extends AIEngine {
|
||||
}
|
||||
|
||||
this.activeSessions.set(sessionId, sessionInfo)
|
||||
|
||||
console.log(`Jan model session created: ${sessionId} for model ${modelId}`)
|
||||
|
||||
console.log(
|
||||
`Jan model session created: ${sessionId} for model ${modelId}`
|
||||
)
|
||||
return sessionInfo
|
||||
} catch (error) {
|
||||
console.error(`Failed to load Jan model ${modelId}:`, error)
|
||||
@ -97,23 +122,23 @@ export default class JanProviderWeb extends AIEngine {
|
||||
async unload(sessionId: string): Promise<UnloadResult> {
|
||||
try {
|
||||
const session = this.activeSessions.get(sessionId)
|
||||
|
||||
|
||||
if (!session) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Session ${sessionId} not found`
|
||||
error: `Session ${sessionId} not found`,
|
||||
}
|
||||
}
|
||||
|
||||
this.activeSessions.delete(sessionId)
|
||||
console.log(`Jan model session unloaded: ${sessionId}`)
|
||||
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
console.error(`Failed to unload Jan session ${sessionId}:`, error)
|
||||
return {
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Unknown error'
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -136,9 +161,12 @@ export default class JanProviderWeb extends AIEngine {
|
||||
}
|
||||
|
||||
// Convert core chat completion request to Jan API format
|
||||
const janMessages: JanChatMessage[] = opts.messages.map(msg => ({
|
||||
const janMessages: JanChatMessage[] = opts.messages.map((msg) => ({
|
||||
role: msg.role as 'system' | 'user' | 'assistant',
|
||||
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content)
|
||||
content:
|
||||
typeof msg.content === 'string'
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
}))
|
||||
|
||||
const janRequest = {
|
||||
@ -162,18 +190,18 @@ export default class JanProviderWeb extends AIEngine {
|
||||
} else {
|
||||
// Return single response
|
||||
const response = await janApiClient.createChatCompletion(janRequest)
|
||||
|
||||
|
||||
// Check if aborted after completion
|
||||
if (abortController?.signal?.aborted) {
|
||||
throw new Error('Request was aborted')
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
id: response.id,
|
||||
object: 'chat.completion' as const,
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
choices: response.choices.map(choice => ({
|
||||
choices: response.choices.map((choice) => ({
|
||||
index: choice.index,
|
||||
message: {
|
||||
role: choice.message.role,
|
||||
@ -182,7 +210,12 @@ export default class JanProviderWeb extends AIEngine {
|
||||
reasoning_content: choice.message.reasoning_content,
|
||||
tool_calls: choice.message.tool_calls,
|
||||
},
|
||||
finish_reason: (choice.finish_reason || 'stop') as 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call',
|
||||
finish_reason: (choice.finish_reason || 'stop') as
|
||||
| 'stop'
|
||||
| 'length'
|
||||
| 'tool_calls'
|
||||
| 'content_filter'
|
||||
| 'function_call',
|
||||
})),
|
||||
usage: response.usage,
|
||||
}
|
||||
@ -193,7 +226,10 @@ export default class JanProviderWeb extends AIEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private async *createStreamingGenerator(janRequest: any, abortController?: AbortController) {
|
||||
private async *createStreamingGenerator(
|
||||
janRequest: any,
|
||||
abortController?: AbortController
|
||||
) {
|
||||
let resolve: () => void
|
||||
let reject: (error: Error) => void
|
||||
const chunks: any[] = []
|
||||
@ -231,7 +267,7 @@ export default class JanProviderWeb extends AIEngine {
|
||||
object: chunk.object,
|
||||
created: chunk.created,
|
||||
model: chunk.model,
|
||||
choices: chunk.choices.map(choice => ({
|
||||
choices: chunk.choices.map((choice) => ({
|
||||
index: choice.index,
|
||||
delta: {
|
||||
role: choice.delta.role,
|
||||
@ -261,14 +297,14 @@ export default class JanProviderWeb extends AIEngine {
|
||||
if (abortController?.signal?.aborted) {
|
||||
throw new Error('Request was aborted')
|
||||
}
|
||||
|
||||
|
||||
while (yieldedIndex < chunks.length) {
|
||||
yield chunks[yieldedIndex]
|
||||
yieldedIndex++
|
||||
}
|
||||
|
||||
|
||||
// Wait a bit before checking again
|
||||
await new Promise(resolve => setTimeout(resolve, 10))
|
||||
await new Promise((resolve) => setTimeout(resolve, 10))
|
||||
}
|
||||
|
||||
// Yield any remaining chunks
|
||||
@ -291,24 +327,32 @@ export default class JanProviderWeb extends AIEngine {
|
||||
}
|
||||
|
||||
async delete(modelId: string): Promise<void> {
|
||||
throw new Error(`Delete operation not supported for remote Jan API model: ${modelId}`)
|
||||
throw new Error(
|
||||
`Delete operation not supported for remote Jan API model: ${modelId}`
|
||||
)
|
||||
}
|
||||
|
||||
async import(modelId: string, _opts: ImportOptions): Promise<void> {
|
||||
throw new Error(`Import operation not supported for remote Jan API model: ${modelId}`)
|
||||
throw new Error(
|
||||
`Import operation not supported for remote Jan API model: ${modelId}`
|
||||
)
|
||||
}
|
||||
|
||||
async abortImport(modelId: string): Promise<void> {
|
||||
throw new Error(`Abort import operation not supported for remote Jan API model: ${modelId}`)
|
||||
throw new Error(
|
||||
`Abort import operation not supported for remote Jan API model: ${modelId}`
|
||||
)
|
||||
}
|
||||
|
||||
async getLoadedModels(): Promise<string[]> {
|
||||
return Array.from(this.activeSessions.values()).map(session => session.model_id)
|
||||
return Array.from(this.activeSessions.values()).map(
|
||||
(session) => session.model_id
|
||||
)
|
||||
}
|
||||
|
||||
async isToolSupported(modelId: string): Promise<boolean> {
|
||||
// Jan models support tool calls via MCP
|
||||
console.log(`Checking tool support for Jan model ${modelId}: supported`);
|
||||
return true;
|
||||
console.log(`Checking tool support for Jan model ${modelId}: supported`)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -922,6 +922,30 @@ export default class llamacpp_extension extends AIEngine {
|
||||
return hash
|
||||
}
|
||||
|
||||
override async get(modelId: string): Promise<modelInfo | undefined> {
|
||||
const modelPath = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
const path = await joinPath([modelPath, 'model.yml'])
|
||||
|
||||
if (!(await fs.existsSync(path))) return undefined
|
||||
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path,
|
||||
})
|
||||
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelConfig.name ?? modelId,
|
||||
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||||
providerId: this.provider,
|
||||
port: 0, // port is not known until the model is loaded
|
||||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||||
} as modelInfo
|
||||
}
|
||||
|
||||
// Implement the required LocalProvider interface methods
|
||||
override async list(): Promise<modelInfo[]> {
|
||||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||||
@ -1085,7 +1109,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
const archiveName = await basename(path)
|
||||
logger.info(`Installing backend from path: ${path}`)
|
||||
|
||||
if (!(await fs.existsSync(path)) || (!path.endsWith('tar.gz') && !path.endsWith('zip'))) {
|
||||
if (
|
||||
!(await fs.existsSync(path)) ||
|
||||
(!path.endsWith('tar.gz') && !path.endsWith('zip'))
|
||||
) {
|
||||
logger.error(`Invalid path or file ${path}`)
|
||||
throw new Error(`Invalid path or file ${path}`)
|
||||
}
|
||||
@ -2601,7 +2628,8 @@ export default class llamacpp_extension extends AIEngine {
|
||||
metadata: Record<string, string>
|
||||
): Promise<number> {
|
||||
// Extract vision parameters from metadata
|
||||
const projectionDim = Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
|
||||
const projectionDim =
|
||||
Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
|
||||
|
||||
// Count images in messages
|
||||
let imageCount = 0
|
||||
|
||||
142
web-app/src/containers/DownloadButton.tsx
Normal file
142
web-app/src/containers/DownloadButton.tsx
Normal file
@ -0,0 +1,142 @@
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
import { useTranslation } from '@/i18n'
|
||||
import { extractModelName } from '@/lib/models'
|
||||
import { cn, sanitizeModelId } from '@/lib/utils'
|
||||
import { CatalogModel } from '@/services/models/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useShallow } from 'zustand/shallow'
|
||||
|
||||
type ModelProps = {
|
||||
model: CatalogModel
|
||||
handleUseModel: (modelId: string) => void
|
||||
}
|
||||
const defaultModelQuantizations = ['iq4_xs', 'q4_k_m']
|
||||
|
||||
export function DownloadButtonPlaceholder({
|
||||
model,
|
||||
handleUseModel,
|
||||
}: ModelProps) {
|
||||
const { downloads, localDownloadingModels, addLocalDownloadingModel } =
|
||||
useDownloadStore(
|
||||
useShallow((state) => ({
|
||||
downloads: state.downloads,
|
||||
localDownloadingModels: state.localDownloadingModels,
|
||||
addLocalDownloadingModel: state.addLocalDownloadingModel,
|
||||
}))
|
||||
)
|
||||
const { t } = useTranslation()
|
||||
const getProviderByName = useModelProvider((state) => state.getProviderByName)
|
||||
const llamaProvider = getProviderByName('llamacpp')
|
||||
|
||||
const serviceHub = useServiceHub()
|
||||
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
|
||||
|
||||
const quant =
|
||||
model.quants.find((e) =>
|
||||
defaultModelQuantizations.some((m) =>
|
||||
e.model_id.toLowerCase().includes(m)
|
||||
)
|
||||
) ?? model.quants[0]
|
||||
|
||||
const modelId = quant?.model_id || model.model_name
|
||||
|
||||
const downloadProcesses = useMemo(
|
||||
() =>
|
||||
Object.values(downloads).map((download) => ({
|
||||
id: download.name,
|
||||
name: download.name,
|
||||
progress: download.progress,
|
||||
current: download.current,
|
||||
total: download.total,
|
||||
})),
|
||||
[downloads]
|
||||
)
|
||||
|
||||
const isRecommendedModel = useCallback((modelId: string) => {
|
||||
return (extractModelName(modelId)?.toLowerCase() ===
|
||||
'jan-nano-gguf') as boolean
|
||||
}, [])
|
||||
|
||||
if (model.quants.length === 0) {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
window.open(`https://huggingface.co/${model.model_name}`, '_blank')
|
||||
}}
|
||||
>
|
||||
View on HuggingFace
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const modelUrl = quant?.path || modelId
|
||||
const isDownloading =
|
||||
localDownloadingModels.has(modelId) ||
|
||||
downloadProcesses.some((e) => e.id === modelId)
|
||||
|
||||
const downloadProgress =
|
||||
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
|
||||
const isDownloaded = llamaProvider?.models.some(
|
||||
(m: { id: string }) =>
|
||||
m.id === modelId ||
|
||||
m.id === `${model.developer}/${sanitizeModelId(modelId)}`
|
||||
)
|
||||
const isRecommended = isRecommendedModel(model.model_name)
|
||||
|
||||
const handleDownload = () => {
|
||||
// Immediately set local downloading state
|
||||
addLocalDownloadingModel(modelId)
|
||||
const mmprojPath = (
|
||||
model.mmproj_models?.find(
|
||||
(e) => e.model_id.toLowerCase() === 'mmproj-f16'
|
||||
) || model.mmproj_models?.[0]
|
||||
)?.path
|
||||
serviceHub
|
||||
.models()
|
||||
.pullModelWithMetadata(modelId, modelUrl, mmprojPath, huggingfaceToken)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center',
|
||||
isRecommended && 'hub-download-button-step'
|
||||
)}
|
||||
>
|
||||
{isDownloading && !isDownloaded && (
|
||||
<div className={cn('flex items-center gap-2 w-20')}>
|
||||
<Progress value={downloadProgress * 100} />
|
||||
<span className="text-xs text-center text-main-view-fg/70">
|
||||
{Math.round(downloadProgress * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{isDownloaded ? (
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => handleUseModel(modelId)}
|
||||
data-test-id={`hub-model-${modelId}`}
|
||||
>
|
||||
{t('hub:use')}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
data-test-id={`hub-model-${modelId}`}
|
||||
size="sm"
|
||||
onClick={handleDownload}
|
||||
className={cn(isDownloading && 'hidden')}
|
||||
>
|
||||
{t('hub:download')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -89,6 +89,7 @@ const CodeComponent = memo(
|
||||
onCopy,
|
||||
copiedId,
|
||||
...props
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}: any) => {
|
||||
const { t } = useTranslation()
|
||||
const match = /language-(\w+)/.exec(className || '')
|
||||
|
||||
@ -21,10 +21,7 @@ import { useEffect, useMemo, useCallback, useState } from 'react'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
import type {
|
||||
CatalogModel,
|
||||
ModelQuant,
|
||||
} from '@/services/models/types'
|
||||
import type { CatalogModel, ModelQuant } from '@/services/models/types'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { cn } from '@/lib/utils'
|
||||
@ -80,12 +77,13 @@ function HubModelDetailContent() {
|
||||
}, [fetchSources])
|
||||
|
||||
const fetchRepo = useCallback(async () => {
|
||||
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(
|
||||
search.repo || modelId,
|
||||
huggingfaceToken
|
||||
)
|
||||
const repoInfo = await serviceHub
|
||||
.models()
|
||||
.fetchHuggingFaceRepo(search.repo || modelId, huggingfaceToken)
|
||||
if (repoInfo) {
|
||||
const repoDetail = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
|
||||
const repoDetail = serviceHub
|
||||
.models()
|
||||
.convertHfRepoToCatalogModel(repoInfo)
|
||||
setRepoData(repoDetail || undefined)
|
||||
}
|
||||
}, [serviceHub, modelId, search, huggingfaceToken])
|
||||
@ -168,7 +166,9 @@ function HubModelDetailContent() {
|
||||
try {
|
||||
// Use the HuggingFace path for the model
|
||||
const modelPath = variant.path
|
||||
const supported = await serviceHub.models().isModelSupported(modelPath, 8192)
|
||||
const supported = await serviceHub
|
||||
.models()
|
||||
.isModelSupported(modelPath, 8192)
|
||||
setModelSupportStatus((prev) => ({
|
||||
...prev,
|
||||
[modelKey]: supported,
|
||||
@ -473,12 +473,20 @@ function HubModelDetailContent() {
|
||||
addLocalDownloadingModel(
|
||||
variant.model_id
|
||||
)
|
||||
serviceHub.models().pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path,
|
||||
modelData.mmproj_models?.[0]?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
serviceHub
|
||||
.models()
|
||||
.pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path,
|
||||
(
|
||||
modelData.mmproj_models?.find(
|
||||
(e) =>
|
||||
e.model_id.toLowerCase() ===
|
||||
'mmproj-f16'
|
||||
) || modelData.mmproj_models?.[0]
|
||||
)?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
}}
|
||||
className={cn(isDownloading && 'hidden')}
|
||||
>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
import { createFileRoute, useNavigate, useSearch } from '@tanstack/react-router'
|
||||
import { createFileRoute, useNavigate } from '@tanstack/react-router'
|
||||
import { route } from '@/constants/routes'
|
||||
import { useModelSources } from '@/hooks/useModelSources'
|
||||
import { cn } from '@/lib/utils'
|
||||
@ -34,8 +34,6 @@ import {
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard'
|
||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
@ -51,10 +49,9 @@ import { Loader } from 'lucide-react'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import Fuse from 'fuse.js'
|
||||
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||
import { DownloadButtonPlaceholder } from '@/containers/DownloadButton'
|
||||
import { useShallow } from 'zustand/shallow'
|
||||
|
||||
type ModelProps = {
|
||||
model: CatalogModel
|
||||
}
|
||||
type SearchParams = {
|
||||
repo: string
|
||||
}
|
||||
@ -77,7 +74,7 @@ function Hub() {
|
||||
|
||||
function HubContent() {
|
||||
const parentRef = useRef(null)
|
||||
const { huggingfaceToken } = useGeneralSetting()
|
||||
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
|
||||
const serviceHub = useServiceHub()
|
||||
|
||||
const { t } = useTranslation()
|
||||
@ -93,7 +90,13 @@ function HubContent() {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const { sources, fetchSources, loading } = useModelSources()
|
||||
const { sources, fetchSources, loading } = useModelSources(
|
||||
useShallow((state) => ({
|
||||
sources: state.sources,
|
||||
fetchSources: state.fetchSources,
|
||||
loading: state.loading,
|
||||
}))
|
||||
)
|
||||
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const [sortSelected, setSortSelected] = useState('newest')
|
||||
@ -108,16 +111,9 @@ function HubContent() {
|
||||
const [modelSupportStatus, setModelSupportStatus] = useState<
|
||||
Record<string, 'RED' | 'YELLOW' | 'GREEN' | 'LOADING'>
|
||||
>({})
|
||||
const [joyrideReady, setJoyrideReady] = useState(false)
|
||||
const [currentStepIndex, setCurrentStepIndex] = useState(0)
|
||||
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null
|
||||
)
|
||||
const downloadButtonRef = useRef<HTMLButtonElement>(null)
|
||||
const hasTriggeredDownload = useRef(false)
|
||||
|
||||
const { getProviderByName } = useModelProvider()
|
||||
const llamaProvider = getProviderByName('llamacpp')
|
||||
|
||||
const toggleModelExpansion = (modelId: string) => {
|
||||
setExpandedModels((prev) => ({
|
||||
@ -168,9 +164,10 @@ function HubContent() {
|
||||
?.map((model) => ({
|
||||
...model,
|
||||
quants: model.quants.filter((variant) =>
|
||||
llamaProvider?.models.some(
|
||||
(m: { id: string }) => m.id === variant.model_id
|
||||
)
|
||||
useModelProvider
|
||||
.getState()
|
||||
.getProviderByName('llamacpp')
|
||||
?.models.some((m: { id: string }) => m.id === variant.model_id)
|
||||
),
|
||||
}))
|
||||
.filter((model) => model.quants.length > 0)
|
||||
@ -186,7 +183,6 @@ function HubContent() {
|
||||
showOnlyDownloaded,
|
||||
huggingFaceRepo,
|
||||
searchOptions,
|
||||
llamaProvider?.models,
|
||||
])
|
||||
|
||||
// The virtualizer
|
||||
@ -215,9 +211,13 @@ function HubContent() {
|
||||
|
||||
addModelSourceTimeoutRef.current = setTimeout(async () => {
|
||||
try {
|
||||
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(searchValue, huggingfaceToken)
|
||||
const repoInfo = await serviceHub
|
||||
.models()
|
||||
.fetchHuggingFaceRepo(searchValue, huggingfaceToken)
|
||||
if (repoInfo) {
|
||||
const catalogModel = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
|
||||
const catalogModel = serviceHub
|
||||
.models()
|
||||
.convertHfRepoToCatalogModel(repoInfo)
|
||||
if (
|
||||
!sources.some(
|
||||
(s) =>
|
||||
@ -303,7 +303,9 @@ function HubContent() {
|
||||
try {
|
||||
// Use the HuggingFace path for the model
|
||||
const modelPath = variant.path
|
||||
const supportStatus = await serviceHub.models().isModelSupported(modelPath, 8192)
|
||||
const supportStatus = await serviceHub
|
||||
.models()
|
||||
.isModelSupported(modelPath, 8192)
|
||||
|
||||
setModelSupportStatus((prev) => ({
|
||||
...prev,
|
||||
@ -320,178 +322,7 @@ function HubContent() {
|
||||
[modelSupportStatus, serviceHub]
|
||||
)
|
||||
|
||||
const DownloadButtonPlaceholder = useMemo(() => {
|
||||
return ({ model }: ModelProps) => {
|
||||
// Check if this is a HuggingFace repository (no quants)
|
||||
if (model.quants.length === 0) {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
window.open(
|
||||
`https://huggingface.co/${model.model_name}`,
|
||||
'_blank'
|
||||
)
|
||||
}}
|
||||
>
|
||||
View on HuggingFace
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const quant =
|
||||
model.quants.find((e) =>
|
||||
defaultModelQuantizations.some((m) =>
|
||||
e.model_id.toLowerCase().includes(m)
|
||||
)
|
||||
) ?? model.quants[0]
|
||||
const modelId = quant?.model_id || model.model_name
|
||||
const modelUrl = quant?.path || modelId
|
||||
const isDownloading =
|
||||
localDownloadingModels.has(modelId) ||
|
||||
downloadProcesses.some((e) => e.id === modelId)
|
||||
const downloadProgress =
|
||||
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
|
||||
const isDownloaded = llamaProvider?.models.some(
|
||||
(m: { id: string }) => m.id === modelId
|
||||
)
|
||||
const isRecommended = isRecommendedModel(model.model_name)
|
||||
|
||||
const handleDownload = () => {
|
||||
// Immediately set local downloading state
|
||||
addLocalDownloadingModel(modelId)
|
||||
const mmprojPath = model.mmproj_models?.[0]?.path
|
||||
serviceHub.models().pullModelWithMetadata(
|
||||
modelId,
|
||||
modelUrl,
|
||||
mmprojPath,
|
||||
huggingfaceToken
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center',
|
||||
isRecommended && 'hub-download-button-step'
|
||||
)}
|
||||
>
|
||||
{isDownloading && !isDownloaded && (
|
||||
<div className={cn('flex items-center gap-2 w-20')}>
|
||||
<Progress value={downloadProgress * 100} />
|
||||
<span className="text-xs text-center text-main-view-fg/70">
|
||||
{Math.round(downloadProgress * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{isDownloaded ? (
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => handleUseModel(modelId)}
|
||||
data-test-id={`hub-model-${modelId}`}
|
||||
>
|
||||
{t('hub:use')}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
data-test-id={`hub-model-${modelId}`}
|
||||
size="sm"
|
||||
onClick={handleDownload}
|
||||
className={cn(isDownloading && 'hidden')}
|
||||
ref={isRecommended ? downloadButtonRef : undefined}
|
||||
>
|
||||
{t('hub:download')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
}, [
|
||||
localDownloadingModels,
|
||||
downloadProcesses,
|
||||
llamaProvider?.models,
|
||||
isRecommendedModel,
|
||||
t,
|
||||
addLocalDownloadingModel,
|
||||
huggingfaceToken,
|
||||
handleUseModel,
|
||||
serviceHub,
|
||||
])
|
||||
|
||||
const { step } = useSearch({ from: Route.id })
|
||||
const isSetup = step === 'setup_local_provider'
|
||||
|
||||
// Wait for DOM to be ready before starting Joyride
|
||||
useEffect(() => {
|
||||
if (!loading && filteredModels.length > 0 && isSetup) {
|
||||
const timer = setTimeout(() => {
|
||||
setJoyrideReady(true)
|
||||
}, 100)
|
||||
return () => clearTimeout(timer)
|
||||
} else {
|
||||
setJoyrideReady(false)
|
||||
}
|
||||
}, [loading, filteredModels.length, isSetup])
|
||||
|
||||
const handleJoyrideCallback = (data: CallBackProps) => {
|
||||
const { status, index } = data
|
||||
|
||||
if (
|
||||
status === STATUS.FINISHED &&
|
||||
!isDownloading &&
|
||||
isLastStep &&
|
||||
!hasTriggeredDownload.current
|
||||
) {
|
||||
const recommendedModel = filteredModels.find((model) =>
|
||||
isRecommendedModel(model.model_name)
|
||||
)
|
||||
if (recommendedModel && recommendedModel.quants[0]?.model_id) {
|
||||
if (downloadButtonRef.current) {
|
||||
hasTriggeredDownload.current = true
|
||||
downloadButtonRef.current.click()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if (status === STATUS.FINISHED) {
|
||||
navigate({
|
||||
to: route.hub.index,
|
||||
})
|
||||
}
|
||||
|
||||
// Track current step index
|
||||
setCurrentStepIndex(index)
|
||||
}
|
||||
|
||||
// Check if any model is currently downloading
|
||||
const isDownloading =
|
||||
localDownloadingModels.size > 0 || downloadProcesses.length > 0
|
||||
|
||||
const steps = [
|
||||
{
|
||||
target: '.hub-model-card-step',
|
||||
title: t('hub:joyride.recommendedModelTitle'),
|
||||
disableBeacon: true,
|
||||
content: t('hub:joyride.recommendedModelContent'),
|
||||
},
|
||||
{
|
||||
target: '.hub-download-button-step',
|
||||
title: isDownloading
|
||||
? t('hub:joyride.downloadInProgressTitle')
|
||||
: t('hub:joyride.downloadModelTitle'),
|
||||
disableBeacon: true,
|
||||
content: isDownloading
|
||||
? t('hub:joyride.downloadInProgressContent')
|
||||
: t('hub:joyride.downloadModelContent'),
|
||||
},
|
||||
]
|
||||
|
||||
// Check if we're on the last step
|
||||
const isLastStep = currentStepIndex === steps.length - 1
|
||||
|
||||
const renderFilter = () => {
|
||||
return (
|
||||
<>
|
||||
@ -544,31 +375,6 @@ function HubContent() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Joyride
|
||||
run={joyrideReady}
|
||||
floaterProps={{
|
||||
hideArrow: true,
|
||||
}}
|
||||
steps={steps}
|
||||
tooltipComponent={CustomTooltipJoyRide}
|
||||
spotlightPadding={0}
|
||||
continuous={true}
|
||||
showSkipButton={!isLastStep}
|
||||
hideCloseButton={true}
|
||||
spotlightClicks={true}
|
||||
disableOverlay={IS_LINUX}
|
||||
disableOverlayClose={true}
|
||||
callback={handleJoyrideCallback}
|
||||
locale={{
|
||||
back: t('hub:joyride.back'),
|
||||
close: t('hub:joyride.close'),
|
||||
last: !isDownloading
|
||||
? t('hub:joyride.lastWithDownload')
|
||||
: t('hub:joyride.last'),
|
||||
next: t('hub:joyride.next'),
|
||||
skip: t('hub:joyride.skip'),
|
||||
}}
|
||||
/>
|
||||
<div className="flex h-full w-full">
|
||||
<div className="flex flex-col h-full w-full ">
|
||||
<HeaderPage>
|
||||
@ -698,6 +504,7 @@ function HubContent() {
|
||||
/>
|
||||
<DownloadButtonPlaceholder
|
||||
model={filteredModels[virtualItem.index]}
|
||||
handleUseModel={handleUseModel}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@ -908,10 +715,13 @@ function HubContent() {
|
||||
(e) => e.id === variant.model_id
|
||||
)?.progress || 0
|
||||
const isDownloaded =
|
||||
llamaProvider?.models.some(
|
||||
(m: { id: string }) =>
|
||||
m.id === variant.model_id
|
||||
)
|
||||
useModelProvider
|
||||
.getState()
|
||||
.getProviderByName('llamacpp')
|
||||
?.models.some(
|
||||
(m: { id: string }) =>
|
||||
m.id === variant.model_id
|
||||
)
|
||||
|
||||
if (isDownloading) {
|
||||
return (
|
||||
@ -962,14 +772,26 @@ function HubContent() {
|
||||
addLocalDownloadingModel(
|
||||
variant.model_id
|
||||
)
|
||||
serviceHub.models().pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path,
|
||||
filteredModels[
|
||||
virtualItem.index
|
||||
].mmproj_models?.[0]?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
serviceHub
|
||||
.models()
|
||||
.pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path,
|
||||
|
||||
(
|
||||
filteredModels[
|
||||
virtualItem.index
|
||||
].mmproj_models?.find(
|
||||
(e) =>
|
||||
e.model_id.toLowerCase() ===
|
||||
'mmproj-f16'
|
||||
) ||
|
||||
filteredModels[
|
||||
virtualItem.index
|
||||
].mmproj_models?.[0]
|
||||
)?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
}}
|
||||
>
|
||||
<IconDownload
|
||||
|
||||
@ -22,7 +22,7 @@ Object.defineProperty(global, 'MODEL_CATALOG_URL', {
|
||||
|
||||
describe('DefaultModelsService', () => {
|
||||
let modelsService: DefaultModelsService
|
||||
|
||||
|
||||
const mockEngine = {
|
||||
list: vi.fn(),
|
||||
updateSettings: vi.fn(),
|
||||
@ -246,7 +246,9 @@ describe('DefaultModelsService', () => {
|
||||
})
|
||||
mockEngine.load.mockRejectedValue(error)
|
||||
|
||||
await expect(modelsService.startModel(provider, model)).rejects.toThrow(error)
|
||||
await expect(modelsService.startModel(provider, model)).rejects.toThrow(
|
||||
error
|
||||
)
|
||||
})
|
||||
it('should not load model again', async () => {
|
||||
const mockSettings = {
|
||||
@ -263,7 +265,9 @@ describe('DefaultModelsService', () => {
|
||||
includes: () => true,
|
||||
})
|
||||
expect(mockEngine.load).toBeCalledTimes(0)
|
||||
await expect(modelsService.startModel(provider, model)).resolves.toBe(undefined)
|
||||
await expect(modelsService.startModel(provider, model)).resolves.toBe(
|
||||
undefined
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@ -312,7 +316,9 @@ describe('DefaultModelsService', () => {
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
@ -342,7 +348,9 @@ describe('DefaultModelsService', () => {
|
||||
)
|
||||
|
||||
// Test with domain prefix
|
||||
await modelsService.fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
||||
await modelsService.fetchHuggingFaceRepo(
|
||||
'huggingface.co/microsoft/DialoGPT-medium'
|
||||
)
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||
{
|
||||
@ -365,7 +373,9 @@ describe('DefaultModelsService', () => {
|
||||
expect(await modelsService.fetchHuggingFaceRepo('')).toBeNull()
|
||||
|
||||
// Test string without slash
|
||||
expect(await modelsService.fetchHuggingFaceRepo('invalid-repo')).toBeNull()
|
||||
expect(
|
||||
await modelsService.fetchHuggingFaceRepo('invalid-repo')
|
||||
).toBeNull()
|
||||
|
||||
// Test whitespace only
|
||||
expect(await modelsService.fetchHuggingFaceRepo(' ')).toBeNull()
|
||||
@ -378,7 +388,8 @@ describe('DefaultModelsService', () => {
|
||||
statusText: 'Not Found',
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('nonexistent/model')
|
||||
const result =
|
||||
await modelsService.fetchHuggingFaceRepo('nonexistent/model')
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
@ -398,7 +409,9 @@ describe('DefaultModelsService', () => {
|
||||
statusText: 'Internal Server Error',
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
@ -414,7 +427,9 @@ describe('DefaultModelsService', () => {
|
||||
|
||||
;(fetch as any).mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
@ -448,7 +463,9 @@ describe('DefaultModelsService', () => {
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
})
|
||||
@ -487,7 +504,9 @@ describe('DefaultModelsService', () => {
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
})
|
||||
@ -531,7 +550,9 @@ describe('DefaultModelsService', () => {
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
const result = await modelsService.fetchHuggingFaceRepo(
|
||||
'microsoft/DialoGPT-medium'
|
||||
)
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
// Verify the GGUF file is present in siblings
|
||||
@ -576,7 +597,8 @@ describe('DefaultModelsService', () => {
|
||||
}
|
||||
|
||||
it('should convert HuggingFace repo to catalog model format', () => {
|
||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
const expected: CatalogModel = {
|
||||
model_name: 'microsoft/DialoGPT-medium',
|
||||
@ -586,12 +608,12 @@ describe('DefaultModelsService', () => {
|
||||
num_quants: 2,
|
||||
quants: [
|
||||
{
|
||||
model_id: 'model-q4_0',
|
||||
model_id: 'microsoft/model-q4_0',
|
||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf',
|
||||
file_size: '2.0 GB',
|
||||
},
|
||||
{
|
||||
model_id: 'model-q8_0',
|
||||
model_id: 'microsoft/model-q8_0',
|
||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q8_0.GGUF',
|
||||
file_size: '4.0 GB',
|
||||
},
|
||||
@ -635,7 +657,8 @@ describe('DefaultModelsService', () => {
|
||||
siblings: undefined,
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
|
||||
|
||||
expect(result.num_quants).toBe(0)
|
||||
expect(result.quants).toEqual([])
|
||||
@ -663,7 +686,9 @@ describe('DefaultModelsService', () => {
|
||||
],
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousFileSizes)
|
||||
const result = modelsService.convertHfRepoToCatalogModel(
|
||||
repoWithVariousFileSizes
|
||||
)
|
||||
|
||||
expect(result.quants[0].file_size).toBe('500.0 MB')
|
||||
expect(result.quants[1].file_size).toBe('3.5 GB')
|
||||
@ -676,7 +701,8 @@ describe('DefaultModelsService', () => {
|
||||
tags: [],
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
|
||||
|
||||
expect(result.description).toBe('**Tags**: ')
|
||||
})
|
||||
@ -687,7 +713,8 @@ describe('DefaultModelsService', () => {
|
||||
downloads: undefined as any,
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
|
||||
|
||||
expect(result.downloads).toBe(0)
|
||||
})
|
||||
@ -714,15 +741,17 @@ describe('DefaultModelsService', () => {
|
||||
],
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
|
||||
|
||||
expect(result.quants[0].model_id).toBe('model')
|
||||
expect(result.quants[1].model_id).toBe('MODEL')
|
||||
expect(result.quants[2].model_id).toBe('complex-model-name')
|
||||
expect(result.quants[0].model_id).toBe('microsoft/model')
|
||||
expect(result.quants[1].model_id).toBe('microsoft/MODEL')
|
||||
expect(result.quants[2].model_id).toBe('microsoft/complex-model-name')
|
||||
})
|
||||
|
||||
it('should generate correct download paths', () => {
|
||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.quants[0].path).toBe(
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf'
|
||||
@ -733,7 +762,8 @@ describe('DefaultModelsService', () => {
|
||||
})
|
||||
|
||||
it('should generate correct readme URL', () => {
|
||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.readme).toBe(
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
|
||||
@ -767,13 +797,14 @@ describe('DefaultModelsService', () => {
|
||||
],
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
|
||||
|
||||
expect(result.num_quants).toBe(3)
|
||||
expect(result.quants).toHaveLength(3)
|
||||
expect(result.quants[0].model_id).toBe('model-1')
|
||||
expect(result.quants[1].model_id).toBe('model-2')
|
||||
expect(result.quants[2].model_id).toBe('model-3')
|
||||
expect(result.quants[0].model_id).toBe('microsoft/model-1')
|
||||
expect(result.quants[1].model_id).toBe('microsoft/model-2')
|
||||
expect(result.quants[2].model_id).toBe('microsoft/model-3')
|
||||
})
|
||||
|
||||
it('should handle edge cases with file size formatting', () => {
|
||||
@ -798,7 +829,8 @@ describe('DefaultModelsService', () => {
|
||||
],
|
||||
}
|
||||
|
||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
|
||||
const result =
|
||||
modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
|
||||
|
||||
expect(result.quants[0].file_size).toBe('0.0 MB')
|
||||
expect(result.quants[1].file_size).toBe('1.0 GB')
|
||||
@ -850,7 +882,10 @@ describe('DefaultModelsService', () => {
|
||||
|
||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||
|
||||
const result = await modelsService.isModelSupported('/path/to/model.gguf', 4096)
|
||||
const result = await modelsService.isModelSupported(
|
||||
'/path/to/model.gguf',
|
||||
4096
|
||||
)
|
||||
|
||||
expect(result).toBe('GREEN')
|
||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||
@ -867,7 +902,10 @@ describe('DefaultModelsService', () => {
|
||||
|
||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||
|
||||
const result = await modelsService.isModelSupported('/path/to/model.gguf', 8192)
|
||||
const result = await modelsService.isModelSupported(
|
||||
'/path/to/model.gguf',
|
||||
8192
|
||||
)
|
||||
|
||||
expect(result).toBe('YELLOW')
|
||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||
@ -884,7 +922,9 @@ describe('DefaultModelsService', () => {
|
||||
|
||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||
|
||||
const result = await modelsService.isModelSupported('/path/to/large-model.gguf')
|
||||
const result = await modelsService.isModelSupported(
|
||||
'/path/to/large-model.gguf'
|
||||
)
|
||||
|
||||
expect(result).toBe('RED')
|
||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||
|
||||
@ -30,6 +30,10 @@ export class DefaultModelsService implements ModelsService {
|
||||
return EngineManager.instance().get(provider) as AIEngine | undefined
|
||||
}
|
||||
|
||||
async getModel(modelId: string): Promise<modelInfo | undefined> {
|
||||
return this.getEngine()?.get(modelId)
|
||||
}
|
||||
|
||||
async fetchModels(): Promise<modelInfo[]> {
|
||||
return this.getEngine()?.list() ?? []
|
||||
}
|
||||
@ -127,7 +131,7 @@ export class DefaultModelsService implements ModelsService {
|
||||
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||
|
||||
return {
|
||||
model_id: sanitizeModelId(modelId),
|
||||
model_id: `${repo.author}/${sanitizeModelId(modelId)}`,
|
||||
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
||||
file_size: formatFileSize(file.size),
|
||||
}
|
||||
|
||||
@ -90,6 +90,7 @@ export interface ModelPlan {
|
||||
}
|
||||
|
||||
export interface ModelsService {
|
||||
getModel(modelId: string): Promise<modelInfo | undefined>
|
||||
fetchModels(): Promise<modelInfo[]>
|
||||
fetchModelCatalog(): Promise<ModelCatalog>
|
||||
fetchHuggingFaceRepo(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user