fix: distinguish between hub and models search (#4989)
* fix: distinguish between hub and models search * chore: refresh models hub when going to hub screen
This commit is contained in:
parent
db43008813
commit
b64749b4bb
@ -40,4 +40,9 @@ export abstract class ModelExtension
|
|||||||
* Delete a model source
|
* Delete a model source
|
||||||
*/
|
*/
|
||||||
abstract deleteSource(source: string): Promise<void>
|
abstract deleteSource(source: string): Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch models hub
|
||||||
|
*/
|
||||||
|
abstract fetchModelsHub(): Promise<void>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -67,7 +67,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sync with cortexsohub
|
// Sync with cortexsohub
|
||||||
this.fetchCortexsoModels()
|
this.fetchModelsHub()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -450,7 +450,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
/**
|
/**
|
||||||
* Fetch models from cortex.so
|
* Fetch models from cortex.so
|
||||||
*/
|
*/
|
||||||
private fetchCortexsoModels = async () => {
|
fetchModelsHub = async () => {
|
||||||
const models = await this.fetchModels()
|
const models = await this.fetchModels()
|
||||||
|
|
||||||
return this.queue.add(() =>
|
return this.queue.add(() =>
|
||||||
|
|||||||
@ -5,12 +5,23 @@ import { SearchIcon } from 'lucide-react'
|
|||||||
|
|
||||||
import { useDebouncedCallback } from 'use-debounce'
|
import { useDebouncedCallback } from 'use-debounce'
|
||||||
|
|
||||||
|
import {
|
||||||
|
useGetModelSources,
|
||||||
|
useModelSourcesMutation,
|
||||||
|
} from '@/hooks/useModelSource'
|
||||||
|
|
||||||
|
import Spinner from '../Loader/Spinner'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
|
supportModelImport?: boolean
|
||||||
onSearchLocal?: (searchText: string) => void
|
onSearchLocal?: (searchText: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelSearch = ({ onSearchLocal }: Props) => {
|
const ModelSearch = ({ supportModelImport, onSearchLocal }: Props) => {
|
||||||
const [searchText, setSearchText] = useState('')
|
const [searchText, setSearchText] = useState('')
|
||||||
|
const [isSearching, setSearching] = useState(false)
|
||||||
|
const { mutate } = useGetModelSources()
|
||||||
|
const { addModelSource } = useModelSourcesMutation()
|
||||||
const inputRef = useRef<HTMLInputElement | null>(null)
|
const inputRef = useRef<HTMLInputElement | null>(null)
|
||||||
const debounced = useDebouncedCallback(async () => {
|
const debounced = useDebouncedCallback(async () => {
|
||||||
if (searchText.indexOf('/') === -1) {
|
if (searchText.indexOf('/') === -1) {
|
||||||
@ -20,6 +31,15 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
|
|||||||
}
|
}
|
||||||
// Attempt to search local
|
// Attempt to search local
|
||||||
onSearchLocal?.(searchText)
|
onSearchLocal?.(searchText)
|
||||||
|
|
||||||
|
setSearching(true)
|
||||||
|
// Attempt to search model source
|
||||||
|
if (supportModelImport)
|
||||||
|
addModelSource(searchText)
|
||||||
|
.then(() => mutate())
|
||||||
|
.then(() => onSearchLocal?.(searchText))
|
||||||
|
.catch(console.debug)
|
||||||
|
.finally(() => setSearching(false))
|
||||||
}, 300)
|
}, 300)
|
||||||
|
|
||||||
const onSearchChanged = useCallback(
|
const onSearchChanged = useCallback(
|
||||||
@ -50,8 +70,18 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<Input
|
<Input
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
prefixIcon={<SearchIcon size={16} />}
|
prefixIcon={
|
||||||
placeholder="Search models..."
|
isSearching ? (
|
||||||
|
<Spinner size={16} strokeWidth={2} />
|
||||||
|
) : (
|
||||||
|
<SearchIcon size={16} />
|
||||||
|
)
|
||||||
|
}
|
||||||
|
placeholder={
|
||||||
|
supportModelImport
|
||||||
|
? 'Search or enter Hugging Face URL'
|
||||||
|
: 'Search models'
|
||||||
|
}
|
||||||
onChange={onSearchChanged}
|
onChange={onSearchChanged}
|
||||||
onKeyDown={onKeyDown}
|
onKeyDown={onKeyDown}
|
||||||
value={searchText}
|
value={searchText}
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import {
|
|||||||
ModelEvent,
|
ModelEvent,
|
||||||
ModelSource,
|
ModelSource,
|
||||||
ModelSibling,
|
ModelSibling,
|
||||||
|
ModelExtension,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { useAtom, useAtomValue } from 'jotai'
|
import { useAtom, useAtomValue } from 'jotai'
|
||||||
import { atomWithStorage } from 'jotai/utils'
|
import { atomWithStorage } from 'jotai/utils'
|
||||||
@ -497,3 +498,21 @@ export const useRefreshModelList = (engine: string) => {
|
|||||||
|
|
||||||
return { refreshingModels, refreshModels }
|
return { refreshingModels, refreshModels }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const useFetchModelsHub = () => {
|
||||||
|
const extension = useMemo(
|
||||||
|
() => extensionManager.get<ModelExtension>(ExtensionTypeEnum.Model) ?? null,
|
||||||
|
[]
|
||||||
|
)
|
||||||
|
|
||||||
|
const { data, error, mutate } = useSWR(
|
||||||
|
extension ? 'fetchModelsHub' : null,
|
||||||
|
() => extension?.fetchModelsHub(),
|
||||||
|
{
|
||||||
|
revalidateOnFocus: false,
|
||||||
|
revalidateOnReconnect: true,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return { modelsHub: data, error, mutate }
|
||||||
|
}
|
||||||
|
|||||||
@ -25,7 +25,10 @@ import { twMerge } from 'tailwind-merge'
|
|||||||
import CenterPanelContainer from '@/containers/CenterPanelContainer'
|
import CenterPanelContainer from '@/containers/CenterPanelContainer'
|
||||||
import ModelSearch from '@/containers/ModelSearch'
|
import ModelSearch from '@/containers/ModelSearch'
|
||||||
|
|
||||||
import { useGetEngineModelSources } from '@/hooks/useEngineManagement'
|
import {
|
||||||
|
useFetchModelsHub,
|
||||||
|
useGetEngineModelSources,
|
||||||
|
} from '@/hooks/useEngineManagement'
|
||||||
import { setImportModelStageAtom } from '@/hooks/useImportModel'
|
import { setImportModelStageAtom } from '@/hooks/useImportModel'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -85,6 +88,7 @@ const hubCompatibleAtom = atom(false)
|
|||||||
const HubScreen = () => {
|
const HubScreen = () => {
|
||||||
const { sources } = useGetModelSources()
|
const { sources } = useGetModelSources()
|
||||||
const { sources: remoteModelSources } = useGetEngineModelSources()
|
const { sources: remoteModelSources } = useGetEngineModelSources()
|
||||||
|
const { mutate: fetchModelsHub } = useFetchModelsHub()
|
||||||
const { addModelSource } = useModelSourcesMutation()
|
const { addModelSource } = useModelSourcesMutation()
|
||||||
const [searchValue, setSearchValue] = useState('')
|
const [searchValue, setSearchValue] = useState('')
|
||||||
const [sortSelected, setSortSelected] = useState('newest')
|
const [sortSelected, setSortSelected] = useState('newest')
|
||||||
@ -268,6 +272,10 @@ const HubScreen = () => {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchModelsHub()
|
||||||
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<CenterPanelContainer>
|
<CenterPanelContainer>
|
||||||
<m.div
|
<m.div
|
||||||
@ -422,7 +430,10 @@ const HubScreen = () => {
|
|||||||
<div className="absolute left-1/2 top-1/2 z-10 mx-auto w-4/5 -translate-x-1/2 -translate-y-1/2 rounded-xl sm:w-2/6">
|
<div className="absolute left-1/2 top-1/2 z-10 mx-auto w-4/5 -translate-x-1/2 -translate-y-1/2 rounded-xl sm:w-2/6">
|
||||||
<div className="flex flex-col items-center justify-between gap-2 sm:flex-row">
|
<div className="flex flex-col items-center justify-between gap-2 sm:flex-row">
|
||||||
<div className="w-full" ref={dropdownRef}>
|
<div className="w-full" ref={dropdownRef}>
|
||||||
<ModelSearch onSearchLocal={onSearchUpdate} />
|
<ModelSearch
|
||||||
|
onSearchLocal={onSearchUpdate}
|
||||||
|
supportModelImport
|
||||||
|
/>
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'invisible absolute mt-2 max-h-[400px] w-full overflow-y-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))] shadow-lg',
|
'invisible absolute mt-2 max-h-[400px] w-full overflow-y-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))] shadow-lg',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user