Merge pull request #6222 from menloresearch/feat/model-tool-use-detection

feat: #5917 - model tool use capability should be auto detected
This commit is contained in:
Louis 2025-08-19 13:55:08 +07:00 committed by GitHub
commit 55390de070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 161 additions and 361 deletions

View File

@ -271,4 +271,10 @@ export abstract class AIEngine extends BaseExtension {
* Optional method to get the underlying chat client * Optional method to get the underlying chat client
*/ */
getChatClient?(sessionId: string): any getChatClient?(sessionId: string): any
/**
* Check if a tool is supported by the model
* @param modelId
*/
abstract isToolSupported(modelId: string): Promise<boolean>
} }

View File

@ -58,6 +58,7 @@ export enum AppEvent {
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate', onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
onAppUpdateDownloadError = 'onAppUpdateDownloadError', onAppUpdateDownloadError = 'onAppUpdateDownloadError',
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess', onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
onModelImported = 'onModelImported',
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk', onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
onSelectedText = 'onSelectedText', onSelectedText = 'onSelectedText',

View File

@ -19,6 +19,7 @@ import {
ImportOptions, ImportOptions,
chatCompletionRequest, chatCompletionRequest,
events, events,
AppEvent,
} from '@janhq/core' } from '@janhq/core'
import { error, info, warn } from '@tauri-apps/plugin-log' import { error, info, warn } from '@tauri-apps/plugin-log'
@ -32,6 +33,7 @@ import {
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
import { getProxyConfig } from './util' import { getProxyConfig } from './util'
import { basename } from '@tauri-apps/api/path' import { basename } from '@tauri-apps/api/path'
import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api'
type LlamacppConfig = { type LlamacppConfig = {
version_backend: string version_backend: string
@ -1085,6 +1087,12 @@ export default class llamacpp_extension extends AIEngine {
data: modelConfig, data: modelConfig,
savePath: configPath, savePath: configPath,
}) })
events.emit(AppEvent.onModelImported, {
modelId,
modelPath,
mmprojPath,
size_bytes,
})
} }
override async abortImport(modelId: string): Promise<void> { override async abortImport(modelId: string): Promise<void> {
@ -1172,7 +1180,7 @@ export default class llamacpp_extension extends AIEngine {
const [version, backend] = cfg.version_backend.split('/') const [version, backend] = cfg.version_backend.split('/')
if (!version || !backend) { if (!version || !backend) {
throw new Error( throw new Error(
"Initial setup for the backend failed due to a network issue. Please restart the app!" 'Initial setup for the backend failed due to a network issue. Please restart the app!'
) )
} }
@ -1279,11 +1287,14 @@ export default class llamacpp_extension extends AIEngine {
try { try {
// TODO: add LIBRARY_PATH // TODO: add LIBRARY_PATH
const sInfo = await invoke<SessionInfo>('plugin:llamacpp|load_llama_model', { const sInfo = await invoke<SessionInfo>(
backendPath, 'plugin:llamacpp|load_llama_model',
libraryPath, {
args, backendPath,
}) libraryPath,
args,
}
)
return sInfo return sInfo
} catch (error) { } catch (error) {
logger.error('Error in load command:\n', error) logger.error('Error in load command:\n', error)
@ -1299,9 +1310,12 @@ export default class llamacpp_extension extends AIEngine {
const pid = sInfo.pid const pid = sInfo.pid
try { try {
// Pass the PID as the session_id // Pass the PID as the session_id
const result = await invoke<UnloadResult>('plugin:llamacpp|unload_llama_model', { const result = await invoke<UnloadResult>(
pid: pid, 'plugin:llamacpp|unload_llama_model',
}) {
pid: pid,
}
)
// If successful, remove from active sessions // If successful, remove from active sessions
if (result.success) { if (result.success) {
@ -1437,9 +1451,12 @@ export default class llamacpp_extension extends AIEngine {
private async findSessionByModel(modelId: string): Promise<SessionInfo> { private async findSessionByModel(modelId: string): Promise<SessionInfo> {
try { try {
let sInfo = await invoke<SessionInfo>('plugin:llamacpp|find_session_by_model', { let sInfo = await invoke<SessionInfo>(
modelId, 'plugin:llamacpp|find_session_by_model',
}) {
modelId,
}
)
return sInfo return sInfo
} catch (e) { } catch (e) {
logger.error(e) logger.error(e)
@ -1516,7 +1533,9 @@ export default class llamacpp_extension extends AIEngine {
override async getLoadedModels(): Promise<string[]> { override async getLoadedModels(): Promise<string[]> {
try { try {
let models: string[] = await invoke<string[]>('plugin:llamacpp|get_loaded_models') let models: string[] = await invoke<string[]>(
'plugin:llamacpp|get_loaded_models'
)
return models return models
} catch (e) { } catch (e) {
logger.error(e) logger.error(e)
@ -1599,14 +1618,31 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('method not implemented yet') throw new Error('method not implemented yet')
} }
private async loadMetadata(path: string): Promise<GgufMetadata> { /**
try { * Check if a tool is supported by the model
const data = await invoke<GgufMetadata>('plugin:llamacpp|read_gguf_metadata', { * Currently read from GGUF chat_template
path: path, * @param modelId
}) * @returns
return data */
} catch (err) { async isToolSupported(modelId: string): Promise<boolean> {
throw err const janDataFolderPath = await getJanDataFolderPath()
} const modelConfigPath = await joinPath([
this.providerPath,
'models',
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
// model option is required
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
const modelPath = await joinPath([
janDataFolderPath,
modelConfig.model_path,
])
return (await readGgufMetadata(modelPath)).metadata?.[
'tokenizer.chat_template'
]?.includes('tools')
} }
} }

View File

@ -6,10 +6,8 @@ import {
import { Progress } from '@/components/ui/progress' import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore' import { useDownloadStore } from '@/hooks/useDownloadStore'
import { useLeftPanel } from '@/hooks/useLeftPanel' import { useLeftPanel } from '@/hooks/useLeftPanel'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useAppUpdater } from '@/hooks/useAppUpdater' import { useAppUpdater } from '@/hooks/useAppUpdater'
import { abortDownload } from '@/services/models' import { abortDownload } from '@/services/models'
import { getProviders } from '@/services/providers'
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core' import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
import { IconDownload, IconX } from '@tabler/icons-react' import { IconDownload, IconX } from '@tabler/icons-react'
import { useCallback, useEffect, useMemo, useState } from 'react' import { useCallback, useEffect, useMemo, useState } from 'react'
@ -18,7 +16,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
export function DownloadManagement() { export function DownloadManagement() {
const { t } = useTranslation() const { t } = useTranslation()
const { setProviders } = useModelProvider()
const { open: isLeftPanelOpen } = useLeftPanel() const { open: isLeftPanelOpen } = useLeftPanel()
const [isPopoverOpen, setIsPopoverOpen] = useState(false) const [isPopoverOpen, setIsPopoverOpen] = useState(false)
const { const {
@ -185,7 +182,6 @@ export function DownloadManagement() {
console.debug('onFileDownloadSuccess', state) console.debug('onFileDownloadSuccess', state)
removeDownload(state.modelId) removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId) removeLocalDownloadingModel(state.modelId)
getProviders().then(setProviders)
toast.success(t('common:toast.downloadComplete.title'), { toast.success(t('common:toast.downloadComplete.title'), {
id: 'download-complete', id: 'download-complete',
description: t('common:toast.downloadComplete.description', { description: t('common:toast.downloadComplete.description', {
@ -193,7 +189,7 @@ export function DownloadManagement() {
}), }),
}) })
}, },
[removeDownload, removeLocalDownloadingModel, setProviders, t] [removeDownload, removeLocalDownloadingModel, t]
) )
useEffect(() => { useEffect(() => {

View File

@ -1,253 +0,0 @@
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { Switch } from '@/components/ui/switch'
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '@/components/ui/tooltip'
import { useModelProvider } from '@/hooks/useModelProvider'
import {
IconPencil,
IconEye,
IconTool,
// IconWorld,
// IconAtom,
IconCodeCircle2,
} from '@tabler/icons-react'
import { useState, useEffect } from 'react'
import { useTranslation } from '@/i18n/react-i18next-compat'
// No need to define our own interface, we'll use the existing Model type
type DialogEditModelProps = {
provider: ModelProvider
modelId?: string // Optional model ID to edit
}
export const DialogEditModel = ({
provider,
modelId,
}: DialogEditModelProps) => {
const { t } = useTranslation()
const { updateProvider } = useModelProvider()
const [selectedModelId, setSelectedModelId] = useState<string>('')
const [capabilities, setCapabilities] = useState<Record<string, boolean>>({
completion: false,
vision: false,
tools: false,
reasoning: false,
embeddings: false,
web_search: false,
})
// Initialize with the provided model ID or the first model if available
useEffect(() => {
if (modelId) {
setSelectedModelId(modelId)
} else if (provider.models && provider.models.length > 0) {
setSelectedModelId(provider.models[0].id)
}
}, [provider, modelId])
// Get the currently selected model
const selectedModel = provider.models.find(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(m: any) => m.id === selectedModelId
)
// Initialize capabilities from selected model
useEffect(() => {
if (selectedModel) {
const modelCapabilities = selectedModel.capabilities || []
setCapabilities({
completion: modelCapabilities.includes('completion'),
vision: modelCapabilities.includes('vision'),
tools: modelCapabilities.includes('tools'),
embeddings: modelCapabilities.includes('embeddings'),
web_search: modelCapabilities.includes('web_search'),
reasoning: modelCapabilities.includes('reasoning'),
})
}
}, [selectedModel])
// Track if capabilities were updated by user action
const [capabilitiesUpdated, setCapabilitiesUpdated] = useState(false)
// Update model capabilities - only update local state
const handleCapabilityChange = (capability: string, enabled: boolean) => {
setCapabilities((prev) => ({
...prev,
[capability]: enabled,
}))
// Mark that capabilities were updated by user action
setCapabilitiesUpdated(true)
}
// Use effect to update the provider when capabilities are explicitly changed by user
useEffect(() => {
// Only run if capabilities were updated by user action and we have a selected model
if (!capabilitiesUpdated || !selectedModel) return
// Reset the flag
setCapabilitiesUpdated(false)
// Create updated capabilities array from the state
const updatedCapabilities = Object.entries(capabilities)
.filter(([, isEnabled]) => isEnabled)
.map(([capName]) => capName)
// Find and update the model in the provider
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const updatedModels = provider.models.map((m: any) => {
if (m.id === selectedModelId) {
return {
...m,
capabilities: updatedCapabilities,
}
}
return m
})
// Update the provider with the updated models
updateProvider(provider.provider, {
...provider,
models: updatedModels,
})
}, [
capabilitiesUpdated,
capabilities,
provider,
selectedModel,
selectedModelId,
updateProvider,
])
if (!selectedModel) {
return null
}
return (
<Dialog>
<DialogTrigger asChild>
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
<IconPencil size={18} className="text-main-view-fg/50" />
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle className="line-clamp-1" title={selectedModel.id}>
{t('providers:editModel.title', { modelId: selectedModel.id })}
</DialogTitle>
<DialogDescription>
{t('providers:editModel.description')}
</DialogDescription>
</DialogHeader>
<div className="py-1">
<h3 className="text-sm font-medium mb-3">
{t('providers:editModel.capabilities')}
</h3>
<div className="space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconTool className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.tools')}
</span>
</div>
<Switch
id="tools-capability"
checked={capabilities.tools}
onCheckedChange={(checked) =>
handleCapabilityChange('tools', checked)
}
/>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconEye className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.vision')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="vision-capability"
checked={capabilities.vision}
disabled={true}
onCheckedChange={(checked) =>
handleCapabilityChange('vision', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconCodeCircle2 className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.embeddings')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="embedding-capability"
disabled={true}
checked={capabilities.embeddings}
onCheckedChange={(checked) =>
handleCapabilityChange('embeddings', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconWorld className="size-4 text-main-view-fg/70" />
<span className="text-sm">Web Search</span>
</div>
<Switch
id="web_search-capability"
checked={capabilities.web_search}
onCheckedChange={(checked) =>
handleCapabilityChange('web_search', checked)
}
/>
</div> */}
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconAtom className="size-4 text-main-view-fg/70" />
<span className="text-sm">{t('reasoning')}</span>
</div>
<Switch
id="reasoning-capability"
checked={capabilities.reasoning}
onCheckedChange={(checked) =>
handleCapabilityChange('reasoning', checked)
}
/>
</div> */}
</div>
</div>
</DialogContent>
</Dialog>
)
}

View File

@ -17,6 +17,7 @@ import {
import { useNavigate } from '@tanstack/react-router' import { useNavigate } from '@tanstack/react-router'
import { route } from '@/constants/routes' import { route } from '@/constants/routes'
import { useThreads } from '@/hooks/useThreads' import { useThreads } from '@/hooks/useThreads'
import { AppEvent, events } from '@janhq/core'
export function DataProvider() { export function DataProvider() {
const { setProviders } = useModelProvider() const { setProviders } = useModelProvider()
@ -70,6 +71,13 @@ export function DataProvider() {
} }
}, [checkForUpdate]) }, [checkForUpdate])
useEffect(() => {
events.on(AppEvent.onModelImported, () => {
getProviders().then(setProviders)
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const handleDeepLink = (urls: string[] | null) => { const handleDeepLink = (urls: string[] | null) => {
if (!urls) return if (!urls) return
console.log('Received deeplink:', urls) console.log('Received deeplink:', urls)

View File

@ -17,7 +17,12 @@ import { useModelProvider } from '@/hooks/useModelProvider'
import { Card, CardItem } from '@/containers/Card' import { Card, CardItem } from '@/containers/Card'
import { RenderMarkdown } from '@/containers/RenderMarkdown' import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { extractModelName, extractDescription } from '@/lib/models' import { extractModelName, extractDescription } from '@/lib/models'
import { IconDownload, IconFileCode, IconSearch } from '@tabler/icons-react' import {
IconDownload,
IconFileCode,
IconSearch,
IconTool,
} from '@tabler/icons-react'
import { Switch } from '@/components/ui/switch' import { Switch } from '@/components/ui/switch'
import Joyride, { CallBackProps, STATUS } from 'react-joyride' import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
@ -133,7 +138,10 @@ function Hub() {
if (debouncedSearchValue.length) { if (debouncedSearchValue.length) {
const fuse = new Fuse(filtered, searchOptions) const fuse = new Fuse(filtered, searchOptions)
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model") // Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
const cleanedSearchValue = debouncedSearchValue.replace(/^https?:\/\/[^/]+\//, '') const cleanedSearchValue = debouncedSearchValue.replace(
/^https?:\/\/[^/]+\//,
''
)
filtered = fuse.search(cleanedSearchValue).map((result) => result.item) filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
} }
// Apply downloaded filter // Apply downloaded filter
@ -647,6 +655,15 @@ function Hub() {
?.length || 0} ?.length || 0}
</span> </span>
</div> </div>
{filteredModels[virtualItem.index].tools && (
<div className="flex items-center gap-1">
<IconTool
size={17}
className="text-main-view-fg/50"
title={t('hub:tools')}
/>
</div>
)}
{filteredModels[virtualItem.index].quants.length > {filteredModels[virtualItem.index].quants.length >
1 && ( 1 && (
<div className="flex items-center gap-2 hub-show-variants-step"> <div className="flex items-center gap-2 hub-show-variants-step">

View File

@ -22,7 +22,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
import Capabilities from '@/containers/Capabilities' import Capabilities from '@/containers/Capabilities'
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting' import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
import { RenderMarkdown } from '@/containers/RenderMarkdown' import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { DialogEditModel } from '@/containers/dialogs/EditModel'
import { DialogAddModel } from '@/containers/dialogs/AddModel' import { DialogAddModel } from '@/containers/dialogs/AddModel'
import { ModelSetting } from '@/containers/ModelSetting' import { ModelSetting } from '@/containers/ModelSetting'
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel' import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
@ -556,10 +555,6 @@ function ProviderDetail() {
} }
actions={ actions={
<div className="flex items-center gap-0.5"> <div className="flex items-center gap-0.5">
<DialogEditModel
provider={provider}
modelId={model.id}
/>
{model.settings && ( {model.settings && (
<ModelSetting <ModelSetting
provider={provider} provider={provider}

View File

@ -5,7 +5,7 @@ import {
updateSettings, updateSettings,
} from '../providers' } from '../providers'
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/mock/data' import { predefinedProviders } from '@/consts/providers'
import { EngineManager } from '@janhq/core' import { EngineManager } from '@janhq/core'
import { fetchModels } from '../models' import { fetchModels } from '../models'
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
@ -21,7 +21,7 @@ vi.mock('token.js', () => ({
}, },
})) }))
vi.mock('@/mock/data', () => ({ vi.mock('@/consts/providers', () => ({
predefinedProviders: [ predefinedProviders: [
{ {
active: true, active: true,
@ -69,6 +69,7 @@ vi.mock('../models', () => ({
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' }, { id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
]) ])
), ),
isToolSupported: vi.fn(() => Promise.resolve(false)),
})) }))
vi.mock('@/lib/extension', () => ({ vi.mock('@/lib/extension', () => ({
@ -116,7 +117,7 @@ describe('providers service', () => {
it('should return builtin and runtime providers', async () => { it('should return builtin and runtime providers', async () => {
const providers = await getProviders() const providers = await getProviders()
expect(providers).toHaveLength(9) // 8 runtime + 1 builtin expect(providers).toHaveLength(2) // 1 runtime + 1 builtin (mocked)
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true) expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
expect(providers.some((p) => p.provider === 'openai')).toBe(true) expect(providers.some((p) => p.provider === 'openai')).toBe(true)
}) })
@ -156,7 +157,7 @@ describe('providers service', () => {
provider: 'openai', provider: 'openai',
base_url: 'https://api.openai.com/v1', base_url: 'https://api.openai.com/v1',
api_key: 'test-key', api_key: 'test-key',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -185,7 +186,7 @@ describe('providers service', () => {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
api_key: '', api_key: '',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -204,7 +205,7 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -214,7 +215,7 @@ describe('providers service', () => {
it('should throw error when provider has no base_url', async () => { it('should throw error when provider has no base_url', async () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Provider must have base_url configured' 'Provider must have base_url configured'
@ -232,10 +233,10 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com' 'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
) )
}) })
@ -245,10 +246,10 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com' 'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
) )
}) })
@ -264,7 +265,7 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -298,7 +299,7 @@ describe('providers service', () => {
controller_type: 'input', controller_type: 'input',
controller_props: { value: 'test-key' }, controller_props: { value: 'test-key' },
}, },
] as ProviderSetting[] ]
await updateSettings('openai', settings) await updateSettings('openai', settings)
@ -324,7 +325,7 @@ describe('providers service', () => {
mockExtensionManager mockExtensionManager
) )
const settings = [] as ProviderSetting[] const settings = []
const result = await updateSettings('nonexistent', settings) const result = await updateSettings('nonexistent', settings)
@ -350,7 +351,7 @@ describe('providers service', () => {
controller_type: 'input', controller_type: 'input',
controller_props: { value: undefined }, controller_props: { value: undefined },
}, },
] as ProviderSetting[] ]
await updateSettings('openai', settings) await updateSettings('openai', settings)

View File

@ -29,6 +29,7 @@ export interface CatalogModel {
mmproj_models?: MMProjModel[] mmproj_models?: MMProjModel[]
created_at?: string created_at?: string
readme?: string readme?: string
tools?: boolean
} }
export type ModelCatalog = CatalogModel[] export type ModelCatalog = CatalogModel[]
@ -313,3 +314,16 @@ export const startModel = async (
throw error throw error
}) })
} }
/**
* Check if model support tool use capability
* Returned by backend engine
* @param modelId
* @returns
*/
export const isToolSupported = async (modelId: string): Promise<boolean> => {
const engine = getEngine()
if (!engine) return false
return engine.isToolSupported(modelId)
}

View File

@ -1,12 +1,9 @@
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/consts/providers' import { predefinedProviders } from '@/consts/providers'
import { EngineManager, SettingComponentProps } from '@janhq/core' import { EngineManager, SettingComponentProps } from '@janhq/core'
import { import { ModelCapabilities } from '@/types/models'
DefaultToolUseSupportedModels,
ModelCapabilities,
} from '@/types/models'
import { modelSettings } from '@/lib/predefined' import { modelSettings } from '@/lib/predefined'
import { fetchModels } from './models' import { fetchModels, isToolSupported } from './models'
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http' import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
@ -65,52 +62,41 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
controller_props: setting.controllerProps as unknown, controller_props: setting.controllerProps as unknown,
} }
}) as ProviderSetting[], }) as ProviderSetting[],
models: models.map((model) => ({ models: await Promise.all(
id: model.id, models.map(
model: model.id, async (model) =>
name: model.name, ({
description: model.description, id: model.id,
capabilities: model: model.id,
'capabilities' in model name: model.name,
? (model.capabilities as string[]) description: model.description,
: [ capabilities:
ModelCapabilities.COMPLETION, 'capabilities' in model
...(Object.values(DefaultToolUseSupportedModels).some((v) => ? (model.capabilities as string[])
model.id.toLowerCase().includes(v.toLowerCase()) : (await isToolSupported(model.id))
) ? [ModelCapabilities.TOOLS]
? [ModelCapabilities.TOOLS] : [],
: []), provider: providerName,
], settings: Object.values(modelSettings).reduce(
provider: providerName, (acc, setting) => {
settings: Object.values(modelSettings).reduce( let value = setting.controller_props.value
(acc, setting) => { if (setting.key === 'ctx_len') {
let value = setting.controller_props.value value = 8192 // Default context length for Llama.cpp models
if (setting.key === 'ctx_len') { }
value = 8192 // Default context length for Llama.cpp models acc[setting.key] = {
} ...setting,
// Set temperature to 0.6 for DefaultToolUseSupportedModels controller_props: {
if ( ...setting.controller_props,
Object.values(DefaultToolUseSupportedModels).some((v) => value: value,
model.id.toLowerCase().includes(v.toLowerCase()) },
) }
) { return acc
if (setting.key === 'temperature') value = 0.7 // Default temperature for tool-supported models },
if (setting.key === 'top_k') value = 20 // Default top_k for tool-supported models {} as Record<string, ProviderSetting>
if (setting.key === 'top_p') value = 0.8 // Default top_p for tool-supported models ),
if (setting.key === 'min_p') value = 0 // Default min_p for tool-supported models }) as Model
} )
acc[setting.key] = { ),
...setting,
controller_props: {
...setting.controller_props,
value: value,
},
}
return acc
},
{} as Record<string, ProviderSetting>
),
})),
} }
runtimeProviders.push(provider) runtimeProviders.push(provider)
} }

View File

@ -14,10 +14,3 @@ export enum ModelCapabilities {
TEXT_TO_AUDIO = 'text_to_audio', TEXT_TO_AUDIO = 'text_to_audio',
AUDIO_TO_TEXT = 'audio_to_text', AUDIO_TO_TEXT = 'audio_to_text',
} }
// TODO: Remove this enum when we integrate llama.cpp extension
export enum DefaultToolUseSupportedModels {
JanNano = 'jan-',
Qwen3 = 'qwen3',
Lucy = 'lucy',
}