Merge pull request #6715 from menloresearch/fix/get-model-capabilities-correctly

fix: Extract model capabilities correctly for various providers on various platforms
This commit is contained in:
Nghia Doan 2025-10-02 23:16:22 +07:00 committed by GitHub
commit bdd8549d3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 49 deletions

View File

@ -15,8 +15,7 @@ import { IconPlus } from '@tabler/icons-react'
import { useState } from 'react'
import { getProviderTitle } from '@/lib/utils'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { ModelCapabilities } from '@/types/models'
import { models as providerModels } from 'token.js'
import { getModelCapabilities } from '@/lib/models'
import { toast } from 'sonner'
type DialogAddModelProps = {
@ -52,23 +51,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
id: modelId,
model: modelId,
name: modelId,
capabilities: [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(modelId)
? ModelCapabilities.TOOLS
: undefined,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsImages as unknown as string[]
)?.includes(modelId)
? ModelCapabilities.VISION
: undefined,
].filter(Boolean) as string[],
capabilities: getModelCapabilities(provider.provider, modelId),
version: '1.0',
}

View File

@ -5,19 +5,30 @@ import {
removeYamlFrontMatter,
extractModelName,
extractModelRepo,
getModelCapabilities,
} from '../models'
import { ModelCapabilities } from '@/types/models'
// Mock the token.js module
vi.mock('token.js', () => ({
models: {
openai: {
models: ['gpt-3.5-turbo', 'gpt-4'],
supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'],
supportsImages: ['gpt-4-vision-preview'],
},
anthropic: {
models: ['claude-3-sonnet', 'claude-3-haiku'],
supportsToolCalls: ['claude-3-sonnet'],
supportsImages: ['claude-3-sonnet', 'claude-3-haiku'],
},
mistral: {
models: ['mistral-7b', 'mistral-8x7b'],
supportsToolCalls: ['mistral-8x7b'],
},
// Provider with no capability arrays
cohere: {
models: ['command', 'command-light'],
},
},
}))
@ -223,3 +234,74 @@ describe('extractModelRepo', () => {
)
})
})
describe('getModelCapabilities', () => {
it('returns completion capability for all models', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})
it('includes tools capability when model supports it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})
it('excludes tools capability when model does not support it', () => {
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})
it('includes vision capability when model supports it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview')
expect(capabilities).toContain(ModelCapabilities.VISION)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})
it('excludes vision capability when model does not support it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})
it('includes both tools and vision when model supports both', () => {
const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.VISION)
})
it('handles provider with no capability arrays gracefully', () => {
const capabilities = getModelCapabilities('cohere', 'command')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})
it('handles unknown provider gracefully', () => {
const capabilities = getModelCapabilities('openrouter', 'some-model')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})
it('handles model not in capability list', () => {
const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.VISION)
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
})
it('returns only completion for provider with partial capability data', () => {
// Mistral has supportsToolCalls but no supportsImages
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
})
it('handles model that supports tools but not vision', () => {
const capabilities = getModelCapabilities('mistral', 'mistral-8x7b')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})
})

View File

@ -1,4 +1,5 @@
import { models } from 'token.js'
import { ModelCapabilities } from '@/types/models'
export const defaultModel = (provider?: string) => {
if (!provider || !Object.keys(models).includes(provider)) {
@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => {
)[0]
}
/**
* Determines model capabilities based on provider configuration from token.js
* @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter')
* @param modelId - The model ID to check capabilities for
* @returns Array of model capabilities
*/
export const getModelCapabilities = (
providerName: string,
modelId: string
): string[] => {
const providerConfig =
models[providerName as unknown as keyof typeof models]
const supportsToolCalls = Array.isArray(
providerConfig?.supportsToolCalls as unknown
)
? (providerConfig.supportsToolCalls as unknown as string[])
: []
const supportsImages = Array.isArray(
providerConfig?.supportsImages as unknown
)
? (providerConfig.supportsImages as unknown as string[])
: []
return [
ModelCapabilities.COMPLETION,
supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined,
supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined,
].filter(Boolean) as string[]
}
/**
* This utility is to extract cortexso model description from README.md file
* @returns

View File

@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined'
import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
import { DefaultProvidersService } from './default'
import { getModelCapabilities } from '@/lib/models'
export class TauriProvidersService extends DefaultProvidersService {
fetch(): typeof fetch {
@ -26,32 +27,16 @@ export class TauriProvidersService extends DefaultProvidersService {
provider.provider as unknown as keyof typeof providerModels
].models as unknown as string[]
if (Array.isArray(builtInModels))
if (Array.isArray(builtInModels)) {
models = builtInModels.map((model) => {
const modelManifest = models.find((e) => e.id === model)
// TODO: Check chat_template for tool call support
const capabilities = [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(model)
? ModelCapabilities.TOOLS
: undefined,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsImages as unknown as string[]
)?.includes(model)
? ModelCapabilities.VISION
: undefined,
].filter(Boolean) as string[]
return {
...(modelManifest ?? { id: model, name: model }),
capabilities,
capabilities: getModelCapabilities(provider.provider, model),
} as Model
})
}
}
return {

View File

@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension'
import type { ProvidersService } from './types'
import { PlatformFeatures } from '@/lib/platform/const'
import { PlatformFeature } from '@/lib/platform/types'
import { getModelCapabilities } from '@/lib/models'
export class WebProvidersService implements ProvidersService {
async getProviders(): Promise<ModelProvider[]> {
@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService {
models = builtInModels.map((model) => {
const modelManifest = models.find((e) => e.id === model)
// TODO: Check chat_template for tool call support
const capabilities = [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(model)
? ModelCapabilities.TOOLS
: undefined,
].filter(Boolean) as string[]
return {
...(modelManifest ?? { id: model, name: model }),
capabilities,
capabilities: getModelCapabilities(provider.provider, model),
} as Model
})
}