diff --git a/web-app/src/containers/dialogs/AddModel.tsx b/web-app/src/containers/dialogs/AddModel.tsx index e8fd4e0fd..c44d3a0a5 100644 --- a/web-app/src/containers/dialogs/AddModel.tsx +++ b/web-app/src/containers/dialogs/AddModel.tsx @@ -15,8 +15,7 @@ import { IconPlus } from '@tabler/icons-react' import { useState } from 'react' import { getProviderTitle } from '@/lib/utils' import { useTranslation } from '@/i18n/react-i18next-compat' -import { ModelCapabilities } from '@/types/models' -import { models as providerModels } from 'token.js' +import { getModelCapabilities } from '@/lib/models' import { toast } from 'sonner' type DialogAddModelProps = { @@ -52,23 +51,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => { id: modelId, model: modelId, name: modelId, - capabilities: [ - ModelCapabilities.COMPLETION, - ( - providerModels[ - provider.provider as unknown as keyof typeof providerModels - ]?.supportsToolCalls as unknown as string[] - )?.includes(modelId) - ? ModelCapabilities.TOOLS - : undefined, - ( - providerModels[ - provider.provider as unknown as keyof typeof providerModels - ]?.supportsImages as unknown as string[] - )?.includes(modelId) - ? ModelCapabilities.VISION - : undefined, - ].filter(Boolean) as string[], + capabilities: getModelCapabilities(provider.provider, modelId), version: '1.0', } diff --git a/web-app/src/lib/__tests__/models.test.ts b/web-app/src/lib/__tests__/models.test.ts index 67f37f873..bba4a64a1 100644 --- a/web-app/src/lib/__tests__/models.test.ts +++ b/web-app/src/lib/__tests__/models.test.ts @@ -5,19 +5,30 @@ import { removeYamlFrontMatter, extractModelName, extractModelRepo, + getModelCapabilities, } from '../models' +import { ModelCapabilities } from '@/types/models' // Mock the token.js module vi.mock('token.js', () => ({ models: { openai: { models: ['gpt-3.5-turbo', 'gpt-4'], + supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'], + supportsImages: ['gpt-4-vision-preview'], }, anthropic: { models: ['claude-3-sonnet', 'claude-3-haiku'], + supportsToolCalls: ['claude-3-sonnet'], + supportsImages: ['claude-3-sonnet', 'claude-3-haiku'], }, mistral: { models: ['mistral-7b', 'mistral-8x7b'], + supportsToolCalls: ['mistral-8x7b'], + }, + // Provider with no capability arrays + cohere: { + models: ['command', 'command-light'], }, }, })) @@ -223,3 +234,74 @@ describe('extractModelRepo', () => { ) }) }) + +describe('getModelCapabilities', () => { + it('returns completion capability for all models', () => { + const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo') + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + }) + + it('includes tools capability when model supports it', () => { + const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo') + expect(capabilities).toContain(ModelCapabilities.TOOLS) + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + }) + + it('excludes tools capability when model does not support it', () => { + const capabilities = getModelCapabilities('mistral', 'mistral-7b') + expect(capabilities).not.toContain(ModelCapabilities.TOOLS) + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + }) + + it('includes vision capability when model supports it', () => { + const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview') + expect(capabilities).toContain(ModelCapabilities.VISION) + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + }) + + it('excludes vision capability when model does not support it', () => { + const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo') + expect(capabilities).not.toContain(ModelCapabilities.VISION) + }) + + it('includes both tools and vision when model supports both', () => { + const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet') + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + expect(capabilities).toContain(ModelCapabilities.TOOLS) + expect(capabilities).toContain(ModelCapabilities.VISION) + }) + + it('handles provider with no capability arrays gracefully', () => { + const capabilities = getModelCapabilities('cohere', 'command') + expect(capabilities).toEqual([ModelCapabilities.COMPLETION]) + expect(capabilities).not.toContain(ModelCapabilities.TOOLS) + expect(capabilities).not.toContain(ModelCapabilities.VISION) + }) + + it('handles unknown provider gracefully', () => { + const capabilities = getModelCapabilities('openrouter', 'some-model') + expect(capabilities).toEqual([ModelCapabilities.COMPLETION]) + expect(capabilities).not.toContain(ModelCapabilities.TOOLS) + expect(capabilities).not.toContain(ModelCapabilities.VISION) + }) + + it('handles model not in capability list', () => { + const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku') + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + expect(capabilities).toContain(ModelCapabilities.VISION) + expect(capabilities).not.toContain(ModelCapabilities.TOOLS) + }) + + it('returns only completion for provider with partial capability data', () => { + // Mistral has supportsToolCalls but no supportsImages + const capabilities = getModelCapabilities('mistral', 'mistral-7b') + expect(capabilities).toEqual([ModelCapabilities.COMPLETION]) + }) + + it('handles model that supports tools but not vision', () => { + const capabilities = getModelCapabilities('mistral', 'mistral-8x7b') + expect(capabilities).toContain(ModelCapabilities.COMPLETION) + expect(capabilities).toContain(ModelCapabilities.TOOLS) + expect(capabilities).not.toContain(ModelCapabilities.VISION) + }) +}) diff --git a/web-app/src/lib/models.ts b/web-app/src/lib/models.ts index 0f9b79c40..18d0b6d8e 100644 --- a/web-app/src/lib/models.ts +++ b/web-app/src/lib/models.ts @@ -1,4 +1,5 @@ import { models } from 'token.js' +import { ModelCapabilities } from '@/types/models' export const defaultModel = (provider?: string) => { if (!provider || !Object.keys(models).includes(provider)) { @@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => { )[0] } +/** + * Determines model capabilities based on provider configuration from token.js + * @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter') + * @param modelId - The model ID to check capabilities for + * @returns Array of model capabilities + */ +export const getModelCapabilities = ( + providerName: string, + modelId: string +): string[] => { + const providerConfig = + models[providerName as unknown as keyof typeof models] + + const supportsToolCalls = Array.isArray( + providerConfig?.supportsToolCalls as unknown + ) + ? (providerConfig.supportsToolCalls as unknown as string[]) + : [] + + const supportsImages = Array.isArray( + providerConfig?.supportsImages as unknown + ) + ? (providerConfig.supportsImages as unknown as string[]) + : [] + + return [ + ModelCapabilities.COMPLETION, + supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined, + supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined, + ].filter(Boolean) as string[] +} + /** * This utility is to extract cortexso model description from README.md file * @returns diff --git a/web-app/src/services/providers/tauri.ts b/web-app/src/services/providers/tauri.ts index 50f1217da..a8ca36fbb 100644 --- a/web-app/src/services/providers/tauri.ts +++ b/web-app/src/services/providers/tauri.ts @@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined' import { ExtensionManager } from '@/lib/extension' import { fetch as fetchTauri } from '@tauri-apps/plugin-http' import { DefaultProvidersService } from './default' +import { getModelCapabilities } from '@/lib/models' export class TauriProvidersService extends DefaultProvidersService { fetch(): typeof fetch { @@ -26,32 +27,16 @@ export class TauriProvidersService extends DefaultProvidersService { provider.provider as unknown as keyof typeof providerModels ].models as unknown as string[] - if (Array.isArray(builtInModels)) + if (Array.isArray(builtInModels)) { models = builtInModels.map((model) => { const modelManifest = models.find((e) => e.id === model) // TODO: Check chat_template for tool call support - const capabilities = [ - ModelCapabilities.COMPLETION, - ( - providerModels[ - provider.provider as unknown as keyof typeof providerModels - ]?.supportsToolCalls as unknown as string[] - )?.includes(model) - ? ModelCapabilities.TOOLS - : undefined, - ( - providerModels[ - provider.provider as unknown as keyof typeof providerModels - ]?.supportsImages as unknown as string[] - )?.includes(model) - ? ModelCapabilities.VISION - : undefined, - ].filter(Boolean) as string[] return { ...(modelManifest ?? { id: model, name: model }), - capabilities, + capabilities: getModelCapabilities(provider.provider, model), } as Model }) + } } return { diff --git a/web-app/src/services/providers/web.ts b/web-app/src/services/providers/web.ts index 6a7865be8..29d4a9cb7 100644 --- a/web-app/src/services/providers/web.ts +++ b/web-app/src/services/providers/web.ts @@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension' import type { ProvidersService } from './types' import { PlatformFeatures } from '@/lib/platform/const' import { PlatformFeature } from '@/lib/platform/types' +import { getModelCapabilities } from '@/lib/models' export class WebProvidersService implements ProvidersService { async getProviders(): Promise { @@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService { models = builtInModels.map((model) => { const modelManifest = models.find((e) => e.id === model) // TODO: Check chat_template for tool call support - const capabilities = [ - ModelCapabilities.COMPLETION, - ( - providerModels[ - provider.provider as unknown as keyof typeof providerModels - ]?.supportsToolCalls as unknown as string[] - )?.includes(model) - ? ModelCapabilities.TOOLS - : undefined, - ].filter(Boolean) as string[] return { ...(modelManifest ?? { id: model, name: model }), - capabilities, + capabilities: getModelCapabilities(provider.provider, model), } as Model }) }