Merge pull request #6222 from menloresearch/feat/model-tool-use-detection
feat: #5917 - model tool use capability should be auto detected
This commit is contained in:
commit
55390de070
@ -271,4 +271,10 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
* Optional method to get the underlying chat client
|
* Optional method to get the underlying chat client
|
||||||
*/
|
*/
|
||||||
getChatClient?(sessionId: string): any
|
getChatClient?(sessionId: string): any
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a tool is supported by the model
|
||||||
|
* @param modelId
|
||||||
|
*/
|
||||||
|
abstract isToolSupported(modelId: string): Promise<boolean>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -58,6 +58,7 @@ export enum AppEvent {
|
|||||||
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
|
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
|
||||||
onAppUpdateDownloadError = 'onAppUpdateDownloadError',
|
onAppUpdateDownloadError = 'onAppUpdateDownloadError',
|
||||||
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
|
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
|
||||||
|
onModelImported = 'onModelImported',
|
||||||
|
|
||||||
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
|
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
|
||||||
onSelectedText = 'onSelectedText',
|
onSelectedText = 'onSelectedText',
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import {
|
|||||||
ImportOptions,
|
ImportOptions,
|
||||||
chatCompletionRequest,
|
chatCompletionRequest,
|
||||||
events,
|
events,
|
||||||
|
AppEvent,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||||||
@ -32,6 +33,7 @@ import {
|
|||||||
import { invoke } from '@tauri-apps/api/core'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
import { getProxyConfig } from './util'
|
import { getProxyConfig } from './util'
|
||||||
import { basename } from '@tauri-apps/api/path'
|
import { basename } from '@tauri-apps/api/path'
|
||||||
|
import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api'
|
||||||
|
|
||||||
type LlamacppConfig = {
|
type LlamacppConfig = {
|
||||||
version_backend: string
|
version_backend: string
|
||||||
@ -1085,6 +1087,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
data: modelConfig,
|
data: modelConfig,
|
||||||
savePath: configPath,
|
savePath: configPath,
|
||||||
})
|
})
|
||||||
|
events.emit(AppEvent.onModelImported, {
|
||||||
|
modelId,
|
||||||
|
modelPath,
|
||||||
|
mmprojPath,
|
||||||
|
size_bytes,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
override async abortImport(modelId: string): Promise<void> {
|
override async abortImport(modelId: string): Promise<void> {
|
||||||
@ -1172,7 +1180,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
const [version, backend] = cfg.version_backend.split('/')
|
const [version, backend] = cfg.version_backend.split('/')
|
||||||
if (!version || !backend) {
|
if (!version || !backend) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Initial setup for the backend failed due to a network issue. Please restart the app!"
|
'Initial setup for the backend failed due to a network issue. Please restart the app!'
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1279,11 +1287,14 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// TODO: add LIBRARY_PATH
|
// TODO: add LIBRARY_PATH
|
||||||
const sInfo = await invoke<SessionInfo>('plugin:llamacpp|load_llama_model', {
|
const sInfo = await invoke<SessionInfo>(
|
||||||
backendPath,
|
'plugin:llamacpp|load_llama_model',
|
||||||
libraryPath,
|
{
|
||||||
args,
|
backendPath,
|
||||||
})
|
libraryPath,
|
||||||
|
args,
|
||||||
|
}
|
||||||
|
)
|
||||||
return sInfo
|
return sInfo
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error in load command:\n', error)
|
logger.error('Error in load command:\n', error)
|
||||||
@ -1299,9 +1310,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
const pid = sInfo.pid
|
const pid = sInfo.pid
|
||||||
try {
|
try {
|
||||||
// Pass the PID as the session_id
|
// Pass the PID as the session_id
|
||||||
const result = await invoke<UnloadResult>('plugin:llamacpp|unload_llama_model', {
|
const result = await invoke<UnloadResult>(
|
||||||
pid: pid,
|
'plugin:llamacpp|unload_llama_model',
|
||||||
})
|
{
|
||||||
|
pid: pid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// If successful, remove from active sessions
|
// If successful, remove from active sessions
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
@ -1437,9 +1451,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||||||
try {
|
try {
|
||||||
let sInfo = await invoke<SessionInfo>('plugin:llamacpp|find_session_by_model', {
|
let sInfo = await invoke<SessionInfo>(
|
||||||
modelId,
|
'plugin:llamacpp|find_session_by_model',
|
||||||
})
|
{
|
||||||
|
modelId,
|
||||||
|
}
|
||||||
|
)
|
||||||
return sInfo
|
return sInfo
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -1516,7 +1533,9 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
override async getLoadedModels(): Promise<string[]> {
|
override async getLoadedModels(): Promise<string[]> {
|
||||||
try {
|
try {
|
||||||
let models: string[] = await invoke<string[]>('plugin:llamacpp|get_loaded_models')
|
let models: string[] = await invoke<string[]>(
|
||||||
|
'plugin:llamacpp|get_loaded_models'
|
||||||
|
)
|
||||||
return models
|
return models
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -1599,14 +1618,31 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
throw new Error('method not implemented yet')
|
throw new Error('method not implemented yet')
|
||||||
}
|
}
|
||||||
|
|
||||||
private async loadMetadata(path: string): Promise<GgufMetadata> {
|
/**
|
||||||
try {
|
* Check if a tool is supported by the model
|
||||||
const data = await invoke<GgufMetadata>('plugin:llamacpp|read_gguf_metadata', {
|
* Currently read from GGUF chat_template
|
||||||
path: path,
|
* @param modelId
|
||||||
})
|
* @returns
|
||||||
return data
|
*/
|
||||||
} catch (err) {
|
async isToolSupported(modelId: string): Promise<boolean> {
|
||||||
throw err
|
const janDataFolderPath = await getJanDataFolderPath()
|
||||||
}
|
const modelConfigPath = await joinPath([
|
||||||
|
this.providerPath,
|
||||||
|
'models',
|
||||||
|
modelId,
|
||||||
|
'model.yml',
|
||||||
|
])
|
||||||
|
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||||
|
path: modelConfigPath,
|
||||||
|
})
|
||||||
|
// model option is required
|
||||||
|
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||||||
|
const modelPath = await joinPath([
|
||||||
|
janDataFolderPath,
|
||||||
|
modelConfig.model_path,
|
||||||
|
])
|
||||||
|
return (await readGgufMetadata(modelPath)).metadata?.[
|
||||||
|
'tokenizer.chat_template'
|
||||||
|
]?.includes('tools')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,10 +6,8 @@ import {
|
|||||||
import { Progress } from '@/components/ui/progress'
|
import { Progress } from '@/components/ui/progress'
|
||||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||||
import { useLeftPanel } from '@/hooks/useLeftPanel'
|
import { useLeftPanel } from '@/hooks/useLeftPanel'
|
||||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
|
||||||
import { useAppUpdater } from '@/hooks/useAppUpdater'
|
import { useAppUpdater } from '@/hooks/useAppUpdater'
|
||||||
import { abortDownload } from '@/services/models'
|
import { abortDownload } from '@/services/models'
|
||||||
import { getProviders } from '@/services/providers'
|
|
||||||
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
|
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
|
||||||
import { IconDownload, IconX } from '@tabler/icons-react'
|
import { IconDownload, IconX } from '@tabler/icons-react'
|
||||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||||
@ -18,7 +16,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
|
|||||||
|
|
||||||
export function DownloadManagement() {
|
export function DownloadManagement() {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { setProviders } = useModelProvider()
|
|
||||||
const { open: isLeftPanelOpen } = useLeftPanel()
|
const { open: isLeftPanelOpen } = useLeftPanel()
|
||||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
|
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
|
||||||
const {
|
const {
|
||||||
@ -185,7 +182,6 @@ export function DownloadManagement() {
|
|||||||
console.debug('onFileDownloadSuccess', state)
|
console.debug('onFileDownloadSuccess', state)
|
||||||
removeDownload(state.modelId)
|
removeDownload(state.modelId)
|
||||||
removeLocalDownloadingModel(state.modelId)
|
removeLocalDownloadingModel(state.modelId)
|
||||||
getProviders().then(setProviders)
|
|
||||||
toast.success(t('common:toast.downloadComplete.title'), {
|
toast.success(t('common:toast.downloadComplete.title'), {
|
||||||
id: 'download-complete',
|
id: 'download-complete',
|
||||||
description: t('common:toast.downloadComplete.description', {
|
description: t('common:toast.downloadComplete.description', {
|
||||||
@ -193,7 +189,7 @@ export function DownloadManagement() {
|
|||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[removeDownload, removeLocalDownloadingModel, setProviders, t]
|
[removeDownload, removeLocalDownloadingModel, t]
|
||||||
)
|
)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@ -1,253 +0,0 @@
|
|||||||
import {
|
|
||||||
Dialog,
|
|
||||||
DialogContent,
|
|
||||||
DialogDescription,
|
|
||||||
DialogHeader,
|
|
||||||
DialogTitle,
|
|
||||||
DialogTrigger,
|
|
||||||
} from '@/components/ui/dialog'
|
|
||||||
import { Switch } from '@/components/ui/switch'
|
|
||||||
import {
|
|
||||||
Tooltip,
|
|
||||||
TooltipContent,
|
|
||||||
TooltipTrigger,
|
|
||||||
} from '@/components/ui/tooltip'
|
|
||||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
|
||||||
import {
|
|
||||||
IconPencil,
|
|
||||||
IconEye,
|
|
||||||
IconTool,
|
|
||||||
// IconWorld,
|
|
||||||
// IconAtom,
|
|
||||||
IconCodeCircle2,
|
|
||||||
} from '@tabler/icons-react'
|
|
||||||
import { useState, useEffect } from 'react'
|
|
||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
|
||||||
|
|
||||||
// No need to define our own interface, we'll use the existing Model type
|
|
||||||
type DialogEditModelProps = {
|
|
||||||
provider: ModelProvider
|
|
||||||
modelId?: string // Optional model ID to edit
|
|
||||||
}
|
|
||||||
|
|
||||||
export const DialogEditModel = ({
|
|
||||||
provider,
|
|
||||||
modelId,
|
|
||||||
}: DialogEditModelProps) => {
|
|
||||||
const { t } = useTranslation()
|
|
||||||
const { updateProvider } = useModelProvider()
|
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>('')
|
|
||||||
const [capabilities, setCapabilities] = useState<Record<string, boolean>>({
|
|
||||||
completion: false,
|
|
||||||
vision: false,
|
|
||||||
tools: false,
|
|
||||||
reasoning: false,
|
|
||||||
embeddings: false,
|
|
||||||
web_search: false,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Initialize with the provided model ID or the first model if available
|
|
||||||
useEffect(() => {
|
|
||||||
if (modelId) {
|
|
||||||
setSelectedModelId(modelId)
|
|
||||||
} else if (provider.models && provider.models.length > 0) {
|
|
||||||
setSelectedModelId(provider.models[0].id)
|
|
||||||
}
|
|
||||||
}, [provider, modelId])
|
|
||||||
|
|
||||||
// Get the currently selected model
|
|
||||||
const selectedModel = provider.models.find(
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
(m: any) => m.id === selectedModelId
|
|
||||||
)
|
|
||||||
|
|
||||||
// Initialize capabilities from selected model
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedModel) {
|
|
||||||
const modelCapabilities = selectedModel.capabilities || []
|
|
||||||
setCapabilities({
|
|
||||||
completion: modelCapabilities.includes('completion'),
|
|
||||||
vision: modelCapabilities.includes('vision'),
|
|
||||||
tools: modelCapabilities.includes('tools'),
|
|
||||||
embeddings: modelCapabilities.includes('embeddings'),
|
|
||||||
web_search: modelCapabilities.includes('web_search'),
|
|
||||||
reasoning: modelCapabilities.includes('reasoning'),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}, [selectedModel])
|
|
||||||
|
|
||||||
// Track if capabilities were updated by user action
|
|
||||||
const [capabilitiesUpdated, setCapabilitiesUpdated] = useState(false)
|
|
||||||
|
|
||||||
// Update model capabilities - only update local state
|
|
||||||
const handleCapabilityChange = (capability: string, enabled: boolean) => {
|
|
||||||
setCapabilities((prev) => ({
|
|
||||||
...prev,
|
|
||||||
[capability]: enabled,
|
|
||||||
}))
|
|
||||||
// Mark that capabilities were updated by user action
|
|
||||||
setCapabilitiesUpdated(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use effect to update the provider when capabilities are explicitly changed by user
|
|
||||||
useEffect(() => {
|
|
||||||
// Only run if capabilities were updated by user action and we have a selected model
|
|
||||||
if (!capabilitiesUpdated || !selectedModel) return
|
|
||||||
|
|
||||||
// Reset the flag
|
|
||||||
setCapabilitiesUpdated(false)
|
|
||||||
|
|
||||||
// Create updated capabilities array from the state
|
|
||||||
const updatedCapabilities = Object.entries(capabilities)
|
|
||||||
.filter(([, isEnabled]) => isEnabled)
|
|
||||||
.map(([capName]) => capName)
|
|
||||||
|
|
||||||
// Find and update the model in the provider
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
const updatedModels = provider.models.map((m: any) => {
|
|
||||||
if (m.id === selectedModelId) {
|
|
||||||
return {
|
|
||||||
...m,
|
|
||||||
capabilities: updatedCapabilities,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
})
|
|
||||||
|
|
||||||
// Update the provider with the updated models
|
|
||||||
updateProvider(provider.provider, {
|
|
||||||
...provider,
|
|
||||||
models: updatedModels,
|
|
||||||
})
|
|
||||||
}, [
|
|
||||||
capabilitiesUpdated,
|
|
||||||
capabilities,
|
|
||||||
provider,
|
|
||||||
selectedModel,
|
|
||||||
selectedModelId,
|
|
||||||
updateProvider,
|
|
||||||
])
|
|
||||||
|
|
||||||
if (!selectedModel) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog>
|
|
||||||
<DialogTrigger asChild>
|
|
||||||
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
|
|
||||||
<IconPencil size={18} className="text-main-view-fg/50" />
|
|
||||||
</div>
|
|
||||||
</DialogTrigger>
|
|
||||||
<DialogContent>
|
|
||||||
<DialogHeader>
|
|
||||||
<DialogTitle className="line-clamp-1" title={selectedModel.id}>
|
|
||||||
{t('providers:editModel.title', { modelId: selectedModel.id })}
|
|
||||||
</DialogTitle>
|
|
||||||
<DialogDescription>
|
|
||||||
{t('providers:editModel.description')}
|
|
||||||
</DialogDescription>
|
|
||||||
</DialogHeader>
|
|
||||||
|
|
||||||
<div className="py-1">
|
|
||||||
<h3 className="text-sm font-medium mb-3">
|
|
||||||
{t('providers:editModel.capabilities')}
|
|
||||||
</h3>
|
|
||||||
<div className="space-y-4">
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<div className="flex items-center space-x-2">
|
|
||||||
<IconTool className="size-4 text-main-view-fg/70" />
|
|
||||||
<span className="text-sm">
|
|
||||||
{t('providers:editModel.tools')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<Switch
|
|
||||||
id="tools-capability"
|
|
||||||
checked={capabilities.tools}
|
|
||||||
onCheckedChange={(checked) =>
|
|
||||||
handleCapabilityChange('tools', checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<div className="flex items-center space-x-2">
|
|
||||||
<IconEye className="size-4 text-main-view-fg/70" />
|
|
||||||
<span className="text-sm">
|
|
||||||
{t('providers:editModel.vision')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger>
|
|
||||||
<Switch
|
|
||||||
id="vision-capability"
|
|
||||||
checked={capabilities.vision}
|
|
||||||
disabled={true}
|
|
||||||
onCheckedChange={(checked) =>
|
|
||||||
handleCapabilityChange('vision', checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent>
|
|
||||||
{t('providers:editModel.notAvailable')}
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<div className="flex items-center space-x-2">
|
|
||||||
<IconCodeCircle2 className="size-4 text-main-view-fg/70" />
|
|
||||||
<span className="text-sm">
|
|
||||||
{t('providers:editModel.embeddings')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger>
|
|
||||||
<Switch
|
|
||||||
id="embedding-capability"
|
|
||||||
disabled={true}
|
|
||||||
checked={capabilities.embeddings}
|
|
||||||
onCheckedChange={(checked) =>
|
|
||||||
handleCapabilityChange('embeddings', checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent>
|
|
||||||
{t('providers:editModel.notAvailable')}
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* <div className="flex items-center justify-between">
|
|
||||||
<div className="flex items-center space-x-2">
|
|
||||||
<IconWorld className="size-4 text-main-view-fg/70" />
|
|
||||||
<span className="text-sm">Web Search</span>
|
|
||||||
</div>
|
|
||||||
<Switch
|
|
||||||
id="web_search-capability"
|
|
||||||
checked={capabilities.web_search}
|
|
||||||
onCheckedChange={(checked) =>
|
|
||||||
handleCapabilityChange('web_search', checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div> */}
|
|
||||||
|
|
||||||
{/* <div className="flex items-center justify-between">
|
|
||||||
<div className="flex items-center space-x-2">
|
|
||||||
<IconAtom className="size-4 text-main-view-fg/70" />
|
|
||||||
<span className="text-sm">{t('reasoning')}</span>
|
|
||||||
</div>
|
|
||||||
<Switch
|
|
||||||
id="reasoning-capability"
|
|
||||||
checked={capabilities.reasoning}
|
|
||||||
onCheckedChange={(checked) =>
|
|
||||||
handleCapabilityChange('reasoning', checked)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div> */}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</DialogContent>
|
|
||||||
</Dialog>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@ -17,6 +17,7 @@ import {
|
|||||||
import { useNavigate } from '@tanstack/react-router'
|
import { useNavigate } from '@tanstack/react-router'
|
||||||
import { route } from '@/constants/routes'
|
import { route } from '@/constants/routes'
|
||||||
import { useThreads } from '@/hooks/useThreads'
|
import { useThreads } from '@/hooks/useThreads'
|
||||||
|
import { AppEvent, events } from '@janhq/core'
|
||||||
|
|
||||||
export function DataProvider() {
|
export function DataProvider() {
|
||||||
const { setProviders } = useModelProvider()
|
const { setProviders } = useModelProvider()
|
||||||
@ -70,6 +71,13 @@ export function DataProvider() {
|
|||||||
}
|
}
|
||||||
}, [checkForUpdate])
|
}, [checkForUpdate])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
events.on(AppEvent.onModelImported, () => {
|
||||||
|
getProviders().then(setProviders)
|
||||||
|
})
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [])
|
||||||
|
|
||||||
const handleDeepLink = (urls: string[] | null) => {
|
const handleDeepLink = (urls: string[] | null) => {
|
||||||
if (!urls) return
|
if (!urls) return
|
||||||
console.log('Received deeplink:', urls)
|
console.log('Received deeplink:', urls)
|
||||||
|
|||||||
@ -17,7 +17,12 @@ import { useModelProvider } from '@/hooks/useModelProvider'
|
|||||||
import { Card, CardItem } from '@/containers/Card'
|
import { Card, CardItem } from '@/containers/Card'
|
||||||
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
||||||
import { extractModelName, extractDescription } from '@/lib/models'
|
import { extractModelName, extractDescription } from '@/lib/models'
|
||||||
import { IconDownload, IconFileCode, IconSearch } from '@tabler/icons-react'
|
import {
|
||||||
|
IconDownload,
|
||||||
|
IconFileCode,
|
||||||
|
IconSearch,
|
||||||
|
IconTool,
|
||||||
|
} from '@tabler/icons-react'
|
||||||
import { Switch } from '@/components/ui/switch'
|
import { Switch } from '@/components/ui/switch'
|
||||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||||
@ -133,7 +138,10 @@ function Hub() {
|
|||||||
if (debouncedSearchValue.length) {
|
if (debouncedSearchValue.length) {
|
||||||
const fuse = new Fuse(filtered, searchOptions)
|
const fuse = new Fuse(filtered, searchOptions)
|
||||||
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
|
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
|
||||||
const cleanedSearchValue = debouncedSearchValue.replace(/^https?:\/\/[^/]+\//, '')
|
const cleanedSearchValue = debouncedSearchValue.replace(
|
||||||
|
/^https?:\/\/[^/]+\//,
|
||||||
|
''
|
||||||
|
)
|
||||||
filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
|
filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
|
||||||
}
|
}
|
||||||
// Apply downloaded filter
|
// Apply downloaded filter
|
||||||
@ -647,6 +655,15 @@ function Hub() {
|
|||||||
?.length || 0}
|
?.length || 0}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
{filteredModels[virtualItem.index].tools && (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<IconTool
|
||||||
|
size={17}
|
||||||
|
className="text-main-view-fg/50"
|
||||||
|
title={t('hub:tools')}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{filteredModels[virtualItem.index].quants.length >
|
{filteredModels[virtualItem.index].quants.length >
|
||||||
1 && (
|
1 && (
|
||||||
<div className="flex items-center gap-2 hub-show-variants-step">
|
<div className="flex items-center gap-2 hub-show-variants-step">
|
||||||
|
|||||||
@ -22,7 +22,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
|
|||||||
import Capabilities from '@/containers/Capabilities'
|
import Capabilities from '@/containers/Capabilities'
|
||||||
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
|
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
|
||||||
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
||||||
import { DialogEditModel } from '@/containers/dialogs/EditModel'
|
|
||||||
import { DialogAddModel } from '@/containers/dialogs/AddModel'
|
import { DialogAddModel } from '@/containers/dialogs/AddModel'
|
||||||
import { ModelSetting } from '@/containers/ModelSetting'
|
import { ModelSetting } from '@/containers/ModelSetting'
|
||||||
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
||||||
@ -556,10 +555,6 @@ function ProviderDetail() {
|
|||||||
}
|
}
|
||||||
actions={
|
actions={
|
||||||
<div className="flex items-center gap-0.5">
|
<div className="flex items-center gap-0.5">
|
||||||
<DialogEditModel
|
|
||||||
provider={provider}
|
|
||||||
modelId={model.id}
|
|
||||||
/>
|
|
||||||
{model.settings && (
|
{model.settings && (
|
||||||
<ModelSetting
|
<ModelSetting
|
||||||
provider={provider}
|
provider={provider}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import {
|
|||||||
updateSettings,
|
updateSettings,
|
||||||
} from '../providers'
|
} from '../providers'
|
||||||
import { models as providerModels } from 'token.js'
|
import { models as providerModels } from 'token.js'
|
||||||
import { predefinedProviders } from '@/mock/data'
|
import { predefinedProviders } from '@/consts/providers'
|
||||||
import { EngineManager } from '@janhq/core'
|
import { EngineManager } from '@janhq/core'
|
||||||
import { fetchModels } from '../models'
|
import { fetchModels } from '../models'
|
||||||
import { ExtensionManager } from '@/lib/extension'
|
import { ExtensionManager } from '@/lib/extension'
|
||||||
@ -21,7 +21,7 @@ vi.mock('token.js', () => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/mock/data', () => ({
|
vi.mock('@/consts/providers', () => ({
|
||||||
predefinedProviders: [
|
predefinedProviders: [
|
||||||
{
|
{
|
||||||
active: true,
|
active: true,
|
||||||
@ -69,6 +69,7 @@ vi.mock('../models', () => ({
|
|||||||
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
|
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
|
||||||
])
|
])
|
||||||
),
|
),
|
||||||
|
isToolSupported: vi.fn(() => Promise.resolve(false)),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/lib/extension', () => ({
|
vi.mock('@/lib/extension', () => ({
|
||||||
@ -116,7 +117,7 @@ describe('providers service', () => {
|
|||||||
it('should return builtin and runtime providers', async () => {
|
it('should return builtin and runtime providers', async () => {
|
||||||
const providers = await getProviders()
|
const providers = await getProviders()
|
||||||
|
|
||||||
expect(providers).toHaveLength(9) // 8 runtime + 1 builtin
|
expect(providers).toHaveLength(2) // 1 runtime + 1 builtin (mocked)
|
||||||
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
|
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
|
||||||
expect(providers.some((p) => p.provider === 'openai')).toBe(true)
|
expect(providers.some((p) => p.provider === 'openai')).toBe(true)
|
||||||
})
|
})
|
||||||
@ -156,7 +157,7 @@ describe('providers service', () => {
|
|||||||
provider: 'openai',
|
provider: 'openai',
|
||||||
base_url: 'https://api.openai.com/v1',
|
base_url: 'https://api.openai.com/v1',
|
||||||
api_key: 'test-key',
|
api_key: 'test-key',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
const models = await fetchModelsFromProvider(provider)
|
const models = await fetchModelsFromProvider(provider)
|
||||||
|
|
||||||
@ -185,7 +186,7 @@ describe('providers service', () => {
|
|||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
base_url: 'https://api.custom.com',
|
base_url: 'https://api.custom.com',
|
||||||
api_key: '',
|
api_key: '',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
const models = await fetchModelsFromProvider(provider)
|
const models = await fetchModelsFromProvider(provider)
|
||||||
|
|
||||||
@ -204,7 +205,7 @@ describe('providers service', () => {
|
|||||||
const provider = {
|
const provider = {
|
||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
base_url: 'https://api.custom.com',
|
base_url: 'https://api.custom.com',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
const models = await fetchModelsFromProvider(provider)
|
const models = await fetchModelsFromProvider(provider)
|
||||||
|
|
||||||
@ -214,7 +215,7 @@ describe('providers service', () => {
|
|||||||
it('should throw error when provider has no base_url', async () => {
|
it('should throw error when provider has no base_url', async () => {
|
||||||
const provider = {
|
const provider = {
|
||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||||
'Provider must have base_url configured'
|
'Provider must have base_url configured'
|
||||||
@ -232,10 +233,10 @@ describe('providers service', () => {
|
|||||||
const provider = {
|
const provider = {
|
||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
base_url: 'https://api.custom.com',
|
base_url: 'https://api.custom.com',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||||
'Cannot connect to custom at https://api.custom.com'
|
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -245,10 +246,10 @@ describe('providers service', () => {
|
|||||||
const provider = {
|
const provider = {
|
||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
base_url: 'https://api.custom.com',
|
base_url: 'https://api.custom.com',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||||
'Cannot connect to custom at https://api.custom.com'
|
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -264,7 +265,7 @@ describe('providers service', () => {
|
|||||||
const provider = {
|
const provider = {
|
||||||
provider: 'custom',
|
provider: 'custom',
|
||||||
base_url: 'https://api.custom.com',
|
base_url: 'https://api.custom.com',
|
||||||
} as ModelProvider
|
}
|
||||||
|
|
||||||
const models = await fetchModelsFromProvider(provider)
|
const models = await fetchModelsFromProvider(provider)
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ describe('providers service', () => {
|
|||||||
controller_type: 'input',
|
controller_type: 'input',
|
||||||
controller_props: { value: 'test-key' },
|
controller_props: { value: 'test-key' },
|
||||||
},
|
},
|
||||||
] as ProviderSetting[]
|
]
|
||||||
|
|
||||||
await updateSettings('openai', settings)
|
await updateSettings('openai', settings)
|
||||||
|
|
||||||
@ -324,7 +325,7 @@ describe('providers service', () => {
|
|||||||
mockExtensionManager
|
mockExtensionManager
|
||||||
)
|
)
|
||||||
|
|
||||||
const settings = [] as ProviderSetting[]
|
const settings = []
|
||||||
|
|
||||||
const result = await updateSettings('nonexistent', settings)
|
const result = await updateSettings('nonexistent', settings)
|
||||||
|
|
||||||
@ -350,7 +351,7 @@ describe('providers service', () => {
|
|||||||
controller_type: 'input',
|
controller_type: 'input',
|
||||||
controller_props: { value: undefined },
|
controller_props: { value: undefined },
|
||||||
},
|
},
|
||||||
] as ProviderSetting[]
|
]
|
||||||
|
|
||||||
await updateSettings('openai', settings)
|
await updateSettings('openai', settings)
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ export interface CatalogModel {
|
|||||||
mmproj_models?: MMProjModel[]
|
mmproj_models?: MMProjModel[]
|
||||||
created_at?: string
|
created_at?: string
|
||||||
readme?: string
|
readme?: string
|
||||||
|
tools?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelCatalog = CatalogModel[]
|
export type ModelCatalog = CatalogModel[]
|
||||||
@ -313,3 +314,16 @@ export const startModel = async (
|
|||||||
throw error
|
throw error
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if model support tool use capability
|
||||||
|
* Returned by backend engine
|
||||||
|
* @param modelId
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const isToolSupported = async (modelId: string): Promise<boolean> => {
|
||||||
|
const engine = getEngine()
|
||||||
|
if (!engine) return false
|
||||||
|
|
||||||
|
return engine.isToolSupported(modelId)
|
||||||
|
}
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
import { models as providerModels } from 'token.js'
|
import { models as providerModels } from 'token.js'
|
||||||
import { predefinedProviders } from '@/consts/providers'
|
import { predefinedProviders } from '@/consts/providers'
|
||||||
import { EngineManager, SettingComponentProps } from '@janhq/core'
|
import { EngineManager, SettingComponentProps } from '@janhq/core'
|
||||||
import {
|
import { ModelCapabilities } from '@/types/models'
|
||||||
DefaultToolUseSupportedModels,
|
|
||||||
ModelCapabilities,
|
|
||||||
} from '@/types/models'
|
|
||||||
import { modelSettings } from '@/lib/predefined'
|
import { modelSettings } from '@/lib/predefined'
|
||||||
import { fetchModels } from './models'
|
import { fetchModels, isToolSupported } from './models'
|
||||||
import { ExtensionManager } from '@/lib/extension'
|
import { ExtensionManager } from '@/lib/extension'
|
||||||
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
||||||
|
|
||||||
@ -65,52 +62,41 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
|||||||
controller_props: setting.controllerProps as unknown,
|
controller_props: setting.controllerProps as unknown,
|
||||||
}
|
}
|
||||||
}) as ProviderSetting[],
|
}) as ProviderSetting[],
|
||||||
models: models.map((model) => ({
|
models: await Promise.all(
|
||||||
id: model.id,
|
models.map(
|
||||||
model: model.id,
|
async (model) =>
|
||||||
name: model.name,
|
({
|
||||||
description: model.description,
|
id: model.id,
|
||||||
capabilities:
|
model: model.id,
|
||||||
'capabilities' in model
|
name: model.name,
|
||||||
? (model.capabilities as string[])
|
description: model.description,
|
||||||
: [
|
capabilities:
|
||||||
ModelCapabilities.COMPLETION,
|
'capabilities' in model
|
||||||
...(Object.values(DefaultToolUseSupportedModels).some((v) =>
|
? (model.capabilities as string[])
|
||||||
model.id.toLowerCase().includes(v.toLowerCase())
|
: (await isToolSupported(model.id))
|
||||||
)
|
? [ModelCapabilities.TOOLS]
|
||||||
? [ModelCapabilities.TOOLS]
|
: [],
|
||||||
: []),
|
provider: providerName,
|
||||||
],
|
settings: Object.values(modelSettings).reduce(
|
||||||
provider: providerName,
|
(acc, setting) => {
|
||||||
settings: Object.values(modelSettings).reduce(
|
let value = setting.controller_props.value
|
||||||
(acc, setting) => {
|
if (setting.key === 'ctx_len') {
|
||||||
let value = setting.controller_props.value
|
value = 8192 // Default context length for Llama.cpp models
|
||||||
if (setting.key === 'ctx_len') {
|
}
|
||||||
value = 8192 // Default context length for Llama.cpp models
|
acc[setting.key] = {
|
||||||
}
|
...setting,
|
||||||
// Set temperature to 0.6 for DefaultToolUseSupportedModels
|
controller_props: {
|
||||||
if (
|
...setting.controller_props,
|
||||||
Object.values(DefaultToolUseSupportedModels).some((v) =>
|
value: value,
|
||||||
model.id.toLowerCase().includes(v.toLowerCase())
|
},
|
||||||
)
|
}
|
||||||
) {
|
return acc
|
||||||
if (setting.key === 'temperature') value = 0.7 // Default temperature for tool-supported models
|
},
|
||||||
if (setting.key === 'top_k') value = 20 // Default top_k for tool-supported models
|
{} as Record<string, ProviderSetting>
|
||||||
if (setting.key === 'top_p') value = 0.8 // Default top_p for tool-supported models
|
),
|
||||||
if (setting.key === 'min_p') value = 0 // Default min_p for tool-supported models
|
}) as Model
|
||||||
}
|
)
|
||||||
acc[setting.key] = {
|
),
|
||||||
...setting,
|
|
||||||
controller_props: {
|
|
||||||
...setting.controller_props,
|
|
||||||
value: value,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return acc
|
|
||||||
},
|
|
||||||
{} as Record<string, ProviderSetting>
|
|
||||||
),
|
|
||||||
})),
|
|
||||||
}
|
}
|
||||||
runtimeProviders.push(provider)
|
runtimeProviders.push(provider)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,10 +14,3 @@ export enum ModelCapabilities {
|
|||||||
TEXT_TO_AUDIO = 'text_to_audio',
|
TEXT_TO_AUDIO = 'text_to_audio',
|
||||||
AUDIO_TO_TEXT = 'audio_to_text',
|
AUDIO_TO_TEXT = 'audio_to_text',
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove this enum when we integrate llama.cpp extension
|
|
||||||
export enum DefaultToolUseSupportedModels {
|
|
||||||
JanNano = 'jan-',
|
|
||||||
Qwen3 = 'qwen3',
|
|
||||||
Lucy = 'lucy',
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user