Merge pull request #6715 from menloresearch/fix/get-model-capabilities-correctly
fix: Extract model capabilities correctly for various providers on various platforms
This commit is contained in:
parent
f537429d2c
commit
c14e1ea00f
@ -15,8 +15,7 @@ import { IconPlus } from '@tabler/icons-react'
|
||||
import { useState } from 'react'
|
||||
import { getProviderTitle } from '@/lib/utils'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
import { models as providerModels } from 'token.js'
|
||||
import { getModelCapabilities } from '@/lib/models'
|
||||
import { toast } from 'sonner'
|
||||
|
||||
type DialogAddModelProps = {
|
||||
@ -52,23 +51,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
|
||||
id: modelId,
|
||||
model: modelId,
|
||||
name: modelId,
|
||||
capabilities: [
|
||||
ModelCapabilities.COMPLETION,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
]?.supportsToolCalls as unknown as string[]
|
||||
)?.includes(modelId)
|
||||
? ModelCapabilities.TOOLS
|
||||
: undefined,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
]?.supportsImages as unknown as string[]
|
||||
)?.includes(modelId)
|
||||
? ModelCapabilities.VISION
|
||||
: undefined,
|
||||
].filter(Boolean) as string[],
|
||||
capabilities: getModelCapabilities(provider.provider, modelId),
|
||||
version: '1.0',
|
||||
}
|
||||
|
||||
|
||||
@ -5,19 +5,30 @@ import {
|
||||
removeYamlFrontMatter,
|
||||
extractModelName,
|
||||
extractModelRepo,
|
||||
getModelCapabilities,
|
||||
} from '../models'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
|
||||
// Mock the token.js module
|
||||
vi.mock('token.js', () => ({
|
||||
models: {
|
||||
openai: {
|
||||
models: ['gpt-3.5-turbo', 'gpt-4'],
|
||||
supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'],
|
||||
supportsImages: ['gpt-4-vision-preview'],
|
||||
},
|
||||
anthropic: {
|
||||
models: ['claude-3-sonnet', 'claude-3-haiku'],
|
||||
supportsToolCalls: ['claude-3-sonnet'],
|
||||
supportsImages: ['claude-3-sonnet', 'claude-3-haiku'],
|
||||
},
|
||||
mistral: {
|
||||
models: ['mistral-7b', 'mistral-8x7b'],
|
||||
supportsToolCalls: ['mistral-8x7b'],
|
||||
},
|
||||
// Provider with no capability arrays
|
||||
cohere: {
|
||||
models: ['command', 'command-light'],
|
||||
},
|
||||
},
|
||||
}))
|
||||
@ -223,3 +234,74 @@ describe('extractModelRepo', () => {
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getModelCapabilities', () => {
|
||||
it('returns completion capability for all models', () => {
|
||||
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
})
|
||||
|
||||
it('includes tools capability when model supports it', () => {
|
||||
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
})
|
||||
|
||||
it('excludes tools capability when model does not support it', () => {
|
||||
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
|
||||
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
})
|
||||
|
||||
it('includes vision capability when model supports it', () => {
|
||||
const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview')
|
||||
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
})
|
||||
|
||||
it('excludes vision capability when model does not support it', () => {
|
||||
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||
})
|
||||
|
||||
it('includes both tools and vision when model supports both', () => {
|
||||
const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet')
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||
})
|
||||
|
||||
it('handles provider with no capability arrays gracefully', () => {
|
||||
const capabilities = getModelCapabilities('cohere', 'command')
|
||||
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||
})
|
||||
|
||||
it('handles unknown provider gracefully', () => {
|
||||
const capabilities = getModelCapabilities('openrouter', 'some-model')
|
||||
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||
})
|
||||
|
||||
it('handles model not in capability list', () => {
|
||||
const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku')
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||
})
|
||||
|
||||
it('returns only completion for provider with partial capability data', () => {
|
||||
// Mistral has supportsToolCalls but no supportsImages
|
||||
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
|
||||
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||
})
|
||||
|
||||
it('handles model that supports tools but not vision', () => {
|
||||
const capabilities = getModelCapabilities('mistral', 'mistral-8x7b')
|
||||
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { models } from 'token.js'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
|
||||
export const defaultModel = (provider?: string) => {
|
||||
if (!provider || !Object.keys(models).includes(provider)) {
|
||||
@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => {
|
||||
)[0]
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines model capabilities based on provider configuration from token.js
|
||||
* @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter')
|
||||
* @param modelId - The model ID to check capabilities for
|
||||
* @returns Array of model capabilities
|
||||
*/
|
||||
export const getModelCapabilities = (
|
||||
providerName: string,
|
||||
modelId: string
|
||||
): string[] => {
|
||||
const providerConfig =
|
||||
models[providerName as unknown as keyof typeof models]
|
||||
|
||||
const supportsToolCalls = Array.isArray(
|
||||
providerConfig?.supportsToolCalls as unknown
|
||||
)
|
||||
? (providerConfig.supportsToolCalls as unknown as string[])
|
||||
: []
|
||||
|
||||
const supportsImages = Array.isArray(
|
||||
providerConfig?.supportsImages as unknown
|
||||
)
|
||||
? (providerConfig.supportsImages as unknown as string[])
|
||||
: []
|
||||
|
||||
return [
|
||||
ModelCapabilities.COMPLETION,
|
||||
supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined,
|
||||
supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined,
|
||||
].filter(Boolean) as string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* This utility is to extract cortexso model description from README.md file
|
||||
* @returns
|
||||
|
||||
@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
||||
import { DefaultProvidersService } from './default'
|
||||
import { getModelCapabilities } from '@/lib/models'
|
||||
|
||||
export class TauriProvidersService extends DefaultProvidersService {
|
||||
fetch(): typeof fetch {
|
||||
@ -26,32 +27,16 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
].models as unknown as string[]
|
||||
|
||||
if (Array.isArray(builtInModels))
|
||||
if (Array.isArray(builtInModels)) {
|
||||
models = builtInModels.map((model) => {
|
||||
const modelManifest = models.find((e) => e.id === model)
|
||||
// TODO: Check chat_template for tool call support
|
||||
const capabilities = [
|
||||
ModelCapabilities.COMPLETION,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
]?.supportsToolCalls as unknown as string[]
|
||||
)?.includes(model)
|
||||
? ModelCapabilities.TOOLS
|
||||
: undefined,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
]?.supportsImages as unknown as string[]
|
||||
)?.includes(model)
|
||||
? ModelCapabilities.VISION
|
||||
: undefined,
|
||||
].filter(Boolean) as string[]
|
||||
return {
|
||||
...(modelManifest ?? { id: model, name: model }),
|
||||
capabilities,
|
||||
capabilities: getModelCapabilities(provider.provider, model),
|
||||
} as Model
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension'
|
||||
import type { ProvidersService } from './types'
|
||||
import { PlatformFeatures } from '@/lib/platform/const'
|
||||
import { PlatformFeature } from '@/lib/platform/types'
|
||||
import { getModelCapabilities } from '@/lib/models'
|
||||
|
||||
export class WebProvidersService implements ProvidersService {
|
||||
async getProviders(): Promise<ModelProvider[]> {
|
||||
@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService {
|
||||
models = builtInModels.map((model) => {
|
||||
const modelManifest = models.find((e) => e.id === model)
|
||||
// TODO: Check chat_template for tool call support
|
||||
const capabilities = [
|
||||
ModelCapabilities.COMPLETION,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
]?.supportsToolCalls as unknown as string[]
|
||||
)?.includes(model)
|
||||
? ModelCapabilities.TOOLS
|
||||
: undefined,
|
||||
].filter(Boolean) as string[]
|
||||
return {
|
||||
...(modelManifest ?? { id: model, name: model }),
|
||||
capabilities,
|
||||
capabilities: getModelCapabilities(provider.provider, model),
|
||||
} as Model
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user