2025-10-03 20:54:54 +07:00

273 lines
9.3 KiB
TypeScript

/**
* Tauri Providers Service - Desktop implementation
*/
import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/consts/providers'
import { EngineManager, SettingComponentProps } from '@janhq/core'
import { ModelCapabilities } from '@/types/models'
import { modelSettings } from '@/lib/predefined'
import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
import { DefaultProvidersService } from './default'
import { getModelCapabilities } from '@/lib/models'
export class TauriProvidersService extends DefaultProvidersService {
fetch(): typeof fetch {
// Tauri implementation uses Tauri's fetch to avoid CORS issues
return fetchTauri as typeof fetch
}
async getProviders(): Promise<ModelProvider[]> {
try {
const builtinProviders = predefinedProviders.map((provider) => {
let models = provider.models as Model[]
if (Object.keys(providerModels).includes(provider.provider)) {
const builtInModels = providerModels[
provider.provider as unknown as keyof typeof providerModels
].models as unknown as string[]
if (Array.isArray(builtInModels)) {
models = builtInModels.map((model) => {
const modelManifest = models.find((e) => e.id === model)
// TODO: Check chat_template for tool call support
return {
...(modelManifest ?? { id: model, name: model }),
capabilities: getModelCapabilities(provider.provider, model),
} as Model
})
}
}
return {
...provider,
models,
}
})
const runtimeProviders: ModelProvider[] = []
for (const [providerName, value] of EngineManager.instance().engines) {
const models = (await value.list()) ?? []
const provider: ModelProvider = {
active: false,
persist: true,
provider: providerName,
base_url:
'inferenceUrl' in value
? (value.inferenceUrl as string).replace('/chat/completions', '')
: '',
settings: (await value.getSettings()).map((setting) => {
return {
key: setting.key,
title: setting.title,
description: setting.description,
controller_type: setting.controllerType as unknown,
controller_props: setting.controllerProps as unknown,
}
}) as ProviderSetting[],
models: await Promise.all(
models.map(async (model) => {
let capabilities: string[] = []
// Check for capabilities
if ('capabilities' in model) {
capabilities = model.capabilities as string[]
} else {
// Try to check tool support, but don't let failures block the model
try {
const toolSupported = await value.isToolSupported(model.id)
if (toolSupported) {
capabilities = [ModelCapabilities.TOOLS]
}
} catch (error) {
console.warn(
`Failed to check tool support for model ${model.id}:`,
error
)
// Continue without tool capabilities if check fails
}
}
return {
id: model.id,
model: model.id,
name: model.name,
description: model.description,
capabilities,
provider: providerName,
settings: Object.values(modelSettings).reduce(
(acc, setting) => {
let value = setting.controller_props.value
if (setting.key === 'ctx_len') {
value = 8192 // Default context length for Llama.cpp models
}
acc[setting.key] = {
...setting,
controller_props: {
...setting.controller_props,
value: value,
},
}
return acc
},
{} as Record<string, ProviderSetting>
),
} as Model
})
),
}
runtimeProviders.push(provider)
}
return runtimeProviders.concat(builtinProviders as ModelProvider[])
} catch (error: unknown) {
console.error('Error getting providers in Tauri:', error)
return []
}
}
async fetchModelsFromProvider(provider: ModelProvider): Promise<string[]> {
if (!provider.base_url) {
throw new Error('Provider must have base_url configured')
}
try {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
// Add Origin header for local providers to avoid CORS issues
// Some local providers (like Ollama) require an Origin header
if (
provider.base_url.includes('localhost:') ||
provider.base_url.includes('127.0.0.1:')
) {
headers['Origin'] = 'tauri://localhost'
}
// Only add authentication headers if API key is provided
if (provider.api_key) {
headers['x-api-key'] = provider.api_key
headers['Authorization'] = `Bearer ${provider.api_key}`
}
if (provider.custom_header) {
provider.custom_header.forEach((header) => {
headers[header.header] = header.value
})
}
// Always use Tauri's fetch to avoid CORS issues
const response = await fetchTauri(`${provider.base_url}/models`, {
method: 'GET',
headers,
})
if (!response.ok) {
// Provide more specific error messages based on status code (aligned with web implementation)
if (response.status === 401) {
throw new Error(
`Authentication failed: API key is required or invalid for ${provider.provider}`
)
} else if (response.status === 403) {
throw new Error(
`Access forbidden: Check your API key permissions for ${provider.provider}`
)
} else if (response.status === 404) {
throw new Error(
`Models endpoint not found for ${provider.provider}. Check the base URL configuration.`
)
} else {
throw new Error(
`Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}`
)
}
}
const data = await response.json()
// Handle different response formats that providers might use
if (data.data && Array.isArray(data.data)) {
// OpenAI format: { data: [{ id: "model-id" }, ...] }
return data.data
.map((model: { id: string }) => model.id)
.filter(Boolean)
} else if (Array.isArray(data)) {
// Direct array format: ["model-id1", "model-id2", ...]
return data
.filter(Boolean)
.map((model) =>
typeof model === 'object' && 'id' in model ? model.id : model
)
} else if (data.models && Array.isArray(data.models)) {
// Alternative format: { models: [...] }
return data.models
.map((model: string | { id: string }) =>
typeof model === 'string' ? model : model.id
)
.filter(Boolean)
} else {
console.warn('Unexpected response format from provider API:', data)
return []
}
} catch (error) {
console.error('Error fetching models from provider:', error)
// Preserve structured error messages thrown above
const structuredErrorPrefixes = [
'Authentication failed',
'Access forbidden',
'Models endpoint not found',
'Failed to fetch models from',
]
if (
error instanceof Error &&
structuredErrorPrefixes.some((prefix) =>
(error as Error).message.startsWith(prefix)
)
) {
throw new Error(error.message)
}
// Provide helpful error message for any connection errors
if (error instanceof Error && error.message.includes('fetch')) {
throw new Error(
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
)
}
// Generic fallback
throw new Error(
`Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}`
)
}
}
async updateSettings(
providerName: string,
settings: ProviderSetting[]
): Promise<void> {
try {
return ExtensionManager.getInstance()
.getEngine(providerName)
?.updateSettings(
settings.map((setting) => ({
...setting,
controllerProps: {
...setting.controller_props,
value:
setting.controller_props.value !== undefined
? setting.controller_props.value
: '',
},
controllerType: setting.controller_type,
})) as SettingComponentProps[]
)
} catch (error) {
console.error('Error updating settings in Tauri:', error)
throw error
}
}
}