feat: sync hub with model catalog

This commit is contained in:
Louis 2025-06-24 23:04:11 +07:00
parent c9c1ff1778
commit c6ac9f1d2a
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
16 changed files with 161 additions and 358 deletions

View File

@ -1,28 +0,0 @@
import { AllQuantizations } from './huggingfaceEntity';
test('testAllQuantizationsArray', () => {
expect(AllQuantizations).toEqual([
'Q3_K_S',
'Q3_K_M',
'Q3_K_L',
'Q4_K_S',
'Q4_K_M',
'Q5_K_S',
'Q5_K_M',
'Q4_0',
'Q4_1',
'Q5_0',
'Q5_1',
'IQ2_XXS',
'IQ2_XS',
'Q2_K',
'Q2_K_S',
'Q6_K',
'Q8_0',
'F16',
'F32',
'COPY',
]);
});

View File

@ -1,65 +0,0 @@
export interface HuggingFaceRepoData {
id: string
modelId: string
modelUrl?: string
author: string
sha: string
downloads: number
lastModified: string
private: boolean
disabled: boolean
gated: boolean
pipeline_tag: 'text-generation'
tags: Array<'transformers' | 'pytorch' | 'safetensors' | string>
cardData: Record<CardDataKeys | string, unknown>
siblings: {
rfilename: string
downloadUrl?: string
fileSize?: number
quantization?: Quantization
}[]
createdAt: string
}
const CardDataKeys = [
'base_model',
'datasets',
'inference',
'language',
'library_name',
'license',
'model_creator',
'model_name',
'model_type',
'pipeline_tag',
'prompt_template',
'quantized_by',
'tags',
] as const
export type CardDataKeysTuple = typeof CardDataKeys
export type CardDataKeys = CardDataKeysTuple[number]
export const AllQuantizations = [
'Q3_K_S',
'Q3_K_M',
'Q3_K_L',
'Q4_K_S',
'Q4_K_M',
'Q5_K_S',
'Q5_K_M',
'Q4_0',
'Q4_1',
'Q5_0',
'Q5_1',
'IQ2_XXS',
'IQ2_XS',
'Q2_K',
'Q2_K_S',
'Q6_K',
'Q8_0',
'F16',
'F32',
'COPY',
]
export type QuantizationsTuple = typeof AllQuantizations
export type Quantization = QuantizationsTuple[number]

View File

@ -1,8 +0,0 @@
import * as huggingfaceEntity from './huggingfaceEntity';
import * as index from './index';
test('test_exports_from_huggingfaceEntity', () => {
expect(index).toEqual(huggingfaceEntity);
});

View File

@ -1 +0,0 @@
export * from './huggingfaceEntity'

View File

@ -5,7 +5,6 @@ export * from './message'
export * from './inference'
export * from './file'
export * from './config'
export * from './huggingface'
export * from './miscellaneous'
export * from './api'
export * from './setting'

View File

@ -1,7 +0,0 @@
export type AppUpdateInfo = {
total: number
delta: number
transferred: number
percent: number
bytesPerSecond: number
}

View File

@ -1,4 +1 @@
export * from './systemResourceInfo'
export * from './promptTemplate'
export * from './appUpdate'
export * from './selectFiles'

View File

@ -1,6 +0,0 @@
export type PromptTemplate = {
system_prompt?: string
ai_prompt?: string
user_prompt?: string
error?: string
}

View File

@ -1,37 +0,0 @@
export type SelectFileOption = {
/**
* The title of the dialog.
*/
title?: string
/**
* Whether the dialog allows multiple selection.
*/
allowMultiple?: boolean
buttonLabel?: string
selectDirectory?: boolean
props?: SelectFileProp[]
filters?: FilterOption[]
}
export type FilterOption = {
name: string
extensions: string[]
}
export const SelectFilePropTuple = [
'openFile',
'openDirectory',
'multiSelections',
'showHiddenFiles',
'createDirectory',
'promptToCreate',
'noResolveAliases',
'treatPackageAsDirectory',
'dontAddToRecent',
] as const
export type SelectFileProp = (typeof SelectFilePropTuple)[number]

View File

@ -572,7 +572,10 @@ export default class llamacpp_extension extends AIEngine {
private createDownloadTaskId(modelId: string) {
// prepend provider to make taksId unique across providers
return `${this.provider}/${modelId}`
const cleanModelId = modelId.includes('.')
? modelId.slice(0, modelId.indexOf('.'))
: modelId
return `${this.provider}/${cleanModelId}`
}
private async *handleStreamingResponse(

View File

@ -4,6 +4,7 @@ export const localStorageKey = {
messages: 'messages',
theme: 'theme',
modelProvider: 'model-provider',
modelSources: 'model-sources',
settingAppearance: 'setting-appearance',
settingGeneral: 'setting-general',
settingCodeBlock: 'setting-code-block',

View File

@ -1,119 +1,65 @@
import { create } from 'zustand'
import { ModelSource } from '@janhq/core'
import {
addModelSource,
deleteModelSource,
fetchModelSources,
} from '@/services/models'
// Service functions for model sources
// Deep comparison function for model sources
const deepCompareModelSources = (
sources1: ModelSource[],
sources2: ModelSource[]
): boolean => {
if (sources1.length !== sources2.length) return false
return sources1.every((source1, index) => {
const source2 = sources2[index]
if (!source2) return false
// Compare basic properties
if (source1.id !== source2.id || source1.author !== source2.author) {
return false
}
// Compare metadata
if (JSON.stringify(source1.metadata) !== JSON.stringify(source2.metadata)) {
return false
}
// Compare models array
if (source1.models.length !== source2.models.length) return false
return source1.models.every((model1, modelIndex) => {
const model2 = source2.models[modelIndex]
return JSON.stringify(model1) === JSON.stringify(model2)
})
})
}
import { localStorageKey } from '@/constants/localStorage'
import { createJSONStorage, persist } from 'zustand/middleware'
import { fetchModelCatalog, CatalogModel } from '@/services/models'
// Zustand store for model sources
type ModelSourcesState = {
sources: ModelSource[]
sources: CatalogModel[]
error: Error | null
loading: boolean
fetchSources: () => Promise<void>
addSource: (source: string) => Promise<void>
deleteSource: (source: string) => Promise<void>
}
export const useModelSources = create<ModelSourcesState>()((set, get) => ({
sources: [],
error: null,
loading: false,
export const useModelSources = create<ModelSourcesState>()(
persist(
(set, get) => ({
sources: [],
error: null,
loading: false,
fetchSources: async () => {
set({ loading: true, error: null })
try {
const newSources = await fetchModelSources()
const currentSources = get().sources
fetchSources: async () => {
set({ loading: true, error: null })
try {
const newSources = await fetchModelCatalog()
const currentSources = get().sources
if (!deepCompareModelSources(currentSources, newSources)) {
set({ sources: newSources, loading: false })
} else {
set({ loading: false })
}
} catch (error) {
set({ error: error as Error, loading: false })
set({
sources: [
...newSources,
...currentSources.filter(
(e) => !newSources.some((s) => s.model_name === e.model_name)
),
],
loading: false,
})
} catch (error) {
set({ error: error as Error, loading: false })
}
},
addSource: async (source: string) => {
set({ loading: true, error: null })
console.log(source)
// try {
// await addModelSource(source)
// const newSources = await fetchModelSources()
// const currentSources = get().sources
// if (!deepCompareModelSources(currentSources, newSources)) {
// set({ sources: newSources, loading: false })
// } else {
// set({ loading: false })
// }
// } catch (error) {
// set({ error: error as Error, loading: false })
// }
},
}),
{
name: localStorageKey.modelSources,
storage: createJSONStorage(() => localStorage),
}
},
addSource: async (source: string) => {
set({ loading: true, error: null })
try {
await addModelSource(source)
const newSources = await fetchModelSources()
const currentSources = get().sources
if (!deepCompareModelSources(currentSources, newSources)) {
set({ sources: newSources, loading: false })
} else {
set({ loading: false })
}
} catch (error) {
set({ error: error as Error, loading: false })
}
},
deleteSource: async (source: string) => {
set({ loading: true, error: null })
try {
await deleteModelSource(source)
const newSources = await fetchModelSources()
const currentSources = get().sources
if (!deepCompareModelSources(currentSources, newSources)) {
set({ sources: newSources, loading: false })
} else {
set({ loading: false })
}
} catch (error) {
set({ error: error as Error, loading: false })
}
},
}))
/**
* @returns Featured model sources from the store
*/
export function useGetFeaturedSources() {
const { sources } = useModelSources()
const featuredSources = sources.filter((e) =>
e.metadata?.tags?.includes('featured')
)
return { sources: featuredSources }
}
)

View File

@ -7,7 +7,7 @@ import {
} from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useModelSources } from '@/hooks/useModelSources'
import { cn, fuzzySearch, toGigabytes } from '@/lib/utils'
import { cn, fuzzySearch } from '@/lib/utils'
import {
useState,
useMemo,
@ -31,7 +31,7 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
import { addModelSource, fetchModelHub, pullModel } from '@/services/models'
import { CatalogModel, pullModel } from '@/services/models'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { Progress } from '@/components/ui/progress'
import HeaderPage from '@/containers/HeaderPage'
@ -39,13 +39,7 @@ import { Loader } from 'lucide-react'
import { useTranslation } from '@/i18n/react-i18next-compat'
type ModelProps = {
model: {
id: string
metadata?: any
models: {
id: string
}[]
}
model: CatalogModel
}
type SearchParams = {
repo: string
@ -65,7 +59,7 @@ function Hub() {
{ value: 'newest', name: t('hub:sortNewest') },
{ value: 'most-downloaded', name: t('hub:sortMostDownloaded') },
]
const { sources, fetchSources, loading } = useModelSources()
const { sources, fetchSources, addSource, loading } = useModelSources()
const search = useSearch({ from: route.hub as any })
const [searchValue, setSearchValue] = useState('')
const [sortSelected, setSortSelected] = useState('newest')
@ -97,7 +91,7 @@ function Hub() {
setSearchValue(search.repo || '')
setIsSearching(true)
addModelSourceTimeoutRef.current = setTimeout(() => {
addModelSource(search.repo)
addSource(search.repo)
.then(() => {
fetchSources()
})
@ -106,17 +100,17 @@ function Hub() {
})
}, 500)
}
}, [fetchSources, search])
}, [addSource, fetchSources, search])
// Sorting functionality
const sortedModels = useMemo(() => {
return [...sources].sort((a, b) => {
if (sortSelected === 'most-downloaded') {
return (b.metadata?.downloads || 0) - (a.metadata?.downloads || 0)
return (b.downloads || 0) - (a.downloads || 0)
} else {
return (
new Date(b.metadata?.createdAt || 0).getTime() -
new Date(a.metadata?.createdAt || 0).getTime()
new Date(b.created_at || 0).getTime() -
new Date(a.created_at || 0).getTime()
)
}
})
@ -132,12 +126,12 @@ function Hub() {
(e) =>
fuzzySearch(
searchValue.replace(/\s+/g, '').toLowerCase(),
e.id.toLowerCase()
e.model_name.toLowerCase()
) ||
e.models.some((model) =>
e.quants.some((model) =>
fuzzySearch(
searchValue.replace(/\s+/g, '').toLowerCase(),
model.id.toLowerCase()
model.model_id.toLowerCase()
)
)
)
@ -146,8 +140,10 @@ function Hub() {
// Apply downloaded filter
if (showOnlyDownloaded) {
filtered = filtered?.filter((model) =>
model.models.some((variant) =>
llamaProvider?.models.some((m: { id: string }) => m.id === variant.id)
model.quants.some((variant) =>
llamaProvider?.models.some(
(m: { id: string }) => m.id === variant.model_id
)
)
)
}
@ -156,7 +152,6 @@ function Hub() {
}, [searchValue, sortedModels, showOnlyDownloaded, llamaProvider?.models])
useEffect(() => {
fetchModelHub()
fetchSources()
}, [fetchSources])
@ -172,7 +167,7 @@ function Hub() {
) {
setIsSearching(true)
addModelSourceTimeoutRef.current = setTimeout(() => {
addModelSource(e.target.value)
addSource(e.target.value)
.then(() => {
fetchSources()
})
@ -223,10 +218,14 @@ function Hub() {
const DownloadButtonPlaceholder = useMemo(() => {
return ({ model }: ModelProps) => {
const modelId =
model.models.find((e) =>
defaultModelQuantizations.some((m) => e.id.toLowerCase().includes(m))
)?.id ?? model.models[0]?.id
const quant =
model.quants.find((e) =>
defaultModelQuantizations.some((m) =>
e.model_id.toLowerCase().includes(m)
)
) ?? model.quants[0]
const modelId = quant?.model_id || model.model_name
const modelUrl = quant?.path || modelId
const isDownloading =
localDownloadingModels.has(modelId) ||
downloadProcesses.some((e) => e.id === modelId)
@ -235,12 +234,12 @@ function Hub() {
const isDownloaded = llamaProvider?.models.some(
(m: { id: string }) => m.id === modelId
)
const isRecommended = isRecommendedModel(model.metadata?.id)
const isRecommended = isRecommendedModel(model.model_name)
const handleDownload = () => {
// Immediately set local downloading state
addLocalDownloadingModel(modelId)
pullModel(modelId, modelId)
pullModel(modelId, modelUrl)
}
return (
@ -316,9 +315,9 @@ function Hub() {
!hasTriggeredDownload.current
) {
const recommendedModel = filteredModels.find((model) =>
isRecommendedModel(model.metadata?.id)
isRecommendedModel(model.model_name)
)
if (recommendedModel && recommendedModel.models[0]?.id) {
if (recommendedModel && recommendedModel.quants[0]?.model_id) {
if (downloadButtonRef.current) {
hasTriggeredDownload.current = true
downloadButtonRef.current.click()
@ -475,20 +474,20 @@ function Hub() {
{renderFilter()}
</div>
{filteredModels.map((model) => (
<div key={model.id}>
<div key={model.model_name}>
<Card
header={
<div className="flex items-center justify-between gap-x-2">
<Link
to={
`https://huggingface.co/${model.metadata?.id}` as string
`https://huggingface.co/${model.model_name}` as string
}
target="_blank"
>
<h1
className={cn(
'text-main-view-fg font-medium text-base capitalize truncate max-w-38 sm:max-w-none',
isRecommendedModel(model.metadata?.id)
isRecommendedModel(model.model_name)
? 'hub-model-card-step'
: ''
)}
@ -496,20 +495,20 @@ function Hub() {
extractModelName(model.metadata?.id) || ''
}
>
{extractModelName(model.metadata?.id) || ''}
{extractModelName(model.model_name) || ''}
</h1>
</Link>
<div className="shrink-0 space-x-3 flex items-center">
<span className="text-main-view-fg/70 font-medium text-xs">
{toGigabytes(
{
(
model.models.find((m) =>
model.quants.find((m) =>
defaultModelQuantizations.some((e) =>
m.id.toLowerCase().includes(e)
m.model_id.toLowerCase().includes(e)
)
) ?? model.models?.[0]
)?.size
)}
) ?? model.quants?.[0]
)?.file_size
}
</span>
<DownloadButtonPlaceholder model={model} />
</div>
@ -530,14 +529,13 @@ function Hub() {
),
}}
content={
extractDescription(model.metadata?.description) ||
''
extractDescription(model?.description) || ''
}
/>
</div>
<div className="flex items-center gap-2 mt-2">
<span className="capitalize text-main-view-fg/80">
{t('hub:by')} {model?.author}
{t('hub:by')} {model?.developer}
</span>
<div className="flex items-center gap-4 ml-2">
<div className="flex items-center gap-1">
@ -547,7 +545,7 @@ function Hub() {
title={t('hub:downloads')}
/>
<span className="text-main-view-fg/80">
{model.metadata?.downloads || 0}
{model.downloads || 0}
</span>
</div>
<div className="flex items-center gap-1">
@ -557,15 +555,15 @@ function Hub() {
title={t('hub:variants')}
/>
<span className="text-main-view-fg/80">
{model.models?.length || 0}
{model.quants?.length || 0}
</span>
</div>
{model.models.length > 1 && (
{model.quants.length > 1 && (
<div className="flex items-center gap-2 hub-show-variants-step">
<Switch
checked={!!expandedModels[model.id]}
checked={!!expandedModels[model.model_name]}
onCheckedChange={() =>
toggleModelExpansion(model.id)
toggleModelExpansion(model.model_name)
}
/>
<p className="text-main-view-fg/70">
@ -575,34 +573,34 @@ function Hub() {
)}
</div>
</div>
{expandedModels[model.id] &&
model.models.length > 0 && (
{expandedModels[model.model_name] &&
model.quants.length > 0 && (
<div className="mt-5">
{model.models.map((variant) => (
{model.quants.map((variant) => (
<CardItem
key={variant.id}
title={variant.id}
key={variant.model_id}
title={variant.model_id}
actions={
<div className="flex items-center gap-2">
<p className="text-main-view-fg/70 font-medium text-xs">
{toGigabytes(variant.size)}
{variant.file_size}
</p>
{(() => {
const isDownloading =
localDownloadingModels.has(
variant.id
variant.model_id
) ||
downloadProcesses.some(
(e) => e.id === variant.id
(e) => e.id === variant.model_id
)
const downloadProgress =
downloadProcesses.find(
(e) => e.id === variant.id
(e) => e.id === variant.model_id
)?.progress || 0
const isDownloaded =
llamaProvider?.models.some(
(m: { id: string }) =>
m.id === variant.id
m.id === variant.model_id
)
if (isDownloading) {
@ -633,7 +631,9 @@ function Hub() {
variant="link"
size="sm"
onClick={() =>
handleUseModel(variant.id)
handleUseModel(
variant.model_id
)
}
>
{t('hub:use')}
@ -648,9 +648,12 @@ function Hub() {
title={t('hub:downloadModel')}
onClick={() => {
addLocalDownloadingModel(
variant.id
variant.model_id
)
pullModel(
variant.model_id,
variant.path
)
pullModel(variant.id, variant.id)
}}
>
<IconDownload

View File

@ -1,4 +1,3 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
AIEngine,
EngineManager,
@ -6,6 +5,25 @@ import {
SettingComponentProps,
} from '@janhq/core'
import { Model as CoreModel } from '@janhq/core'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
// Types for model catalog
export interface ModelQuant {
model_id: string
path: string
file_size: string
}
export interface CatalogModel {
model_name: string
description: string
developer: string
downloads: number
num_quants: number
quants: ModelQuant[]
created_at?: string
}
export type ModelCatalog = CatalogModel[]
// TODO: Replace this with the actual provider later
const defaultProvider = 'llamacpp'
@ -22,43 +40,27 @@ export const fetchModels = async () => {
}
/**
* Fetches the sources of the models.
* @returns A promise that resolves to the model sources.
* Fetches the model catalog from the GitHub repository.
* @returns A promise that resolves to the model catalog.
*/
export const fetchModelSources = async () => {
// TODO: New Hub
return []
}
export const fetchModelCatalog = async (): Promise<ModelCatalog> => {
try {
const response = await fetchTauri(MODEL_CATALOG_URL)
/**
* Fetches the model hub.
* @returns A promise that resolves to the model hub.
*/
export const fetchModelHub = async () => {
// TODO: New Hub
return
}
if (!response.ok) {
throw new Error(
`Failed to fetch model catalog: ${response.status} ${response.statusText}`
)
}
/**
* Adds a new model source.
* @param source The source to add.
* @returns A promise that resolves when the source is added.
*/
export const addModelSource = async (source: string) => {
// TODO: New Hub
console.log(source)
return
}
/**
* Deletes a model source.
* @param source The source to delete.
* @returns A promise that resolves when the source is deleted.
*/
export const deleteModelSource = async (source: string) => {
// TODO: New Hub
console.log(source)
return
const catalog: ModelCatalog = await response.json()
return catalog
} catch (error) {
console.error('Error fetching model catalog:', error)
throw new Error(
`Failed to fetch model catalog: ${error instanceof Error ? error.message : 'Unknown error'}`
)
}
}
/**

View File

@ -18,6 +18,7 @@ declare global {
declare const VERSION: string
declare const POSTHOG_KEY: string
declare const POSTHOG_HOST: string
declare const MODEL_CATALOG_URL: string
interface Window {
core: AppCore | undefined
}

View File

@ -49,6 +49,9 @@ export default defineConfig(({ mode }) => {
POSTHOG_KEY: JSON.stringify(env.POSTHOG_KEY),
POSTHOG_HOST: JSON.stringify(env.POSTHOG_HOST),
MODEL_CATALOG_URL: JSON.stringify(
'https://raw.githubusercontent.com/menloresearch/model-catalog/main/model_catalog.json'
),
},
// Vite options tailored for Tauri development and only applied in `tauri dev` or `tauri build`