Merge pull request #6475 from menloresearch/feat/bump-tokenjs
feat: fix remote provider vision capability
This commit is contained in:
parent
8cdb021b3d
commit
b0b84b7eda
@ -82,7 +82,7 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"sonner": "^2.0.3",
|
||||
"tailwindcss": "^4.1.4",
|
||||
"token.js": "npm:token.js-fork@0.7.23",
|
||||
"token.js": "npm:token.js-fork@0.7.27",
|
||||
"tw-animate-css": "^1.2.7",
|
||||
"ulidx": "^2.4.1",
|
||||
"unified": "^11.0.5",
|
||||
|
||||
@ -15,6 +15,8 @@ import { IconPlus } from '@tabler/icons-react'
|
||||
import { useState } from 'react'
|
||||
import { getProviderTitle } from '@/lib/utils'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
import { models as providerModels } from 'token.js'
|
||||
|
||||
type DialogAddModelProps = {
|
||||
provider: ModelProvider
|
||||
@ -44,7 +46,23 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
|
||||
id: modelId,
|
||||
model: modelId,
|
||||
name: modelId,
|
||||
capabilities: ['completion'], // Default capability
|
||||
capabilities: [
|
||||
ModelCapabilities.COMPLETION,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
].supportsToolCalls as unknown as string[]
|
||||
).includes(modelId)
|
||||
? ModelCapabilities.TOOLS
|
||||
: undefined,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
].supportsImages as unknown as string[]
|
||||
).includes(modelId)
|
||||
? ModelCapabilities.VISION
|
||||
: undefined,
|
||||
].filter(Boolean) as string[],
|
||||
version: '1.0',
|
||||
}
|
||||
|
||||
|
||||
@ -39,6 +39,13 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
).includes(model)
|
||||
? ModelCapabilities.TOOLS
|
||||
: undefined,
|
||||
(
|
||||
providerModels[
|
||||
provider.provider as unknown as keyof typeof providerModels
|
||||
].supportsImages as unknown as string[]
|
||||
).includes(model)
|
||||
? ModelCapabilities.VISION
|
||||
: undefined,
|
||||
].filter(Boolean) as string[]
|
||||
return {
|
||||
...(modelManifest ?? { id: model, name: model }),
|
||||
@ -74,53 +81,54 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
}
|
||||
}) as ProviderSetting[],
|
||||
models: await Promise.all(
|
||||
models.map(
|
||||
async (model) => {
|
||||
let capabilities: string[] = []
|
||||
|
||||
// Check for capabilities
|
||||
if ('capabilities' in model) {
|
||||
capabilities = model.capabilities as string[]
|
||||
} else {
|
||||
// Try to check tool support, but don't let failures block the model
|
||||
try {
|
||||
const toolSupported = await value.isToolSupported(model.id)
|
||||
if (toolSupported) {
|
||||
capabilities = [ModelCapabilities.TOOLS]
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(`Failed to check tool support for model ${model.id}:`, error)
|
||||
// Continue without tool capabilities if check fails
|
||||
models.map(async (model) => {
|
||||
let capabilities: string[] = []
|
||||
|
||||
// Check for capabilities
|
||||
if ('capabilities' in model) {
|
||||
capabilities = model.capabilities as string[]
|
||||
} else {
|
||||
// Try to check tool support, but don't let failures block the model
|
||||
try {
|
||||
const toolSupported = await value.isToolSupported(model.id)
|
||||
if (toolSupported) {
|
||||
capabilities = [ModelCapabilities.TOOLS]
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
`Failed to check tool support for model ${model.id}:`,
|
||||
error
|
||||
)
|
||||
// Continue without tool capabilities if check fails
|
||||
}
|
||||
|
||||
return {
|
||||
id: model.id,
|
||||
model: model.id,
|
||||
name: model.name,
|
||||
description: model.description,
|
||||
capabilities,
|
||||
provider: providerName,
|
||||
settings: Object.values(modelSettings).reduce(
|
||||
(acc, setting) => {
|
||||
let value = setting.controller_props.value
|
||||
if (setting.key === 'ctx_len') {
|
||||
value = 8192 // Default context length for Llama.cpp models
|
||||
}
|
||||
acc[setting.key] = {
|
||||
...setting,
|
||||
controller_props: {
|
||||
...setting.controller_props,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, ProviderSetting>
|
||||
),
|
||||
} as Model
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
id: model.id,
|
||||
model: model.id,
|
||||
name: model.name,
|
||||
description: model.description,
|
||||
capabilities,
|
||||
provider: providerName,
|
||||
settings: Object.values(modelSettings).reduce(
|
||||
(acc, setting) => {
|
||||
let value = setting.controller_props.value
|
||||
if (setting.key === 'ctx_len') {
|
||||
value = 8192 // Default context length for Llama.cpp models
|
||||
}
|
||||
acc[setting.key] = {
|
||||
...setting,
|
||||
controller_props: {
|
||||
...setting.controller_props,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, ProviderSetting>
|
||||
),
|
||||
} as Model
|
||||
})
|
||||
),
|
||||
}
|
||||
runtimeProviders.push(provider)
|
||||
@ -145,7 +153,10 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
|
||||
// Add Origin header for local providers to avoid CORS issues
|
||||
// Some local providers (like Ollama) require an Origin header
|
||||
if (provider.base_url.includes('localhost:') || provider.base_url.includes('127.0.0.1:')) {
|
||||
if (
|
||||
provider.base_url.includes('localhost:') ||
|
||||
provider.base_url.includes('127.0.0.1:')
|
||||
) {
|
||||
headers['Origin'] = 'tauri://localhost'
|
||||
}
|
||||
|
||||
@ -187,7 +198,9 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
// Handle different response formats that providers might use
|
||||
if (data.data && Array.isArray(data.data)) {
|
||||
// OpenAI format: { data: [{ id: "model-id" }, ...] }
|
||||
return data.data.map((model: { id: string }) => model.id).filter(Boolean)
|
||||
return data.data
|
||||
.map((model: { id: string }) => model.id)
|
||||
.filter(Boolean)
|
||||
} else if (Array.isArray(data)) {
|
||||
// Direct array format: ["model-id1", "model-id2", ...]
|
||||
return data
|
||||
@ -214,11 +227,15 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
'Authentication failed',
|
||||
'Access forbidden',
|
||||
'Models endpoint not found',
|
||||
'Failed to fetch models from'
|
||||
'Failed to fetch models from',
|
||||
]
|
||||
|
||||
if (error instanceof Error &&
|
||||
structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
structuredErrorPrefixes.some((prefix) =>
|
||||
(error as Error).message.startsWith(prefix)
|
||||
)
|
||||
) {
|
||||
throw new Error(error.message)
|
||||
}
|
||||
|
||||
@ -236,7 +253,10 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
}
|
||||
}
|
||||
|
||||
async updateSettings(providerName: string, settings: ProviderSetting[]): Promise<void> {
|
||||
async updateSettings(
|
||||
providerName: string,
|
||||
settings: ProviderSetting[]
|
||||
): Promise<void> {
|
||||
try {
|
||||
return ExtensionManager.getInstance()
|
||||
.getEngine(providerName)
|
||||
@ -258,4 +278,4 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,4 +13,6 @@ export enum ModelCapabilities {
|
||||
IMAGE_TO_IMAGE = 'image_to_image',
|
||||
TEXT_TO_AUDIO = 'text_to_audio',
|
||||
AUDIO_TO_TEXT = 'audio_to_text',
|
||||
}
|
||||
// Need to consolidate the capabilities list
|
||||
VISION = 'vision',
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user