Merge pull request #6715 from menloresearch/fix/get-model-capabilities-correctly
fix: Extract model capabilities correctly for various providers on various platforms
This commit is contained in:
commit
bdd8549d3e
@ -15,8 +15,7 @@ import { IconPlus } from '@tabler/icons-react'
|
|||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
import { getProviderTitle } from '@/lib/utils'
|
import { getProviderTitle } from '@/lib/utils'
|
||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||||
import { ModelCapabilities } from '@/types/models'
|
import { getModelCapabilities } from '@/lib/models'
|
||||||
import { models as providerModels } from 'token.js'
|
|
||||||
import { toast } from 'sonner'
|
import { toast } from 'sonner'
|
||||||
|
|
||||||
type DialogAddModelProps = {
|
type DialogAddModelProps = {
|
||||||
@ -52,23 +51,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
|
|||||||
id: modelId,
|
id: modelId,
|
||||||
model: modelId,
|
model: modelId,
|
||||||
name: modelId,
|
name: modelId,
|
||||||
capabilities: [
|
capabilities: getModelCapabilities(provider.provider, modelId),
|
||||||
ModelCapabilities.COMPLETION,
|
|
||||||
(
|
|
||||||
providerModels[
|
|
||||||
provider.provider as unknown as keyof typeof providerModels
|
|
||||||
]?.supportsToolCalls as unknown as string[]
|
|
||||||
)?.includes(modelId)
|
|
||||||
? ModelCapabilities.TOOLS
|
|
||||||
: undefined,
|
|
||||||
(
|
|
||||||
providerModels[
|
|
||||||
provider.provider as unknown as keyof typeof providerModels
|
|
||||||
]?.supportsImages as unknown as string[]
|
|
||||||
)?.includes(modelId)
|
|
||||||
? ModelCapabilities.VISION
|
|
||||||
: undefined,
|
|
||||||
].filter(Boolean) as string[],
|
|
||||||
version: '1.0',
|
version: '1.0',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,19 +5,30 @@ import {
|
|||||||
removeYamlFrontMatter,
|
removeYamlFrontMatter,
|
||||||
extractModelName,
|
extractModelName,
|
||||||
extractModelRepo,
|
extractModelRepo,
|
||||||
|
getModelCapabilities,
|
||||||
} from '../models'
|
} from '../models'
|
||||||
|
import { ModelCapabilities } from '@/types/models'
|
||||||
|
|
||||||
// Mock the token.js module
|
// Mock the token.js module
|
||||||
vi.mock('token.js', () => ({
|
vi.mock('token.js', () => ({
|
||||||
models: {
|
models: {
|
||||||
openai: {
|
openai: {
|
||||||
models: ['gpt-3.5-turbo', 'gpt-4'],
|
models: ['gpt-3.5-turbo', 'gpt-4'],
|
||||||
|
supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'],
|
||||||
|
supportsImages: ['gpt-4-vision-preview'],
|
||||||
},
|
},
|
||||||
anthropic: {
|
anthropic: {
|
||||||
models: ['claude-3-sonnet', 'claude-3-haiku'],
|
models: ['claude-3-sonnet', 'claude-3-haiku'],
|
||||||
|
supportsToolCalls: ['claude-3-sonnet'],
|
||||||
|
supportsImages: ['claude-3-sonnet', 'claude-3-haiku'],
|
||||||
},
|
},
|
||||||
mistral: {
|
mistral: {
|
||||||
models: ['mistral-7b', 'mistral-8x7b'],
|
models: ['mistral-7b', 'mistral-8x7b'],
|
||||||
|
supportsToolCalls: ['mistral-8x7b'],
|
||||||
|
},
|
||||||
|
// Provider with no capability arrays
|
||||||
|
cohere: {
|
||||||
|
models: ['command', 'command-light'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
@ -223,3 +234,74 @@ describe('extractModelRepo', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('getModelCapabilities', () => {
|
||||||
|
it('returns completion capability for all models', () => {
|
||||||
|
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('includes tools capability when model supports it', () => {
|
||||||
|
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('excludes tools capability when model does not support it', () => {
|
||||||
|
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('includes vision capability when model supports it', () => {
|
||||||
|
const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('excludes vision capability when model does not support it', () => {
|
||||||
|
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('includes both tools and vision when model supports both', () => {
|
||||||
|
const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles provider with no capability arrays gracefully', () => {
|
||||||
|
const capabilities = getModelCapabilities('cohere', 'command')
|
||||||
|
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles unknown provider gracefully', () => {
|
||||||
|
const capabilities = getModelCapabilities('openrouter', 'some-model')
|
||||||
|
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles model not in capability list', () => {
|
||||||
|
const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.VISION)
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns only completion for provider with partial capability data', () => {
|
||||||
|
// Mistral has supportsToolCalls but no supportsImages
|
||||||
|
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
|
||||||
|
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles model that supports tools but not vision', () => {
|
||||||
|
const capabilities = getModelCapabilities('mistral', 'mistral-8x7b')
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
|
||||||
|
expect(capabilities).toContain(ModelCapabilities.TOOLS)
|
||||||
|
expect(capabilities).not.toContain(ModelCapabilities.VISION)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { models } from 'token.js'
|
import { models } from 'token.js'
|
||||||
|
import { ModelCapabilities } from '@/types/models'
|
||||||
|
|
||||||
export const defaultModel = (provider?: string) => {
|
export const defaultModel = (provider?: string) => {
|
||||||
if (!provider || !Object.keys(models).includes(provider)) {
|
if (!provider || !Object.keys(models).includes(provider)) {
|
||||||
@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => {
|
|||||||
)[0]
|
)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determines model capabilities based on provider configuration from token.js
|
||||||
|
* @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter')
|
||||||
|
* @param modelId - The model ID to check capabilities for
|
||||||
|
* @returns Array of model capabilities
|
||||||
|
*/
|
||||||
|
export const getModelCapabilities = (
|
||||||
|
providerName: string,
|
||||||
|
modelId: string
|
||||||
|
): string[] => {
|
||||||
|
const providerConfig =
|
||||||
|
models[providerName as unknown as keyof typeof models]
|
||||||
|
|
||||||
|
const supportsToolCalls = Array.isArray(
|
||||||
|
providerConfig?.supportsToolCalls as unknown
|
||||||
|
)
|
||||||
|
? (providerConfig.supportsToolCalls as unknown as string[])
|
||||||
|
: []
|
||||||
|
|
||||||
|
const supportsImages = Array.isArray(
|
||||||
|
providerConfig?.supportsImages as unknown
|
||||||
|
)
|
||||||
|
? (providerConfig.supportsImages as unknown as string[])
|
||||||
|
: []
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModelCapabilities.COMPLETION,
|
||||||
|
supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined,
|
||||||
|
supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined,
|
||||||
|
].filter(Boolean) as string[]
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This utility is to extract cortexso model description from README.md file
|
* This utility is to extract cortexso model description from README.md file
|
||||||
* @returns
|
* @returns
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined'
|
|||||||
import { ExtensionManager } from '@/lib/extension'
|
import { ExtensionManager } from '@/lib/extension'
|
||||||
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
||||||
import { DefaultProvidersService } from './default'
|
import { DefaultProvidersService } from './default'
|
||||||
|
import { getModelCapabilities } from '@/lib/models'
|
||||||
|
|
||||||
export class TauriProvidersService extends DefaultProvidersService {
|
export class TauriProvidersService extends DefaultProvidersService {
|
||||||
fetch(): typeof fetch {
|
fetch(): typeof fetch {
|
||||||
@ -26,33 +27,17 @@ export class TauriProvidersService extends DefaultProvidersService {
|
|||||||
provider.provider as unknown as keyof typeof providerModels
|
provider.provider as unknown as keyof typeof providerModels
|
||||||
].models as unknown as string[]
|
].models as unknown as string[]
|
||||||
|
|
||||||
if (Array.isArray(builtInModels))
|
if (Array.isArray(builtInModels)) {
|
||||||
models = builtInModels.map((model) => {
|
models = builtInModels.map((model) => {
|
||||||
const modelManifest = models.find((e) => e.id === model)
|
const modelManifest = models.find((e) => e.id === model)
|
||||||
// TODO: Check chat_template for tool call support
|
// TODO: Check chat_template for tool call support
|
||||||
const capabilities = [
|
|
||||||
ModelCapabilities.COMPLETION,
|
|
||||||
(
|
|
||||||
providerModels[
|
|
||||||
provider.provider as unknown as keyof typeof providerModels
|
|
||||||
]?.supportsToolCalls as unknown as string[]
|
|
||||||
)?.includes(model)
|
|
||||||
? ModelCapabilities.TOOLS
|
|
||||||
: undefined,
|
|
||||||
(
|
|
||||||
providerModels[
|
|
||||||
provider.provider as unknown as keyof typeof providerModels
|
|
||||||
]?.supportsImages as unknown as string[]
|
|
||||||
)?.includes(model)
|
|
||||||
? ModelCapabilities.VISION
|
|
||||||
: undefined,
|
|
||||||
].filter(Boolean) as string[]
|
|
||||||
return {
|
return {
|
||||||
...(modelManifest ?? { id: model, name: model }),
|
...(modelManifest ?? { id: model, name: model }),
|
||||||
capabilities,
|
capabilities: getModelCapabilities(provider.provider, model),
|
||||||
} as Model
|
} as Model
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension'
|
|||||||
import type { ProvidersService } from './types'
|
import type { ProvidersService } from './types'
|
||||||
import { PlatformFeatures } from '@/lib/platform/const'
|
import { PlatformFeatures } from '@/lib/platform/const'
|
||||||
import { PlatformFeature } from '@/lib/platform/types'
|
import { PlatformFeature } from '@/lib/platform/types'
|
||||||
|
import { getModelCapabilities } from '@/lib/models'
|
||||||
|
|
||||||
export class WebProvidersService implements ProvidersService {
|
export class WebProvidersService implements ProvidersService {
|
||||||
async getProviders(): Promise<ModelProvider[]> {
|
async getProviders(): Promise<ModelProvider[]> {
|
||||||
@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService {
|
|||||||
models = builtInModels.map((model) => {
|
models = builtInModels.map((model) => {
|
||||||
const modelManifest = models.find((e) => e.id === model)
|
const modelManifest = models.find((e) => e.id === model)
|
||||||
// TODO: Check chat_template for tool call support
|
// TODO: Check chat_template for tool call support
|
||||||
const capabilities = [
|
|
||||||
ModelCapabilities.COMPLETION,
|
|
||||||
(
|
|
||||||
providerModels[
|
|
||||||
provider.provider as unknown as keyof typeof providerModels
|
|
||||||
]?.supportsToolCalls as unknown as string[]
|
|
||||||
)?.includes(model)
|
|
||||||
? ModelCapabilities.TOOLS
|
|
||||||
: undefined,
|
|
||||||
].filter(Boolean) as string[]
|
|
||||||
return {
|
return {
|
||||||
...(modelManifest ?? { id: model, name: model }),
|
...(modelManifest ?? { id: model, name: model }),
|
||||||
capabilities,
|
capabilities: getModelCapabilities(provider.provider, model),
|
||||||
} as Model
|
} as Model
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user