fix: allow users to download the same model from different authors (#6577)

* fix: allow users to download the same model from different authors

* fix: importing models should have author name in the ID

* fix: incorrect model id show

* fix: tests

* fix: default to mmproj f16 instead of bf16

* fix: type

* fix: build error
This commit is contained in:
Louis 2025-09-24 17:57:10 +07:00 committed by GitHub
parent fe05478336
commit 57110d2bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 407 additions and 311 deletions

View File

@ -240,6 +240,12 @@ export abstract class AIEngine extends BaseExtension {
EngineManager.instance().register(this)
}
/**
* Gets model info
* @param modelId
*/
abstract get(modelId: string): Promise<modelInfo | undefined>
/**
* Lists available models
*/

View File

@ -22,7 +22,7 @@ export default class JanProviderWeb extends AIEngine {
override async onLoad() {
console.log('Loading Jan Provider Extension...')
try {
// Initialize authentication and fetch models
await janApiClient.initialize()
@ -37,20 +37,43 @@ export default class JanProviderWeb extends AIEngine {
override async onUnload() {
console.log('Unloading Jan Provider Extension...')
// Clear all sessions
for (const sessionId of this.activeSessions.keys()) {
await this.unload(sessionId)
}
janProviderStore.reset()
console.log('Jan Provider Extension unloaded')
}
async get(modelId: string): Promise<modelInfo | undefined> {
return janApiClient
.getModels()
.then((list) => list.find((e) => e.id === modelId))
.then((model) =>
model
? {
id: model.id,
name: model.id, // Use ID as name for now
quant_type: undefined,
providerId: this.provider,
port: 443, // HTTPS port for API
sizeBytes: 0, // Size not provided by Jan API
tags: [],
path: undefined, // Remote model, no local path
owned_by: model.owned_by,
object: model.object,
capabilities: ['tools'], // Jan models support both tools via MCP
}
: undefined
)
}
async list(): Promise<modelInfo[]> {
try {
const janModels = await janApiClient.getModels()
return janModels.map((model) => ({
id: model.id,
name: model.id, // Use ID as name for now
@ -75,7 +98,7 @@ export default class JanProviderWeb extends AIEngine {
// For Jan API, we don't actually "load" models in the traditional sense
// We just create a session reference for tracking
const sessionId = `jan-${modelId}-${Date.now()}`
const sessionInfo: SessionInfo = {
pid: Date.now(), // Use timestamp as pseudo-PID
port: 443, // HTTPS port
@ -85,8 +108,10 @@ export default class JanProviderWeb extends AIEngine {
}
this.activeSessions.set(sessionId, sessionInfo)
console.log(`Jan model session created: ${sessionId} for model ${modelId}`)
console.log(
`Jan model session created: ${sessionId} for model ${modelId}`
)
return sessionInfo
} catch (error) {
console.error(`Failed to load Jan model ${modelId}:`, error)
@ -97,23 +122,23 @@ export default class JanProviderWeb extends AIEngine {
async unload(sessionId: string): Promise<UnloadResult> {
try {
const session = this.activeSessions.get(sessionId)
if (!session) {
return {
success: false,
error: `Session ${sessionId} not found`
error: `Session ${sessionId} not found`,
}
}
this.activeSessions.delete(sessionId)
console.log(`Jan model session unloaded: ${sessionId}`)
return { success: true }
} catch (error) {
console.error(`Failed to unload Jan session ${sessionId}:`, error)
return {
success: false,
error: error instanceof Error ? error.message : 'Unknown error'
error: error instanceof Error ? error.message : 'Unknown error',
}
}
}
@ -136,9 +161,12 @@ export default class JanProviderWeb extends AIEngine {
}
// Convert core chat completion request to Jan API format
const janMessages: JanChatMessage[] = opts.messages.map(msg => ({
const janMessages: JanChatMessage[] = opts.messages.map((msg) => ({
role: msg.role as 'system' | 'user' | 'assistant',
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content)
content:
typeof msg.content === 'string'
? msg.content
: JSON.stringify(msg.content),
}))
const janRequest = {
@ -162,18 +190,18 @@ export default class JanProviderWeb extends AIEngine {
} else {
// Return single response
const response = await janApiClient.createChatCompletion(janRequest)
// Check if aborted after completion
if (abortController?.signal?.aborted) {
throw new Error('Request was aborted')
}
return {
id: response.id,
object: 'chat.completion' as const,
created: response.created,
model: response.model,
choices: response.choices.map(choice => ({
choices: response.choices.map((choice) => ({
index: choice.index,
message: {
role: choice.message.role,
@ -182,7 +210,12 @@ export default class JanProviderWeb extends AIEngine {
reasoning_content: choice.message.reasoning_content,
tool_calls: choice.message.tool_calls,
},
finish_reason: (choice.finish_reason || 'stop') as 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call',
finish_reason: (choice.finish_reason || 'stop') as
| 'stop'
| 'length'
| 'tool_calls'
| 'content_filter'
| 'function_call',
})),
usage: response.usage,
}
@ -193,7 +226,10 @@ export default class JanProviderWeb extends AIEngine {
}
}
private async *createStreamingGenerator(janRequest: any, abortController?: AbortController) {
private async *createStreamingGenerator(
janRequest: any,
abortController?: AbortController
) {
let resolve: () => void
let reject: (error: Error) => void
const chunks: any[] = []
@ -231,7 +267,7 @@ export default class JanProviderWeb extends AIEngine {
object: chunk.object,
created: chunk.created,
model: chunk.model,
choices: chunk.choices.map(choice => ({
choices: chunk.choices.map((choice) => ({
index: choice.index,
delta: {
role: choice.delta.role,
@ -261,14 +297,14 @@ export default class JanProviderWeb extends AIEngine {
if (abortController?.signal?.aborted) {
throw new Error('Request was aborted')
}
while (yieldedIndex < chunks.length) {
yield chunks[yieldedIndex]
yieldedIndex++
}
// Wait a bit before checking again
await new Promise(resolve => setTimeout(resolve, 10))
await new Promise((resolve) => setTimeout(resolve, 10))
}
// Yield any remaining chunks
@ -291,24 +327,32 @@ export default class JanProviderWeb extends AIEngine {
}
async delete(modelId: string): Promise<void> {
throw new Error(`Delete operation not supported for remote Jan API model: ${modelId}`)
throw new Error(
`Delete operation not supported for remote Jan API model: ${modelId}`
)
}
async import(modelId: string, _opts: ImportOptions): Promise<void> {
throw new Error(`Import operation not supported for remote Jan API model: ${modelId}`)
throw new Error(
`Import operation not supported for remote Jan API model: ${modelId}`
)
}
async abortImport(modelId: string): Promise<void> {
throw new Error(`Abort import operation not supported for remote Jan API model: ${modelId}`)
throw new Error(
`Abort import operation not supported for remote Jan API model: ${modelId}`
)
}
async getLoadedModels(): Promise<string[]> {
return Array.from(this.activeSessions.values()).map(session => session.model_id)
return Array.from(this.activeSessions.values()).map(
(session) => session.model_id
)
}
async isToolSupported(modelId: string): Promise<boolean> {
// Jan models support tool calls via MCP
console.log(`Checking tool support for Jan model ${modelId}: supported`);
return true;
console.log(`Checking tool support for Jan model ${modelId}: supported`)
return true
}
}
}

View File

@ -922,6 +922,30 @@ export default class llamacpp_extension extends AIEngine {
return hash
}
override async get(modelId: string): Promise<modelInfo | undefined> {
const modelPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
const path = await joinPath([modelPath, 'model.yml'])
if (!(await fs.existsSync(path))) return undefined
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path,
})
return {
id: modelId,
name: modelConfig.name ?? modelId,
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
providerId: this.provider,
port: 0, // port is not known until the model is loaded
sizeBytes: modelConfig.size_bytes ?? 0,
} as modelInfo
}
// Implement the required LocalProvider interface methods
override async list(): Promise<modelInfo[]> {
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
@ -1085,7 +1109,10 @@ export default class llamacpp_extension extends AIEngine {
const archiveName = await basename(path)
logger.info(`Installing backend from path: ${path}`)
if (!(await fs.existsSync(path)) || (!path.endsWith('tar.gz') && !path.endsWith('zip'))) {
if (
!(await fs.existsSync(path)) ||
(!path.endsWith('tar.gz') && !path.endsWith('zip'))
) {
logger.error(`Invalid path or file ${path}`)
throw new Error(`Invalid path or file ${path}`)
}
@ -2601,7 +2628,8 @@ export default class llamacpp_extension extends AIEngine {
metadata: Record<string, string>
): Promise<number> {
// Extract vision parameters from metadata
const projectionDim = Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
const projectionDim =
Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
// Count images in messages
let imageCount = 0

View File

@ -0,0 +1,142 @@
import { Button } from '@/components/ui/button'
import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useServiceHub } from '@/hooks/useServiceHub'
import { useTranslation } from '@/i18n'
import { extractModelName } from '@/lib/models'
import { cn, sanitizeModelId } from '@/lib/utils'
import { CatalogModel } from '@/services/models/types'
import { useCallback, useMemo } from 'react'
import { useShallow } from 'zustand/shallow'
type ModelProps = {
model: CatalogModel
handleUseModel: (modelId: string) => void
}
const defaultModelQuantizations = ['iq4_xs', 'q4_k_m']
export function DownloadButtonPlaceholder({
model,
handleUseModel,
}: ModelProps) {
const { downloads, localDownloadingModels, addLocalDownloadingModel } =
useDownloadStore(
useShallow((state) => ({
downloads: state.downloads,
localDownloadingModels: state.localDownloadingModels,
addLocalDownloadingModel: state.addLocalDownloadingModel,
}))
)
const { t } = useTranslation()
const getProviderByName = useModelProvider((state) => state.getProviderByName)
const llamaProvider = getProviderByName('llamacpp')
const serviceHub = useServiceHub()
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
const quant =
model.quants.find((e) =>
defaultModelQuantizations.some((m) =>
e.model_id.toLowerCase().includes(m)
)
) ?? model.quants[0]
const modelId = quant?.model_id || model.model_name
const downloadProcesses = useMemo(
() =>
Object.values(downloads).map((download) => ({
id: download.name,
name: download.name,
progress: download.progress,
current: download.current,
total: download.total,
})),
[downloads]
)
const isRecommendedModel = useCallback((modelId: string) => {
return (extractModelName(modelId)?.toLowerCase() ===
'jan-nano-gguf') as boolean
}, [])
if (model.quants.length === 0) {
return (
<div className="flex items-center gap-2">
<Button
size="sm"
onClick={() => {
window.open(`https://huggingface.co/${model.model_name}`, '_blank')
}}
>
View on HuggingFace
</Button>
</div>
)
}
const modelUrl = quant?.path || modelId
const isDownloading =
localDownloadingModels.has(modelId) ||
downloadProcesses.some((e) => e.id === modelId)
const downloadProgress =
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
const isDownloaded = llamaProvider?.models.some(
(m: { id: string }) =>
m.id === modelId ||
m.id === `${model.developer}/${sanitizeModelId(modelId)}`
)
const isRecommended = isRecommendedModel(model.model_name)
const handleDownload = () => {
// Immediately set local downloading state
addLocalDownloadingModel(modelId)
const mmprojPath = (
model.mmproj_models?.find(
(e) => e.model_id.toLowerCase() === 'mmproj-f16'
) || model.mmproj_models?.[0]
)?.path
serviceHub
.models()
.pullModelWithMetadata(modelId, modelUrl, mmprojPath, huggingfaceToken)
}
return (
<div
className={cn(
'flex items-center',
isRecommended && 'hub-download-button-step'
)}
>
{isDownloading && !isDownloaded && (
<div className={cn('flex items-center gap-2 w-20')}>
<Progress value={downloadProgress * 100} />
<span className="text-xs text-center text-main-view-fg/70">
{Math.round(downloadProgress * 100)}%
</span>
</div>
)}
{isDownloaded ? (
<Button
size="sm"
onClick={() => handleUseModel(modelId)}
data-test-id={`hub-model-${modelId}`}
>
{t('hub:use')}
</Button>
) : (
<Button
data-test-id={`hub-model-${modelId}`}
size="sm"
onClick={handleDownload}
className={cn(isDownloading && 'hidden')}
>
{t('hub:download')}
</Button>
)}
</div>
)
}

View File

@ -89,6 +89,7 @@ const CodeComponent = memo(
onCopy,
copiedId,
...props
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}: any) => {
const { t } = useTranslation()
const match = /language-(\w+)/.exec(className || '')

View File

@ -21,10 +21,7 @@ import { useEffect, useMemo, useCallback, useState } from 'react'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { useServiceHub } from '@/hooks/useServiceHub'
import type {
CatalogModel,
ModelQuant,
} from '@/services/models/types'
import type { CatalogModel, ModelQuant } from '@/services/models/types'
import { Progress } from '@/components/ui/progress'
import { Button } from '@/components/ui/button'
import { cn } from '@/lib/utils'
@ -80,12 +77,13 @@ function HubModelDetailContent() {
}, [fetchSources])
const fetchRepo = useCallback(async () => {
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(
search.repo || modelId,
huggingfaceToken
)
const repoInfo = await serviceHub
.models()
.fetchHuggingFaceRepo(search.repo || modelId, huggingfaceToken)
if (repoInfo) {
const repoDetail = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
const repoDetail = serviceHub
.models()
.convertHfRepoToCatalogModel(repoInfo)
setRepoData(repoDetail || undefined)
}
}, [serviceHub, modelId, search, huggingfaceToken])
@ -168,7 +166,9 @@ function HubModelDetailContent() {
try {
// Use the HuggingFace path for the model
const modelPath = variant.path
const supported = await serviceHub.models().isModelSupported(modelPath, 8192)
const supported = await serviceHub
.models()
.isModelSupported(modelPath, 8192)
setModelSupportStatus((prev) => ({
...prev,
[modelKey]: supported,
@ -473,12 +473,20 @@ function HubModelDetailContent() {
addLocalDownloadingModel(
variant.model_id
)
serviceHub.models().pullModelWithMetadata(
variant.model_id,
variant.path,
modelData.mmproj_models?.[0]?.path,
huggingfaceToken
)
serviceHub
.models()
.pullModelWithMetadata(
variant.model_id,
variant.path,
(
modelData.mmproj_models?.find(
(e) =>
e.model_id.toLowerCase() ===
'mmproj-f16'
) || modelData.mmproj_models?.[0]
)?.path,
huggingfaceToken
)
}}
className={cn(isDownloading && 'hidden')}
>

View File

@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { useVirtualizer } from '@tanstack/react-virtual'
import { createFileRoute, useNavigate, useSearch } from '@tanstack/react-router'
import { createFileRoute, useNavigate } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useModelSources } from '@/hooks/useModelSources'
import { cn } from '@/lib/utils'
@ -34,8 +34,6 @@ import {
TooltipTrigger,
} from '@/components/ui/tooltip'
import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard'
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import {
DropdownMenu,
DropdownMenuContent,
@ -51,10 +49,9 @@ import { Loader } from 'lucide-react'
import { useTranslation } from '@/i18n/react-i18next-compat'
import Fuse from 'fuse.js'
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { DownloadButtonPlaceholder } from '@/containers/DownloadButton'
import { useShallow } from 'zustand/shallow'
type ModelProps = {
model: CatalogModel
}
type SearchParams = {
repo: string
}
@ -77,7 +74,7 @@ function Hub() {
function HubContent() {
const parentRef = useRef(null)
const { huggingfaceToken } = useGeneralSetting()
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
const serviceHub = useServiceHub()
const { t } = useTranslation()
@ -93,7 +90,13 @@ function HubContent() {
}
}, [])
const { sources, fetchSources, loading } = useModelSources()
const { sources, fetchSources, loading } = useModelSources(
useShallow((state) => ({
sources: state.sources,
fetchSources: state.fetchSources,
loading: state.loading,
}))
)
const [searchValue, setSearchValue] = useState('')
const [sortSelected, setSortSelected] = useState('newest')
@ -108,16 +111,9 @@ function HubContent() {
const [modelSupportStatus, setModelSupportStatus] = useState<
Record<string, 'RED' | 'YELLOW' | 'GREEN' | 'LOADING'>
>({})
const [joyrideReady, setJoyrideReady] = useState(false)
const [currentStepIndex, setCurrentStepIndex] = useState(0)
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
null
)
const downloadButtonRef = useRef<HTMLButtonElement>(null)
const hasTriggeredDownload = useRef(false)
const { getProviderByName } = useModelProvider()
const llamaProvider = getProviderByName('llamacpp')
const toggleModelExpansion = (modelId: string) => {
setExpandedModels((prev) => ({
@ -168,9 +164,10 @@ function HubContent() {
?.map((model) => ({
...model,
quants: model.quants.filter((variant) =>
llamaProvider?.models.some(
(m: { id: string }) => m.id === variant.model_id
)
useModelProvider
.getState()
.getProviderByName('llamacpp')
?.models.some((m: { id: string }) => m.id === variant.model_id)
),
}))
.filter((model) => model.quants.length > 0)
@ -186,7 +183,6 @@ function HubContent() {
showOnlyDownloaded,
huggingFaceRepo,
searchOptions,
llamaProvider?.models,
])
// The virtualizer
@ -215,9 +211,13 @@ function HubContent() {
addModelSourceTimeoutRef.current = setTimeout(async () => {
try {
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(searchValue, huggingfaceToken)
const repoInfo = await serviceHub
.models()
.fetchHuggingFaceRepo(searchValue, huggingfaceToken)
if (repoInfo) {
const catalogModel = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
const catalogModel = serviceHub
.models()
.convertHfRepoToCatalogModel(repoInfo)
if (
!sources.some(
(s) =>
@ -303,7 +303,9 @@ function HubContent() {
try {
// Use the HuggingFace path for the model
const modelPath = variant.path
const supportStatus = await serviceHub.models().isModelSupported(modelPath, 8192)
const supportStatus = await serviceHub
.models()
.isModelSupported(modelPath, 8192)
setModelSupportStatus((prev) => ({
...prev,
@ -320,178 +322,7 @@ function HubContent() {
[modelSupportStatus, serviceHub]
)
const DownloadButtonPlaceholder = useMemo(() => {
return ({ model }: ModelProps) => {
// Check if this is a HuggingFace repository (no quants)
if (model.quants.length === 0) {
return (
<div className="flex items-center gap-2">
<Button
size="sm"
onClick={() => {
window.open(
`https://huggingface.co/${model.model_name}`,
'_blank'
)
}}
>
View on HuggingFace
</Button>
</div>
)
}
const quant =
model.quants.find((e) =>
defaultModelQuantizations.some((m) =>
e.model_id.toLowerCase().includes(m)
)
) ?? model.quants[0]
const modelId = quant?.model_id || model.model_name
const modelUrl = quant?.path || modelId
const isDownloading =
localDownloadingModels.has(modelId) ||
downloadProcesses.some((e) => e.id === modelId)
const downloadProgress =
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
const isDownloaded = llamaProvider?.models.some(
(m: { id: string }) => m.id === modelId
)
const isRecommended = isRecommendedModel(model.model_name)
const handleDownload = () => {
// Immediately set local downloading state
addLocalDownloadingModel(modelId)
const mmprojPath = model.mmproj_models?.[0]?.path
serviceHub.models().pullModelWithMetadata(
modelId,
modelUrl,
mmprojPath,
huggingfaceToken
)
}
return (
<div
className={cn(
'flex items-center',
isRecommended && 'hub-download-button-step'
)}
>
{isDownloading && !isDownloaded && (
<div className={cn('flex items-center gap-2 w-20')}>
<Progress value={downloadProgress * 100} />
<span className="text-xs text-center text-main-view-fg/70">
{Math.round(downloadProgress * 100)}%
</span>
</div>
)}
{isDownloaded ? (
<Button
size="sm"
onClick={() => handleUseModel(modelId)}
data-test-id={`hub-model-${modelId}`}
>
{t('hub:use')}
</Button>
) : (
<Button
data-test-id={`hub-model-${modelId}`}
size="sm"
onClick={handleDownload}
className={cn(isDownloading && 'hidden')}
ref={isRecommended ? downloadButtonRef : undefined}
>
{t('hub:download')}
</Button>
)}
</div>
)
}
}, [
localDownloadingModels,
downloadProcesses,
llamaProvider?.models,
isRecommendedModel,
t,
addLocalDownloadingModel,
huggingfaceToken,
handleUseModel,
serviceHub,
])
const { step } = useSearch({ from: Route.id })
const isSetup = step === 'setup_local_provider'
// Wait for DOM to be ready before starting Joyride
useEffect(() => {
if (!loading && filteredModels.length > 0 && isSetup) {
const timer = setTimeout(() => {
setJoyrideReady(true)
}, 100)
return () => clearTimeout(timer)
} else {
setJoyrideReady(false)
}
}, [loading, filteredModels.length, isSetup])
const handleJoyrideCallback = (data: CallBackProps) => {
const { status, index } = data
if (
status === STATUS.FINISHED &&
!isDownloading &&
isLastStep &&
!hasTriggeredDownload.current
) {
const recommendedModel = filteredModels.find((model) =>
isRecommendedModel(model.model_name)
)
if (recommendedModel && recommendedModel.quants[0]?.model_id) {
if (downloadButtonRef.current) {
hasTriggeredDownload.current = true
downloadButtonRef.current.click()
}
return
}
}
if (status === STATUS.FINISHED) {
navigate({
to: route.hub.index,
})
}
// Track current step index
setCurrentStepIndex(index)
}
// Check if any model is currently downloading
const isDownloading =
localDownloadingModels.size > 0 || downloadProcesses.length > 0
const steps = [
{
target: '.hub-model-card-step',
title: t('hub:joyride.recommendedModelTitle'),
disableBeacon: true,
content: t('hub:joyride.recommendedModelContent'),
},
{
target: '.hub-download-button-step',
title: isDownloading
? t('hub:joyride.downloadInProgressTitle')
: t('hub:joyride.downloadModelTitle'),
disableBeacon: true,
content: isDownloading
? t('hub:joyride.downloadInProgressContent')
: t('hub:joyride.downloadModelContent'),
},
]
// Check if we're on the last step
const isLastStep = currentStepIndex === steps.length - 1
const renderFilter = () => {
return (
<>
@ -544,31 +375,6 @@ function HubContent() {
return (
<>
<Joyride
run={joyrideReady}
floaterProps={{
hideArrow: true,
}}
steps={steps}
tooltipComponent={CustomTooltipJoyRide}
spotlightPadding={0}
continuous={true}
showSkipButton={!isLastStep}
hideCloseButton={true}
spotlightClicks={true}
disableOverlay={IS_LINUX}
disableOverlayClose={true}
callback={handleJoyrideCallback}
locale={{
back: t('hub:joyride.back'),
close: t('hub:joyride.close'),
last: !isDownloading
? t('hub:joyride.lastWithDownload')
: t('hub:joyride.last'),
next: t('hub:joyride.next'),
skip: t('hub:joyride.skip'),
}}
/>
<div className="flex h-full w-full">
<div className="flex flex-col h-full w-full ">
<HeaderPage>
@ -698,6 +504,7 @@ function HubContent() {
/>
<DownloadButtonPlaceholder
model={filteredModels[virtualItem.index]}
handleUseModel={handleUseModel}
/>
</div>
</div>
@ -908,10 +715,13 @@ function HubContent() {
(e) => e.id === variant.model_id
)?.progress || 0
const isDownloaded =
llamaProvider?.models.some(
(m: { id: string }) =>
m.id === variant.model_id
)
useModelProvider
.getState()
.getProviderByName('llamacpp')
?.models.some(
(m: { id: string }) =>
m.id === variant.model_id
)
if (isDownloading) {
return (
@ -962,14 +772,26 @@ function HubContent() {
addLocalDownloadingModel(
variant.model_id
)
serviceHub.models().pullModelWithMetadata(
variant.model_id,
variant.path,
filteredModels[
virtualItem.index
].mmproj_models?.[0]?.path,
huggingfaceToken
)
serviceHub
.models()
.pullModelWithMetadata(
variant.model_id,
variant.path,
(
filteredModels[
virtualItem.index
].mmproj_models?.find(
(e) =>
e.model_id.toLowerCase() ===
'mmproj-f16'
) ||
filteredModels[
virtualItem.index
].mmproj_models?.[0]
)?.path,
huggingfaceToken
)
}}
>
<IconDownload

View File

@ -22,7 +22,7 @@ Object.defineProperty(global, 'MODEL_CATALOG_URL', {
describe('DefaultModelsService', () => {
let modelsService: DefaultModelsService
const mockEngine = {
list: vi.fn(),
updateSettings: vi.fn(),
@ -246,7 +246,9 @@ describe('DefaultModelsService', () => {
})
mockEngine.load.mockRejectedValue(error)
await expect(modelsService.startModel(provider, model)).rejects.toThrow(error)
await expect(modelsService.startModel(provider, model)).rejects.toThrow(
error
)
})
it('should not load model again', async () => {
const mockSettings = {
@ -263,7 +265,9 @@ describe('DefaultModelsService', () => {
includes: () => true,
})
expect(mockEngine.load).toBeCalledTimes(0)
await expect(modelsService.startModel(provider, model)).resolves.toBe(undefined)
await expect(modelsService.startModel(provider, model)).resolves.toBe(
undefined
)
})
})
@ -312,7 +316,9 @@ describe('DefaultModelsService', () => {
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toEqual(mockRepoData)
expect(fetch).toHaveBeenCalledWith(
@ -342,7 +348,9 @@ describe('DefaultModelsService', () => {
)
// Test with domain prefix
await modelsService.fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
await modelsService.fetchHuggingFaceRepo(
'huggingface.co/microsoft/DialoGPT-medium'
)
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
{
@ -365,7 +373,9 @@ describe('DefaultModelsService', () => {
expect(await modelsService.fetchHuggingFaceRepo('')).toBeNull()
// Test string without slash
expect(await modelsService.fetchHuggingFaceRepo('invalid-repo')).toBeNull()
expect(
await modelsService.fetchHuggingFaceRepo('invalid-repo')
).toBeNull()
// Test whitespace only
expect(await modelsService.fetchHuggingFaceRepo(' ')).toBeNull()
@ -378,7 +388,8 @@ describe('DefaultModelsService', () => {
statusText: 'Not Found',
})
const result = await modelsService.fetchHuggingFaceRepo('nonexistent/model')
const result =
await modelsService.fetchHuggingFaceRepo('nonexistent/model')
expect(result).toBeNull()
expect(fetch).toHaveBeenCalledWith(
@ -398,7 +409,9 @@ describe('DefaultModelsService', () => {
statusText: 'Internal Server Error',
})
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toBeNull()
expect(consoleSpy).toHaveBeenCalledWith(
@ -414,7 +427,9 @@ describe('DefaultModelsService', () => {
;(fetch as any).mockRejectedValue(new Error('Network error'))
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toBeNull()
expect(consoleSpy).toHaveBeenCalledWith(
@ -448,7 +463,9 @@ describe('DefaultModelsService', () => {
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toEqual(mockRepoData)
})
@ -487,7 +504,9 @@ describe('DefaultModelsService', () => {
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toEqual(mockRepoData)
})
@ -531,7 +550,9 @@ describe('DefaultModelsService', () => {
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
const result = await modelsService.fetchHuggingFaceRepo(
'microsoft/DialoGPT-medium'
)
expect(result).toEqual(mockRepoData)
// Verify the GGUF file is present in siblings
@ -576,7 +597,8 @@ describe('DefaultModelsService', () => {
}
it('should convert HuggingFace repo to catalog model format', () => {
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
const result =
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
const expected: CatalogModel = {
model_name: 'microsoft/DialoGPT-medium',
@ -586,12 +608,12 @@ describe('DefaultModelsService', () => {
num_quants: 2,
quants: [
{
model_id: 'model-q4_0',
model_id: 'microsoft/model-q4_0',
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf',
file_size: '2.0 GB',
},
{
model_id: 'model-q8_0',
model_id: 'microsoft/model-q8_0',
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q8_0.GGUF',
file_size: '4.0 GB',
},
@ -635,7 +657,8 @@ describe('DefaultModelsService', () => {
siblings: undefined,
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
expect(result.num_quants).toBe(0)
expect(result.quants).toEqual([])
@ -663,7 +686,9 @@ describe('DefaultModelsService', () => {
],
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousFileSizes)
const result = modelsService.convertHfRepoToCatalogModel(
repoWithVariousFileSizes
)
expect(result.quants[0].file_size).toBe('500.0 MB')
expect(result.quants[1].file_size).toBe('3.5 GB')
@ -676,7 +701,8 @@ describe('DefaultModelsService', () => {
tags: [],
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
expect(result.description).toBe('**Tags**: ')
})
@ -687,7 +713,8 @@ describe('DefaultModelsService', () => {
downloads: undefined as any,
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
expect(result.downloads).toBe(0)
})
@ -714,15 +741,17 @@ describe('DefaultModelsService', () => {
],
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
expect(result.quants[0].model_id).toBe('model')
expect(result.quants[1].model_id).toBe('MODEL')
expect(result.quants[2].model_id).toBe('complex-model-name')
expect(result.quants[0].model_id).toBe('microsoft/model')
expect(result.quants[1].model_id).toBe('microsoft/MODEL')
expect(result.quants[2].model_id).toBe('microsoft/complex-model-name')
})
it('should generate correct download paths', () => {
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
const result =
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.quants[0].path).toBe(
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf'
@ -733,7 +762,8 @@ describe('DefaultModelsService', () => {
})
it('should generate correct readme URL', () => {
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
const result =
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.readme).toBe(
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
@ -767,13 +797,14 @@ describe('DefaultModelsService', () => {
],
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
expect(result.num_quants).toBe(3)
expect(result.quants).toHaveLength(3)
expect(result.quants[0].model_id).toBe('model-1')
expect(result.quants[1].model_id).toBe('model-2')
expect(result.quants[2].model_id).toBe('model-3')
expect(result.quants[0].model_id).toBe('microsoft/model-1')
expect(result.quants[1].model_id).toBe('microsoft/model-2')
expect(result.quants[2].model_id).toBe('microsoft/model-3')
})
it('should handle edge cases with file size formatting', () => {
@ -798,7 +829,8 @@ describe('DefaultModelsService', () => {
],
}
const result = modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
const result =
modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
expect(result.quants[0].file_size).toBe('0.0 MB')
expect(result.quants[1].file_size).toBe('1.0 GB')
@ -850,7 +882,10 @@ describe('DefaultModelsService', () => {
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
const result = await modelsService.isModelSupported('/path/to/model.gguf', 4096)
const result = await modelsService.isModelSupported(
'/path/to/model.gguf',
4096
)
expect(result).toBe('GREEN')
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
@ -867,7 +902,10 @@ describe('DefaultModelsService', () => {
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
const result = await modelsService.isModelSupported('/path/to/model.gguf', 8192)
const result = await modelsService.isModelSupported(
'/path/to/model.gguf',
8192
)
expect(result).toBe('YELLOW')
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
@ -884,7 +922,9 @@ describe('DefaultModelsService', () => {
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
const result = await modelsService.isModelSupported('/path/to/large-model.gguf')
const result = await modelsService.isModelSupported(
'/path/to/large-model.gguf'
)
expect(result).toBe('RED')
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(

View File

@ -30,6 +30,10 @@ export class DefaultModelsService implements ModelsService {
return EngineManager.instance().get(provider) as AIEngine | undefined
}
async getModel(modelId: string): Promise<modelInfo | undefined> {
return this.getEngine()?.get(modelId)
}
async fetchModels(): Promise<modelInfo[]> {
return this.getEngine()?.list() ?? []
}
@ -127,7 +131,7 @@ export class DefaultModelsService implements ModelsService {
const modelId = file.rfilename.replace(/\.gguf$/i, '')
return {
model_id: sanitizeModelId(modelId),
model_id: `${repo.author}/${sanitizeModelId(modelId)}`,
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
file_size: formatFileSize(file.size),
}

View File

@ -90,6 +90,7 @@ export interface ModelPlan {
}
export interface ModelsService {
getModel(modelId: string): Promise<modelInfo | undefined>
fetchModels(): Promise<modelInfo[]>
fetchModelCatalog(): Promise<ModelCatalog>
fetchHuggingFaceRepo(