fix: allow users to download the same model from different authors (#6577)
* fix: allow users to download the same model from different authors * fix: importing models should have author name in the ID * fix: incorrect model id show * fix: tests * fix: default to mmproj f16 instead of bf16 * fix: type * fix: build error
This commit is contained in:
parent
fe05478336
commit
57110d2bd7
@ -240,6 +240,12 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
EngineManager.instance().register(this)
|
EngineManager.instance().register(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets model info
|
||||||
|
* @param modelId
|
||||||
|
*/
|
||||||
|
abstract get(modelId: string): Promise<modelInfo | undefined>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lists available models
|
* Lists available models
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -47,6 +47,29 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
console.log('Jan Provider Extension unloaded')
|
console.log('Jan Provider Extension unloaded')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async get(modelId: string): Promise<modelInfo | undefined> {
|
||||||
|
return janApiClient
|
||||||
|
.getModels()
|
||||||
|
.then((list) => list.find((e) => e.id === modelId))
|
||||||
|
.then((model) =>
|
||||||
|
model
|
||||||
|
? {
|
||||||
|
id: model.id,
|
||||||
|
name: model.id, // Use ID as name for now
|
||||||
|
quant_type: undefined,
|
||||||
|
providerId: this.provider,
|
||||||
|
port: 443, // HTTPS port for API
|
||||||
|
sizeBytes: 0, // Size not provided by Jan API
|
||||||
|
tags: [],
|
||||||
|
path: undefined, // Remote model, no local path
|
||||||
|
owned_by: model.owned_by,
|
||||||
|
object: model.object,
|
||||||
|
capabilities: ['tools'], // Jan models support both tools via MCP
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
async list(): Promise<modelInfo[]> {
|
async list(): Promise<modelInfo[]> {
|
||||||
try {
|
try {
|
||||||
const janModels = await janApiClient.getModels()
|
const janModels = await janApiClient.getModels()
|
||||||
@ -86,7 +109,9 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
|
|
||||||
this.activeSessions.set(sessionId, sessionInfo)
|
this.activeSessions.set(sessionId, sessionInfo)
|
||||||
|
|
||||||
console.log(`Jan model session created: ${sessionId} for model ${modelId}`)
|
console.log(
|
||||||
|
`Jan model session created: ${sessionId} for model ${modelId}`
|
||||||
|
)
|
||||||
return sessionInfo
|
return sessionInfo
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to load Jan model ${modelId}:`, error)
|
console.error(`Failed to load Jan model ${modelId}:`, error)
|
||||||
@ -101,7 +126,7 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
if (!session) {
|
if (!session) {
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: `Session ${sessionId} not found`
|
error: `Session ${sessionId} not found`,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,7 +138,7 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
console.error(`Failed to unload Jan session ${sessionId}:`, error)
|
console.error(`Failed to unload Jan session ${sessionId}:`, error)
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: error instanceof Error ? error.message : 'Unknown error'
|
error: error instanceof Error ? error.message : 'Unknown error',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,9 +161,12 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert core chat completion request to Jan API format
|
// Convert core chat completion request to Jan API format
|
||||||
const janMessages: JanChatMessage[] = opts.messages.map(msg => ({
|
const janMessages: JanChatMessage[] = opts.messages.map((msg) => ({
|
||||||
role: msg.role as 'system' | 'user' | 'assistant',
|
role: msg.role as 'system' | 'user' | 'assistant',
|
||||||
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content)
|
content:
|
||||||
|
typeof msg.content === 'string'
|
||||||
|
? msg.content
|
||||||
|
: JSON.stringify(msg.content),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const janRequest = {
|
const janRequest = {
|
||||||
@ -173,7 +201,7 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
object: 'chat.completion' as const,
|
object: 'chat.completion' as const,
|
||||||
created: response.created,
|
created: response.created,
|
||||||
model: response.model,
|
model: response.model,
|
||||||
choices: response.choices.map(choice => ({
|
choices: response.choices.map((choice) => ({
|
||||||
index: choice.index,
|
index: choice.index,
|
||||||
message: {
|
message: {
|
||||||
role: choice.message.role,
|
role: choice.message.role,
|
||||||
@ -182,7 +210,12 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
reasoning_content: choice.message.reasoning_content,
|
reasoning_content: choice.message.reasoning_content,
|
||||||
tool_calls: choice.message.tool_calls,
|
tool_calls: choice.message.tool_calls,
|
||||||
},
|
},
|
||||||
finish_reason: (choice.finish_reason || 'stop') as 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call',
|
finish_reason: (choice.finish_reason || 'stop') as
|
||||||
|
| 'stop'
|
||||||
|
| 'length'
|
||||||
|
| 'tool_calls'
|
||||||
|
| 'content_filter'
|
||||||
|
| 'function_call',
|
||||||
})),
|
})),
|
||||||
usage: response.usage,
|
usage: response.usage,
|
||||||
}
|
}
|
||||||
@ -193,7 +226,10 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async *createStreamingGenerator(janRequest: any, abortController?: AbortController) {
|
private async *createStreamingGenerator(
|
||||||
|
janRequest: any,
|
||||||
|
abortController?: AbortController
|
||||||
|
) {
|
||||||
let resolve: () => void
|
let resolve: () => void
|
||||||
let reject: (error: Error) => void
|
let reject: (error: Error) => void
|
||||||
const chunks: any[] = []
|
const chunks: any[] = []
|
||||||
@ -231,7 +267,7 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
object: chunk.object,
|
object: chunk.object,
|
||||||
created: chunk.created,
|
created: chunk.created,
|
||||||
model: chunk.model,
|
model: chunk.model,
|
||||||
choices: chunk.choices.map(choice => ({
|
choices: chunk.choices.map((choice) => ({
|
||||||
index: choice.index,
|
index: choice.index,
|
||||||
delta: {
|
delta: {
|
||||||
role: choice.delta.role,
|
role: choice.delta.role,
|
||||||
@ -268,7 +304,7 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait a bit before checking again
|
// Wait a bit before checking again
|
||||||
await new Promise(resolve => setTimeout(resolve, 10))
|
await new Promise((resolve) => setTimeout(resolve, 10))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yield any remaining chunks
|
// Yield any remaining chunks
|
||||||
@ -291,24 +327,32 @@ export default class JanProviderWeb extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async delete(modelId: string): Promise<void> {
|
async delete(modelId: string): Promise<void> {
|
||||||
throw new Error(`Delete operation not supported for remote Jan API model: ${modelId}`)
|
throw new Error(
|
||||||
|
`Delete operation not supported for remote Jan API model: ${modelId}`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async import(modelId: string, _opts: ImportOptions): Promise<void> {
|
async import(modelId: string, _opts: ImportOptions): Promise<void> {
|
||||||
throw new Error(`Import operation not supported for remote Jan API model: ${modelId}`)
|
throw new Error(
|
||||||
|
`Import operation not supported for remote Jan API model: ${modelId}`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async abortImport(modelId: string): Promise<void> {
|
async abortImport(modelId: string): Promise<void> {
|
||||||
throw new Error(`Abort import operation not supported for remote Jan API model: ${modelId}`)
|
throw new Error(
|
||||||
|
`Abort import operation not supported for remote Jan API model: ${modelId}`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async getLoadedModels(): Promise<string[]> {
|
async getLoadedModels(): Promise<string[]> {
|
||||||
return Array.from(this.activeSessions.values()).map(session => session.model_id)
|
return Array.from(this.activeSessions.values()).map(
|
||||||
|
(session) => session.model_id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async isToolSupported(modelId: string): Promise<boolean> {
|
async isToolSupported(modelId: string): Promise<boolean> {
|
||||||
// Jan models support tool calls via MCP
|
// Jan models support tool calls via MCP
|
||||||
console.log(`Checking tool support for Jan model ${modelId}: supported`);
|
console.log(`Checking tool support for Jan model ${modelId}: supported`)
|
||||||
return true;
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -922,6 +922,30 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override async get(modelId: string): Promise<modelInfo | undefined> {
|
||||||
|
const modelPath = await joinPath([
|
||||||
|
await this.getProviderPath(),
|
||||||
|
'models',
|
||||||
|
modelId,
|
||||||
|
])
|
||||||
|
const path = await joinPath([modelPath, 'model.yml'])
|
||||||
|
|
||||||
|
if (!(await fs.existsSync(path))) return undefined
|
||||||
|
|
||||||
|
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||||
|
path,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: modelId,
|
||||||
|
name: modelConfig.name ?? modelId,
|
||||||
|
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||||||
|
providerId: this.provider,
|
||||||
|
port: 0, // port is not known until the model is loaded
|
||||||
|
sizeBytes: modelConfig.size_bytes ?? 0,
|
||||||
|
} as modelInfo
|
||||||
|
}
|
||||||
|
|
||||||
// Implement the required LocalProvider interface methods
|
// Implement the required LocalProvider interface methods
|
||||||
override async list(): Promise<modelInfo[]> {
|
override async list(): Promise<modelInfo[]> {
|
||||||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||||||
@ -1085,7 +1109,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
const archiveName = await basename(path)
|
const archiveName = await basename(path)
|
||||||
logger.info(`Installing backend from path: ${path}`)
|
logger.info(`Installing backend from path: ${path}`)
|
||||||
|
|
||||||
if (!(await fs.existsSync(path)) || (!path.endsWith('tar.gz') && !path.endsWith('zip'))) {
|
if (
|
||||||
|
!(await fs.existsSync(path)) ||
|
||||||
|
(!path.endsWith('tar.gz') && !path.endsWith('zip'))
|
||||||
|
) {
|
||||||
logger.error(`Invalid path or file ${path}`)
|
logger.error(`Invalid path or file ${path}`)
|
||||||
throw new Error(`Invalid path or file ${path}`)
|
throw new Error(`Invalid path or file ${path}`)
|
||||||
}
|
}
|
||||||
@ -2601,7 +2628,8 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
metadata: Record<string, string>
|
metadata: Record<string, string>
|
||||||
): Promise<number> {
|
): Promise<number> {
|
||||||
// Extract vision parameters from metadata
|
// Extract vision parameters from metadata
|
||||||
const projectionDim = Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
|
const projectionDim =
|
||||||
|
Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
|
||||||
|
|
||||||
// Count images in messages
|
// Count images in messages
|
||||||
let imageCount = 0
|
let imageCount = 0
|
||||||
|
|||||||
142
web-app/src/containers/DownloadButton.tsx
Normal file
142
web-app/src/containers/DownloadButton.tsx
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { Progress } from '@/components/ui/progress'
|
||||||
|
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||||
|
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||||
|
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||||
|
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||||
|
import { useTranslation } from '@/i18n'
|
||||||
|
import { extractModelName } from '@/lib/models'
|
||||||
|
import { cn, sanitizeModelId } from '@/lib/utils'
|
||||||
|
import { CatalogModel } from '@/services/models/types'
|
||||||
|
import { useCallback, useMemo } from 'react'
|
||||||
|
import { useShallow } from 'zustand/shallow'
|
||||||
|
|
||||||
|
type ModelProps = {
|
||||||
|
model: CatalogModel
|
||||||
|
handleUseModel: (modelId: string) => void
|
||||||
|
}
|
||||||
|
const defaultModelQuantizations = ['iq4_xs', 'q4_k_m']
|
||||||
|
|
||||||
|
export function DownloadButtonPlaceholder({
|
||||||
|
model,
|
||||||
|
handleUseModel,
|
||||||
|
}: ModelProps) {
|
||||||
|
const { downloads, localDownloadingModels, addLocalDownloadingModel } =
|
||||||
|
useDownloadStore(
|
||||||
|
useShallow((state) => ({
|
||||||
|
downloads: state.downloads,
|
||||||
|
localDownloadingModels: state.localDownloadingModels,
|
||||||
|
addLocalDownloadingModel: state.addLocalDownloadingModel,
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const getProviderByName = useModelProvider((state) => state.getProviderByName)
|
||||||
|
const llamaProvider = getProviderByName('llamacpp')
|
||||||
|
|
||||||
|
const serviceHub = useServiceHub()
|
||||||
|
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
|
||||||
|
|
||||||
|
const quant =
|
||||||
|
model.quants.find((e) =>
|
||||||
|
defaultModelQuantizations.some((m) =>
|
||||||
|
e.model_id.toLowerCase().includes(m)
|
||||||
|
)
|
||||||
|
) ?? model.quants[0]
|
||||||
|
|
||||||
|
const modelId = quant?.model_id || model.model_name
|
||||||
|
|
||||||
|
const downloadProcesses = useMemo(
|
||||||
|
() =>
|
||||||
|
Object.values(downloads).map((download) => ({
|
||||||
|
id: download.name,
|
||||||
|
name: download.name,
|
||||||
|
progress: download.progress,
|
||||||
|
current: download.current,
|
||||||
|
total: download.total,
|
||||||
|
})),
|
||||||
|
[downloads]
|
||||||
|
)
|
||||||
|
|
||||||
|
const isRecommendedModel = useCallback((modelId: string) => {
|
||||||
|
return (extractModelName(modelId)?.toLowerCase() ===
|
||||||
|
'jan-nano-gguf') as boolean
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
if (model.quants.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
onClick={() => {
|
||||||
|
window.open(`https://huggingface.co/${model.model_name}`, '_blank')
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
View on HuggingFace
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelUrl = quant?.path || modelId
|
||||||
|
const isDownloading =
|
||||||
|
localDownloadingModels.has(modelId) ||
|
||||||
|
downloadProcesses.some((e) => e.id === modelId)
|
||||||
|
|
||||||
|
const downloadProgress =
|
||||||
|
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
|
||||||
|
const isDownloaded = llamaProvider?.models.some(
|
||||||
|
(m: { id: string }) =>
|
||||||
|
m.id === modelId ||
|
||||||
|
m.id === `${model.developer}/${sanitizeModelId(modelId)}`
|
||||||
|
)
|
||||||
|
const isRecommended = isRecommendedModel(model.model_name)
|
||||||
|
|
||||||
|
const handleDownload = () => {
|
||||||
|
// Immediately set local downloading state
|
||||||
|
addLocalDownloadingModel(modelId)
|
||||||
|
const mmprojPath = (
|
||||||
|
model.mmproj_models?.find(
|
||||||
|
(e) => e.model_id.toLowerCase() === 'mmproj-f16'
|
||||||
|
) || model.mmproj_models?.[0]
|
||||||
|
)?.path
|
||||||
|
serviceHub
|
||||||
|
.models()
|
||||||
|
.pullModelWithMetadata(modelId, modelUrl, mmprojPath, huggingfaceToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
'flex items-center',
|
||||||
|
isRecommended && 'hub-download-button-step'
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isDownloading && !isDownloaded && (
|
||||||
|
<div className={cn('flex items-center gap-2 w-20')}>
|
||||||
|
<Progress value={downloadProgress * 100} />
|
||||||
|
<span className="text-xs text-center text-main-view-fg/70">
|
||||||
|
{Math.round(downloadProgress * 100)}%
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{isDownloaded ? (
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
onClick={() => handleUseModel(modelId)}
|
||||||
|
data-test-id={`hub-model-${modelId}`}
|
||||||
|
>
|
||||||
|
{t('hub:use')}
|
||||||
|
</Button>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
data-test-id={`hub-model-${modelId}`}
|
||||||
|
size="sm"
|
||||||
|
onClick={handleDownload}
|
||||||
|
className={cn(isDownloading && 'hidden')}
|
||||||
|
>
|
||||||
|
{t('hub:download')}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@ -89,6 +89,7 @@ const CodeComponent = memo(
|
|||||||
onCopy,
|
onCopy,
|
||||||
copiedId,
|
copiedId,
|
||||||
...props
|
...props
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
}: any) => {
|
}: any) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const match = /language-(\w+)/.exec(className || '')
|
const match = /language-(\w+)/.exec(className || '')
|
||||||
|
|||||||
@ -21,10 +21,7 @@ import { useEffect, useMemo, useCallback, useState } from 'react'
|
|||||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||||
import type {
|
import type { CatalogModel, ModelQuant } from '@/services/models/types'
|
||||||
CatalogModel,
|
|
||||||
ModelQuant,
|
|
||||||
} from '@/services/models/types'
|
|
||||||
import { Progress } from '@/components/ui/progress'
|
import { Progress } from '@/components/ui/progress'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
@ -80,12 +77,13 @@ function HubModelDetailContent() {
|
|||||||
}, [fetchSources])
|
}, [fetchSources])
|
||||||
|
|
||||||
const fetchRepo = useCallback(async () => {
|
const fetchRepo = useCallback(async () => {
|
||||||
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(
|
const repoInfo = await serviceHub
|
||||||
search.repo || modelId,
|
.models()
|
||||||
huggingfaceToken
|
.fetchHuggingFaceRepo(search.repo || modelId, huggingfaceToken)
|
||||||
)
|
|
||||||
if (repoInfo) {
|
if (repoInfo) {
|
||||||
const repoDetail = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
|
const repoDetail = serviceHub
|
||||||
|
.models()
|
||||||
|
.convertHfRepoToCatalogModel(repoInfo)
|
||||||
setRepoData(repoDetail || undefined)
|
setRepoData(repoDetail || undefined)
|
||||||
}
|
}
|
||||||
}, [serviceHub, modelId, search, huggingfaceToken])
|
}, [serviceHub, modelId, search, huggingfaceToken])
|
||||||
@ -168,7 +166,9 @@ function HubModelDetailContent() {
|
|||||||
try {
|
try {
|
||||||
// Use the HuggingFace path for the model
|
// Use the HuggingFace path for the model
|
||||||
const modelPath = variant.path
|
const modelPath = variant.path
|
||||||
const supported = await serviceHub.models().isModelSupported(modelPath, 8192)
|
const supported = await serviceHub
|
||||||
|
.models()
|
||||||
|
.isModelSupported(modelPath, 8192)
|
||||||
setModelSupportStatus((prev) => ({
|
setModelSupportStatus((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
[modelKey]: supported,
|
[modelKey]: supported,
|
||||||
@ -473,12 +473,20 @@ function HubModelDetailContent() {
|
|||||||
addLocalDownloadingModel(
|
addLocalDownloadingModel(
|
||||||
variant.model_id
|
variant.model_id
|
||||||
)
|
)
|
||||||
serviceHub.models().pullModelWithMetadata(
|
serviceHub
|
||||||
variant.model_id,
|
.models()
|
||||||
variant.path,
|
.pullModelWithMetadata(
|
||||||
modelData.mmproj_models?.[0]?.path,
|
variant.model_id,
|
||||||
huggingfaceToken
|
variant.path,
|
||||||
)
|
(
|
||||||
|
modelData.mmproj_models?.find(
|
||||||
|
(e) =>
|
||||||
|
e.model_id.toLowerCase() ===
|
||||||
|
'mmproj-f16'
|
||||||
|
) || modelData.mmproj_models?.[0]
|
||||||
|
)?.path,
|
||||||
|
huggingfaceToken
|
||||||
|
)
|
||||||
}}
|
}}
|
||||||
className={cn(isDownloading && 'hidden')}
|
className={cn(isDownloading && 'hidden')}
|
||||||
>
|
>
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||||
import { createFileRoute, useNavigate, useSearch } from '@tanstack/react-router'
|
import { createFileRoute, useNavigate } from '@tanstack/react-router'
|
||||||
import { route } from '@/constants/routes'
|
import { route } from '@/constants/routes'
|
||||||
import { useModelSources } from '@/hooks/useModelSources'
|
import { useModelSources } from '@/hooks/useModelSources'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
@ -34,8 +34,6 @@ import {
|
|||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from '@/components/ui/tooltip'
|
} from '@/components/ui/tooltip'
|
||||||
import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard'
|
import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard'
|
||||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
|
||||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
|
||||||
import {
|
import {
|
||||||
DropdownMenu,
|
DropdownMenu,
|
||||||
DropdownMenuContent,
|
DropdownMenuContent,
|
||||||
@ -51,10 +49,9 @@ import { Loader } from 'lucide-react'
|
|||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||||
import Fuse from 'fuse.js'
|
import Fuse from 'fuse.js'
|
||||||
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||||
|
import { DownloadButtonPlaceholder } from '@/containers/DownloadButton'
|
||||||
|
import { useShallow } from 'zustand/shallow'
|
||||||
|
|
||||||
type ModelProps = {
|
|
||||||
model: CatalogModel
|
|
||||||
}
|
|
||||||
type SearchParams = {
|
type SearchParams = {
|
||||||
repo: string
|
repo: string
|
||||||
}
|
}
|
||||||
@ -77,7 +74,7 @@ function Hub() {
|
|||||||
|
|
||||||
function HubContent() {
|
function HubContent() {
|
||||||
const parentRef = useRef(null)
|
const parentRef = useRef(null)
|
||||||
const { huggingfaceToken } = useGeneralSetting()
|
const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken)
|
||||||
const serviceHub = useServiceHub()
|
const serviceHub = useServiceHub()
|
||||||
|
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@ -93,7 +90,13 @@ function HubContent() {
|
|||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const { sources, fetchSources, loading } = useModelSources()
|
const { sources, fetchSources, loading } = useModelSources(
|
||||||
|
useShallow((state) => ({
|
||||||
|
sources: state.sources,
|
||||||
|
fetchSources: state.fetchSources,
|
||||||
|
loading: state.loading,
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
|
||||||
const [searchValue, setSearchValue] = useState('')
|
const [searchValue, setSearchValue] = useState('')
|
||||||
const [sortSelected, setSortSelected] = useState('newest')
|
const [sortSelected, setSortSelected] = useState('newest')
|
||||||
@ -108,16 +111,9 @@ function HubContent() {
|
|||||||
const [modelSupportStatus, setModelSupportStatus] = useState<
|
const [modelSupportStatus, setModelSupportStatus] = useState<
|
||||||
Record<string, 'RED' | 'YELLOW' | 'GREEN' | 'LOADING'>
|
Record<string, 'RED' | 'YELLOW' | 'GREEN' | 'LOADING'>
|
||||||
>({})
|
>({})
|
||||||
const [joyrideReady, setJoyrideReady] = useState(false)
|
|
||||||
const [currentStepIndex, setCurrentStepIndex] = useState(0)
|
|
||||||
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||||
null
|
null
|
||||||
)
|
)
|
||||||
const downloadButtonRef = useRef<HTMLButtonElement>(null)
|
|
||||||
const hasTriggeredDownload = useRef(false)
|
|
||||||
|
|
||||||
const { getProviderByName } = useModelProvider()
|
|
||||||
const llamaProvider = getProviderByName('llamacpp')
|
|
||||||
|
|
||||||
const toggleModelExpansion = (modelId: string) => {
|
const toggleModelExpansion = (modelId: string) => {
|
||||||
setExpandedModels((prev) => ({
|
setExpandedModels((prev) => ({
|
||||||
@ -168,9 +164,10 @@ function HubContent() {
|
|||||||
?.map((model) => ({
|
?.map((model) => ({
|
||||||
...model,
|
...model,
|
||||||
quants: model.quants.filter((variant) =>
|
quants: model.quants.filter((variant) =>
|
||||||
llamaProvider?.models.some(
|
useModelProvider
|
||||||
(m: { id: string }) => m.id === variant.model_id
|
.getState()
|
||||||
)
|
.getProviderByName('llamacpp')
|
||||||
|
?.models.some((m: { id: string }) => m.id === variant.model_id)
|
||||||
),
|
),
|
||||||
}))
|
}))
|
||||||
.filter((model) => model.quants.length > 0)
|
.filter((model) => model.quants.length > 0)
|
||||||
@ -186,7 +183,6 @@ function HubContent() {
|
|||||||
showOnlyDownloaded,
|
showOnlyDownloaded,
|
||||||
huggingFaceRepo,
|
huggingFaceRepo,
|
||||||
searchOptions,
|
searchOptions,
|
||||||
llamaProvider?.models,
|
|
||||||
])
|
])
|
||||||
|
|
||||||
// The virtualizer
|
// The virtualizer
|
||||||
@ -215,9 +211,13 @@ function HubContent() {
|
|||||||
|
|
||||||
addModelSourceTimeoutRef.current = setTimeout(async () => {
|
addModelSourceTimeoutRef.current = setTimeout(async () => {
|
||||||
try {
|
try {
|
||||||
const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(searchValue, huggingfaceToken)
|
const repoInfo = await serviceHub
|
||||||
|
.models()
|
||||||
|
.fetchHuggingFaceRepo(searchValue, huggingfaceToken)
|
||||||
if (repoInfo) {
|
if (repoInfo) {
|
||||||
const catalogModel = serviceHub.models().convertHfRepoToCatalogModel(repoInfo)
|
const catalogModel = serviceHub
|
||||||
|
.models()
|
||||||
|
.convertHfRepoToCatalogModel(repoInfo)
|
||||||
if (
|
if (
|
||||||
!sources.some(
|
!sources.some(
|
||||||
(s) =>
|
(s) =>
|
||||||
@ -303,7 +303,9 @@ function HubContent() {
|
|||||||
try {
|
try {
|
||||||
// Use the HuggingFace path for the model
|
// Use the HuggingFace path for the model
|
||||||
const modelPath = variant.path
|
const modelPath = variant.path
|
||||||
const supportStatus = await serviceHub.models().isModelSupported(modelPath, 8192)
|
const supportStatus = await serviceHub
|
||||||
|
.models()
|
||||||
|
.isModelSupported(modelPath, 8192)
|
||||||
|
|
||||||
setModelSupportStatus((prev) => ({
|
setModelSupportStatus((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
@ -320,178 +322,7 @@ function HubContent() {
|
|||||||
[modelSupportStatus, serviceHub]
|
[modelSupportStatus, serviceHub]
|
||||||
)
|
)
|
||||||
|
|
||||||
const DownloadButtonPlaceholder = useMemo(() => {
|
|
||||||
return ({ model }: ModelProps) => {
|
|
||||||
// Check if this is a HuggingFace repository (no quants)
|
|
||||||
if (model.quants.length === 0) {
|
|
||||||
return (
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={() => {
|
|
||||||
window.open(
|
|
||||||
`https://huggingface.co/${model.model_name}`,
|
|
||||||
'_blank'
|
|
||||||
)
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
View on HuggingFace
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const quant =
|
|
||||||
model.quants.find((e) =>
|
|
||||||
defaultModelQuantizations.some((m) =>
|
|
||||||
e.model_id.toLowerCase().includes(m)
|
|
||||||
)
|
|
||||||
) ?? model.quants[0]
|
|
||||||
const modelId = quant?.model_id || model.model_name
|
|
||||||
const modelUrl = quant?.path || modelId
|
|
||||||
const isDownloading =
|
|
||||||
localDownloadingModels.has(modelId) ||
|
|
||||||
downloadProcesses.some((e) => e.id === modelId)
|
|
||||||
const downloadProgress =
|
|
||||||
downloadProcesses.find((e) => e.id === modelId)?.progress || 0
|
|
||||||
const isDownloaded = llamaProvider?.models.some(
|
|
||||||
(m: { id: string }) => m.id === modelId
|
|
||||||
)
|
|
||||||
const isRecommended = isRecommendedModel(model.model_name)
|
|
||||||
|
|
||||||
const handleDownload = () => {
|
|
||||||
// Immediately set local downloading state
|
|
||||||
addLocalDownloadingModel(modelId)
|
|
||||||
const mmprojPath = model.mmproj_models?.[0]?.path
|
|
||||||
serviceHub.models().pullModelWithMetadata(
|
|
||||||
modelId,
|
|
||||||
modelUrl,
|
|
||||||
mmprojPath,
|
|
||||||
huggingfaceToken
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
'flex items-center',
|
|
||||||
isRecommended && 'hub-download-button-step'
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{isDownloading && !isDownloaded && (
|
|
||||||
<div className={cn('flex items-center gap-2 w-20')}>
|
|
||||||
<Progress value={downloadProgress * 100} />
|
|
||||||
<span className="text-xs text-center text-main-view-fg/70">
|
|
||||||
{Math.round(downloadProgress * 100)}%
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{isDownloaded ? (
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={() => handleUseModel(modelId)}
|
|
||||||
data-test-id={`hub-model-${modelId}`}
|
|
||||||
>
|
|
||||||
{t('hub:use')}
|
|
||||||
</Button>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
data-test-id={`hub-model-${modelId}`}
|
|
||||||
size="sm"
|
|
||||||
onClick={handleDownload}
|
|
||||||
className={cn(isDownloading && 'hidden')}
|
|
||||||
ref={isRecommended ? downloadButtonRef : undefined}
|
|
||||||
>
|
|
||||||
{t('hub:download')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}, [
|
|
||||||
localDownloadingModels,
|
|
||||||
downloadProcesses,
|
|
||||||
llamaProvider?.models,
|
|
||||||
isRecommendedModel,
|
|
||||||
t,
|
|
||||||
addLocalDownloadingModel,
|
|
||||||
huggingfaceToken,
|
|
||||||
handleUseModel,
|
|
||||||
serviceHub,
|
|
||||||
])
|
|
||||||
|
|
||||||
const { step } = useSearch({ from: Route.id })
|
|
||||||
const isSetup = step === 'setup_local_provider'
|
|
||||||
|
|
||||||
// Wait for DOM to be ready before starting Joyride
|
|
||||||
useEffect(() => {
|
|
||||||
if (!loading && filteredModels.length > 0 && isSetup) {
|
|
||||||
const timer = setTimeout(() => {
|
|
||||||
setJoyrideReady(true)
|
|
||||||
}, 100)
|
|
||||||
return () => clearTimeout(timer)
|
|
||||||
} else {
|
|
||||||
setJoyrideReady(false)
|
|
||||||
}
|
|
||||||
}, [loading, filteredModels.length, isSetup])
|
|
||||||
|
|
||||||
const handleJoyrideCallback = (data: CallBackProps) => {
|
|
||||||
const { status, index } = data
|
|
||||||
|
|
||||||
if (
|
|
||||||
status === STATUS.FINISHED &&
|
|
||||||
!isDownloading &&
|
|
||||||
isLastStep &&
|
|
||||||
!hasTriggeredDownload.current
|
|
||||||
) {
|
|
||||||
const recommendedModel = filteredModels.find((model) =>
|
|
||||||
isRecommendedModel(model.model_name)
|
|
||||||
)
|
|
||||||
if (recommendedModel && recommendedModel.quants[0]?.model_id) {
|
|
||||||
if (downloadButtonRef.current) {
|
|
||||||
hasTriggeredDownload.current = true
|
|
||||||
downloadButtonRef.current.click()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (status === STATUS.FINISHED) {
|
|
||||||
navigate({
|
|
||||||
to: route.hub.index,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track current step index
|
|
||||||
setCurrentStepIndex(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any model is currently downloading
|
|
||||||
const isDownloading =
|
|
||||||
localDownloadingModels.size > 0 || downloadProcesses.length > 0
|
|
||||||
|
|
||||||
const steps = [
|
|
||||||
{
|
|
||||||
target: '.hub-model-card-step',
|
|
||||||
title: t('hub:joyride.recommendedModelTitle'),
|
|
||||||
disableBeacon: true,
|
|
||||||
content: t('hub:joyride.recommendedModelContent'),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
target: '.hub-download-button-step',
|
|
||||||
title: isDownloading
|
|
||||||
? t('hub:joyride.downloadInProgressTitle')
|
|
||||||
: t('hub:joyride.downloadModelTitle'),
|
|
||||||
disableBeacon: true,
|
|
||||||
content: isDownloading
|
|
||||||
? t('hub:joyride.downloadInProgressContent')
|
|
||||||
: t('hub:joyride.downloadModelContent'),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
// Check if we're on the last step
|
// Check if we're on the last step
|
||||||
const isLastStep = currentStepIndex === steps.length - 1
|
|
||||||
|
|
||||||
const renderFilter = () => {
|
const renderFilter = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -544,31 +375,6 @@ function HubContent() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Joyride
|
|
||||||
run={joyrideReady}
|
|
||||||
floaterProps={{
|
|
||||||
hideArrow: true,
|
|
||||||
}}
|
|
||||||
steps={steps}
|
|
||||||
tooltipComponent={CustomTooltipJoyRide}
|
|
||||||
spotlightPadding={0}
|
|
||||||
continuous={true}
|
|
||||||
showSkipButton={!isLastStep}
|
|
||||||
hideCloseButton={true}
|
|
||||||
spotlightClicks={true}
|
|
||||||
disableOverlay={IS_LINUX}
|
|
||||||
disableOverlayClose={true}
|
|
||||||
callback={handleJoyrideCallback}
|
|
||||||
locale={{
|
|
||||||
back: t('hub:joyride.back'),
|
|
||||||
close: t('hub:joyride.close'),
|
|
||||||
last: !isDownloading
|
|
||||||
? t('hub:joyride.lastWithDownload')
|
|
||||||
: t('hub:joyride.last'),
|
|
||||||
next: t('hub:joyride.next'),
|
|
||||||
skip: t('hub:joyride.skip'),
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<div className="flex h-full w-full">
|
<div className="flex h-full w-full">
|
||||||
<div className="flex flex-col h-full w-full ">
|
<div className="flex flex-col h-full w-full ">
|
||||||
<HeaderPage>
|
<HeaderPage>
|
||||||
@ -698,6 +504,7 @@ function HubContent() {
|
|||||||
/>
|
/>
|
||||||
<DownloadButtonPlaceholder
|
<DownloadButtonPlaceholder
|
||||||
model={filteredModels[virtualItem.index]}
|
model={filteredModels[virtualItem.index]}
|
||||||
|
handleUseModel={handleUseModel}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@ -908,10 +715,13 @@ function HubContent() {
|
|||||||
(e) => e.id === variant.model_id
|
(e) => e.id === variant.model_id
|
||||||
)?.progress || 0
|
)?.progress || 0
|
||||||
const isDownloaded =
|
const isDownloaded =
|
||||||
llamaProvider?.models.some(
|
useModelProvider
|
||||||
(m: { id: string }) =>
|
.getState()
|
||||||
m.id === variant.model_id
|
.getProviderByName('llamacpp')
|
||||||
)
|
?.models.some(
|
||||||
|
(m: { id: string }) =>
|
||||||
|
m.id === variant.model_id
|
||||||
|
)
|
||||||
|
|
||||||
if (isDownloading) {
|
if (isDownloading) {
|
||||||
return (
|
return (
|
||||||
@ -962,14 +772,26 @@ function HubContent() {
|
|||||||
addLocalDownloadingModel(
|
addLocalDownloadingModel(
|
||||||
variant.model_id
|
variant.model_id
|
||||||
)
|
)
|
||||||
serviceHub.models().pullModelWithMetadata(
|
serviceHub
|
||||||
variant.model_id,
|
.models()
|
||||||
variant.path,
|
.pullModelWithMetadata(
|
||||||
filteredModels[
|
variant.model_id,
|
||||||
virtualItem.index
|
variant.path,
|
||||||
].mmproj_models?.[0]?.path,
|
|
||||||
huggingfaceToken
|
(
|
||||||
)
|
filteredModels[
|
||||||
|
virtualItem.index
|
||||||
|
].mmproj_models?.find(
|
||||||
|
(e) =>
|
||||||
|
e.model_id.toLowerCase() ===
|
||||||
|
'mmproj-f16'
|
||||||
|
) ||
|
||||||
|
filteredModels[
|
||||||
|
virtualItem.index
|
||||||
|
].mmproj_models?.[0]
|
||||||
|
)?.path,
|
||||||
|
huggingfaceToken
|
||||||
|
)
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IconDownload
|
<IconDownload
|
||||||
|
|||||||
@ -246,7 +246,9 @@ describe('DefaultModelsService', () => {
|
|||||||
})
|
})
|
||||||
mockEngine.load.mockRejectedValue(error)
|
mockEngine.load.mockRejectedValue(error)
|
||||||
|
|
||||||
await expect(modelsService.startModel(provider, model)).rejects.toThrow(error)
|
await expect(modelsService.startModel(provider, model)).rejects.toThrow(
|
||||||
|
error
|
||||||
|
)
|
||||||
})
|
})
|
||||||
it('should not load model again', async () => {
|
it('should not load model again', async () => {
|
||||||
const mockSettings = {
|
const mockSettings = {
|
||||||
@ -263,7 +265,9 @@ describe('DefaultModelsService', () => {
|
|||||||
includes: () => true,
|
includes: () => true,
|
||||||
})
|
})
|
||||||
expect(mockEngine.load).toBeCalledTimes(0)
|
expect(mockEngine.load).toBeCalledTimes(0)
|
||||||
await expect(modelsService.startModel(provider, model)).resolves.toBe(undefined)
|
await expect(modelsService.startModel(provider, model)).resolves.toBe(
|
||||||
|
undefined
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -312,7 +316,9 @@ describe('DefaultModelsService', () => {
|
|||||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toEqual(mockRepoData)
|
expect(result).toEqual(mockRepoData)
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
@ -342,7 +348,9 @@ describe('DefaultModelsService', () => {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Test with domain prefix
|
// Test with domain prefix
|
||||||
await modelsService.fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'huggingface.co/microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
@ -365,7 +373,9 @@ describe('DefaultModelsService', () => {
|
|||||||
expect(await modelsService.fetchHuggingFaceRepo('')).toBeNull()
|
expect(await modelsService.fetchHuggingFaceRepo('')).toBeNull()
|
||||||
|
|
||||||
// Test string without slash
|
// Test string without slash
|
||||||
expect(await modelsService.fetchHuggingFaceRepo('invalid-repo')).toBeNull()
|
expect(
|
||||||
|
await modelsService.fetchHuggingFaceRepo('invalid-repo')
|
||||||
|
).toBeNull()
|
||||||
|
|
||||||
// Test whitespace only
|
// Test whitespace only
|
||||||
expect(await modelsService.fetchHuggingFaceRepo(' ')).toBeNull()
|
expect(await modelsService.fetchHuggingFaceRepo(' ')).toBeNull()
|
||||||
@ -378,7 +388,8 @@ describe('DefaultModelsService', () => {
|
|||||||
statusText: 'Not Found',
|
statusText: 'Not Found',
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('nonexistent/model')
|
const result =
|
||||||
|
await modelsService.fetchHuggingFaceRepo('nonexistent/model')
|
||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
@ -398,7 +409,9 @@ describe('DefaultModelsService', () => {
|
|||||||
statusText: 'Internal Server Error',
|
statusText: 'Internal Server Error',
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(consoleSpy).toHaveBeenCalledWith(
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
@ -414,7 +427,9 @@ describe('DefaultModelsService', () => {
|
|||||||
|
|
||||||
;(fetch as any).mockRejectedValue(new Error('Network error'))
|
;(fetch as any).mockRejectedValue(new Error('Network error'))
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(consoleSpy).toHaveBeenCalledWith(
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
@ -448,7 +463,9 @@ describe('DefaultModelsService', () => {
|
|||||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toEqual(mockRepoData)
|
expect(result).toEqual(mockRepoData)
|
||||||
})
|
})
|
||||||
@ -487,7 +504,9 @@ describe('DefaultModelsService', () => {
|
|||||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toEqual(mockRepoData)
|
expect(result).toEqual(mockRepoData)
|
||||||
})
|
})
|
||||||
@ -531,7 +550,9 @@ describe('DefaultModelsService', () => {
|
|||||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
const result = await modelsService.fetchHuggingFaceRepo(
|
||||||
|
'microsoft/DialoGPT-medium'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toEqual(mockRepoData)
|
expect(result).toEqual(mockRepoData)
|
||||||
// Verify the GGUF file is present in siblings
|
// Verify the GGUF file is present in siblings
|
||||||
@ -576,7 +597,8 @@ describe('DefaultModelsService', () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
it('should convert HuggingFace repo to catalog model format', () => {
|
it('should convert HuggingFace repo to catalog model format', () => {
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||||
|
|
||||||
const expected: CatalogModel = {
|
const expected: CatalogModel = {
|
||||||
model_name: 'microsoft/DialoGPT-medium',
|
model_name: 'microsoft/DialoGPT-medium',
|
||||||
@ -586,12 +608,12 @@ describe('DefaultModelsService', () => {
|
|||||||
num_quants: 2,
|
num_quants: 2,
|
||||||
quants: [
|
quants: [
|
||||||
{
|
{
|
||||||
model_id: 'model-q4_0',
|
model_id: 'microsoft/model-q4_0',
|
||||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf',
|
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf',
|
||||||
file_size: '2.0 GB',
|
file_size: '2.0 GB',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model_id: 'model-q8_0',
|
model_id: 'microsoft/model-q8_0',
|
||||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q8_0.GGUF',
|
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q8_0.GGUF',
|
||||||
file_size: '4.0 GB',
|
file_size: '4.0 GB',
|
||||||
},
|
},
|
||||||
@ -635,7 +657,8 @@ describe('DefaultModelsService', () => {
|
|||||||
siblings: undefined,
|
siblings: undefined,
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings)
|
||||||
|
|
||||||
expect(result.num_quants).toBe(0)
|
expect(result.num_quants).toBe(0)
|
||||||
expect(result.quants).toEqual([])
|
expect(result.quants).toEqual([])
|
||||||
@ -663,7 +686,9 @@ describe('DefaultModelsService', () => {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousFileSizes)
|
const result = modelsService.convertHfRepoToCatalogModel(
|
||||||
|
repoWithVariousFileSizes
|
||||||
|
)
|
||||||
|
|
||||||
expect(result.quants[0].file_size).toBe('500.0 MB')
|
expect(result.quants[0].file_size).toBe('500.0 MB')
|
||||||
expect(result.quants[1].file_size).toBe('3.5 GB')
|
expect(result.quants[1].file_size).toBe('3.5 GB')
|
||||||
@ -676,7 +701,8 @@ describe('DefaultModelsService', () => {
|
|||||||
tags: [],
|
tags: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags)
|
||||||
|
|
||||||
expect(result.description).toBe('**Tags**: ')
|
expect(result.description).toBe('**Tags**: ')
|
||||||
})
|
})
|
||||||
@ -687,7 +713,8 @@ describe('DefaultModelsService', () => {
|
|||||||
downloads: undefined as any,
|
downloads: undefined as any,
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads)
|
||||||
|
|
||||||
expect(result.downloads).toBe(0)
|
expect(result.downloads).toBe(0)
|
||||||
})
|
})
|
||||||
@ -714,15 +741,17 @@ describe('DefaultModelsService', () => {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF)
|
||||||
|
|
||||||
expect(result.quants[0].model_id).toBe('model')
|
expect(result.quants[0].model_id).toBe('microsoft/model')
|
||||||
expect(result.quants[1].model_id).toBe('MODEL')
|
expect(result.quants[1].model_id).toBe('microsoft/MODEL')
|
||||||
expect(result.quants[2].model_id).toBe('complex-model-name')
|
expect(result.quants[2].model_id).toBe('microsoft/complex-model-name')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should generate correct download paths', () => {
|
it('should generate correct download paths', () => {
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||||
|
|
||||||
expect(result.quants[0].path).toBe(
|
expect(result.quants[0].path).toBe(
|
||||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf'
|
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf'
|
||||||
@ -733,7 +762,8 @@ describe('DefaultModelsService', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should generate correct readme URL', () => {
|
it('should generate correct readme URL', () => {
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||||
|
|
||||||
expect(result.readme).toBe(
|
expect(result.readme).toBe(
|
||||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
|
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
|
||||||
@ -767,13 +797,14 @@ describe('DefaultModelsService', () => {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithMixedCase)
|
||||||
|
|
||||||
expect(result.num_quants).toBe(3)
|
expect(result.num_quants).toBe(3)
|
||||||
expect(result.quants).toHaveLength(3)
|
expect(result.quants).toHaveLength(3)
|
||||||
expect(result.quants[0].model_id).toBe('model-1')
|
expect(result.quants[0].model_id).toBe('microsoft/model-1')
|
||||||
expect(result.quants[1].model_id).toBe('model-2')
|
expect(result.quants[1].model_id).toBe('microsoft/model-2')
|
||||||
expect(result.quants[2].model_id).toBe('model-3')
|
expect(result.quants[2].model_id).toBe('microsoft/model-3')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle edge cases with file size formatting', () => {
|
it('should handle edge cases with file size formatting', () => {
|
||||||
@ -798,7 +829,8 @@ describe('DefaultModelsService', () => {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
|
const result =
|
||||||
|
modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases)
|
||||||
|
|
||||||
expect(result.quants[0].file_size).toBe('0.0 MB')
|
expect(result.quants[0].file_size).toBe('0.0 MB')
|
||||||
expect(result.quants[1].file_size).toBe('1.0 GB')
|
expect(result.quants[1].file_size).toBe('1.0 GB')
|
||||||
@ -850,7 +882,10 @@ describe('DefaultModelsService', () => {
|
|||||||
|
|
||||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||||
|
|
||||||
const result = await modelsService.isModelSupported('/path/to/model.gguf', 4096)
|
const result = await modelsService.isModelSupported(
|
||||||
|
'/path/to/model.gguf',
|
||||||
|
4096
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toBe('GREEN')
|
expect(result).toBe('GREEN')
|
||||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||||
@ -867,7 +902,10 @@ describe('DefaultModelsService', () => {
|
|||||||
|
|
||||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||||
|
|
||||||
const result = await modelsService.isModelSupported('/path/to/model.gguf', 8192)
|
const result = await modelsService.isModelSupported(
|
||||||
|
'/path/to/model.gguf',
|
||||||
|
8192
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toBe('YELLOW')
|
expect(result).toBe('YELLOW')
|
||||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||||
@ -884,7 +922,9 @@ describe('DefaultModelsService', () => {
|
|||||||
|
|
||||||
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
mockEngineManager.get.mockReturnValue(mockEngineWithSupport)
|
||||||
|
|
||||||
const result = await modelsService.isModelSupported('/path/to/large-model.gguf')
|
const result = await modelsService.isModelSupported(
|
||||||
|
'/path/to/large-model.gguf'
|
||||||
|
)
|
||||||
|
|
||||||
expect(result).toBe('RED')
|
expect(result).toBe('RED')
|
||||||
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith(
|
||||||
|
|||||||
@ -30,6 +30,10 @@ export class DefaultModelsService implements ModelsService {
|
|||||||
return EngineManager.instance().get(provider) as AIEngine | undefined
|
return EngineManager.instance().get(provider) as AIEngine | undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getModel(modelId: string): Promise<modelInfo | undefined> {
|
||||||
|
return this.getEngine()?.get(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
async fetchModels(): Promise<modelInfo[]> {
|
async fetchModels(): Promise<modelInfo[]> {
|
||||||
return this.getEngine()?.list() ?? []
|
return this.getEngine()?.list() ?? []
|
||||||
}
|
}
|
||||||
@ -127,7 +131,7 @@ export class DefaultModelsService implements ModelsService {
|
|||||||
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
model_id: sanitizeModelId(modelId),
|
model_id: `${repo.author}/${sanitizeModelId(modelId)}`,
|
||||||
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
||||||
file_size: formatFileSize(file.size),
|
file_size: formatFileSize(file.size),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -90,6 +90,7 @@ export interface ModelPlan {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface ModelsService {
|
export interface ModelsService {
|
||||||
|
getModel(modelId: string): Promise<modelInfo | undefined>
|
||||||
fetchModels(): Promise<modelInfo[]>
|
fetchModels(): Promise<modelInfo[]>
|
||||||
fetchModelCatalog(): Promise<ModelCatalog>
|
fetchModelCatalog(): Promise<ModelCatalog>
|
||||||
fetchHuggingFaceRepo(
|
fetchHuggingFaceRepo(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user