402 lines
14 KiB
TypeScript
402 lines
14 KiB
TypeScript
import { create } from 'zustand'
|
|
import { persist, createJSONStorage } from 'zustand/middleware'
|
|
import { localStorageKey } from '@/constants/localStorage'
|
|
import { getServiceHub } from '@/hooks/useServiceHub'
|
|
import { modelSettings } from '@/lib/predefined'
|
|
|
|
type ModelProviderState = {
|
|
providers: ModelProvider[]
|
|
selectedProvider: string
|
|
selectedModel: Model | null
|
|
deletedModels: string[]
|
|
getModelBy: (modelId: string) => Model | undefined
|
|
setProviders: (providers: ModelProvider[]) => void
|
|
getProviderByName: (providerName: string) => ModelProvider | undefined
|
|
updateProvider: (providerName: string, data: Partial<ModelProvider>) => void
|
|
selectModelProvider: (
|
|
providerName: string,
|
|
modelName: string
|
|
) => Model | undefined
|
|
addProvider: (provider: ModelProvider) => void
|
|
deleteProvider: (providerName: string) => void
|
|
deleteModel: (modelId: string) => void
|
|
}
|
|
|
|
export const useModelProvider = create<ModelProviderState>()(
|
|
persist(
|
|
(set, get) => ({
|
|
providers: [],
|
|
selectedProvider: 'llamacpp',
|
|
selectedModel: null,
|
|
deletedModels: [],
|
|
getModelBy: (modelId: string) => {
|
|
const provider = get().providers.find(
|
|
(provider) => provider.provider === get().selectedProvider
|
|
)
|
|
if (!provider) return undefined
|
|
return provider.models.find((model) => model.id === modelId)
|
|
},
|
|
setProviders: (providers) =>
|
|
set((state) => {
|
|
const existingProviders = state.providers
|
|
// Filter out legacy llama.cpp provider for migration
|
|
// Can remove after a couple of releases
|
|
.filter((e) => e.provider !== 'llama.cpp')
|
|
.map((provider) => {
|
|
return {
|
|
...provider,
|
|
models: provider.models.filter(
|
|
(e) =>
|
|
('id' in e || 'model' in e) &&
|
|
typeof (e.id ?? e.model) === 'string'
|
|
),
|
|
}
|
|
})
|
|
|
|
let legacyModels: Model[] | undefined = []
|
|
/// Cortex Migration
|
|
if (
|
|
localStorage.getItem('cortex_model_settings_migrated') !== 'true'
|
|
) {
|
|
legacyModels = state.providers.find(
|
|
(e) => e.provider === 'llama.cpp'
|
|
)?.models
|
|
localStorage.setItem('cortex_model_settings_migrated', 'true')
|
|
}
|
|
// Ensure deletedModels is always an array
|
|
const currentDeletedModels = Array.isArray(state.deletedModels)
|
|
? state.deletedModels
|
|
: []
|
|
|
|
const updatedProviders = providers.map((provider) => {
|
|
const existingProvider = existingProviders.find(
|
|
(x) => x.provider === provider.provider
|
|
)
|
|
const models = (existingProvider?.models || []).filter(
|
|
(e) =>
|
|
('id' in e || 'model' in e) &&
|
|
typeof (e.id ?? e.model) === 'string'
|
|
)
|
|
const mergedModels = [
|
|
...(provider?.models ?? []).filter(
|
|
(e) =>
|
|
('id' in e || 'model' in e) &&
|
|
typeof (e.id ?? e.model) === 'string' &&
|
|
!models.some((m) => m.id === e.id) &&
|
|
!currentDeletedModels.includes(e.id)
|
|
),
|
|
...models,
|
|
]
|
|
const updatedModels = provider.models?.map((model) => {
|
|
const settings =
|
|
(legacyModels && legacyModels?.length > 0
|
|
? legacyModels
|
|
: models
|
|
).find(
|
|
(m) =>
|
|
m.id
|
|
.split(':')
|
|
.slice(0, 2)
|
|
.join(getServiceHub().path().sep()) === model.id
|
|
)?.settings || model.settings
|
|
const existingModel = models.find((m) => m.id === model.id)
|
|
return {
|
|
...model,
|
|
settings: settings,
|
|
capabilities: existingModel?.capabilities || model.capabilities,
|
|
displayName: existingModel?.displayName || model.displayName,
|
|
}
|
|
})
|
|
|
|
return {
|
|
...provider,
|
|
models: provider.persist ? updatedModels : mergedModels,
|
|
settings: provider.settings.map((setting) => {
|
|
const existingSetting = provider.persist
|
|
? undefined
|
|
: existingProvider?.settings?.find(
|
|
(x) => x.key === setting.key
|
|
)
|
|
return {
|
|
...setting,
|
|
controller_props: {
|
|
...setting.controller_props,
|
|
...(existingSetting?.controller_props || {}),
|
|
},
|
|
}
|
|
}),
|
|
api_key: existingProvider?.api_key || provider.api_key,
|
|
base_url: existingProvider?.base_url || provider.base_url,
|
|
active: existingProvider ? existingProvider?.active : true,
|
|
}
|
|
})
|
|
return {
|
|
providers: [
|
|
...updatedProviders,
|
|
...existingProviders.filter(
|
|
(e) => !updatedProviders.some((p) => p.provider === e.provider)
|
|
),
|
|
],
|
|
}
|
|
}),
|
|
updateProvider: (providerName, data) => {
|
|
set((state) => ({
|
|
providers: state.providers.map((provider) => {
|
|
if (provider.provider === providerName) {
|
|
return {
|
|
...provider,
|
|
...data,
|
|
}
|
|
}
|
|
return provider
|
|
}),
|
|
}))
|
|
},
|
|
getProviderByName: (providerName: string) => {
|
|
const provider = get().providers.find(
|
|
(provider) => provider.provider === providerName
|
|
)
|
|
|
|
return provider
|
|
},
|
|
selectModelProvider: (providerName: string, modelName: string) => {
|
|
// Find the model object
|
|
const provider = get().providers.find(
|
|
(provider) => provider.provider === providerName
|
|
)
|
|
|
|
let modelObject: Model | undefined = undefined
|
|
|
|
if (provider && provider.models) {
|
|
modelObject = provider.models.find((model) => model.id === modelName)
|
|
}
|
|
|
|
// Update state with provider name and model object
|
|
set({
|
|
selectedProvider: providerName,
|
|
selectedModel: modelObject || null,
|
|
})
|
|
|
|
return modelObject
|
|
},
|
|
deleteModel: (modelId: string) => {
|
|
set((state) => {
|
|
// Ensure deletedModels is always an array
|
|
const currentDeletedModels = Array.isArray(state.deletedModels)
|
|
? state.deletedModels
|
|
: []
|
|
|
|
return {
|
|
providers: state.providers.map((provider) => {
|
|
const models = provider.models.filter(
|
|
(model) => model.id !== modelId
|
|
)
|
|
return {
|
|
...provider,
|
|
models,
|
|
}
|
|
}),
|
|
deletedModels: [...currentDeletedModels, modelId],
|
|
}
|
|
})
|
|
},
|
|
addProvider: (provider: ModelProvider) => {
|
|
set((state) => ({
|
|
providers: [...state.providers, provider],
|
|
}))
|
|
},
|
|
deleteProvider: (providerName: string) => {
|
|
set((state) => ({
|
|
providers: state.providers.filter(
|
|
(provider) => provider.provider !== providerName
|
|
),
|
|
}))
|
|
},
|
|
}),
|
|
{
|
|
name: localStorageKey.modelProvider,
|
|
storage: createJSONStorage(() => localStorage),
|
|
migrate: (persistedState: unknown, version: number) => {
|
|
const state = persistedState as ModelProviderState & {
|
|
providers: Array<
|
|
ModelProvider & {
|
|
models: Array<
|
|
Model & {
|
|
settings?: Record<string, unknown> & {
|
|
chatTemplate?: string
|
|
chat_template?: string
|
|
}
|
|
}
|
|
>
|
|
}
|
|
>
|
|
}
|
|
|
|
if (version <= 1 && state?.providers) {
|
|
state.providers.forEach((provider) => {
|
|
// Update cont_batching description for llamacpp provider
|
|
if (provider.provider === 'llamacpp' && provider.settings) {
|
|
const contBatchingSetting = provider.settings.find(
|
|
(s) => s.key === 'cont_batching'
|
|
)
|
|
if (contBatchingSetting) {
|
|
contBatchingSetting.description =
|
|
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.'
|
|
}
|
|
}
|
|
|
|
// Migrate model settings
|
|
if (provider.models && provider.provider === 'llamacpp') {
|
|
provider.models.forEach((model) => {
|
|
if (!model.settings) model.settings = {}
|
|
|
|
// Migrate chatTemplate key to chat_template
|
|
if (model.settings.chatTemplate) {
|
|
model.settings.chat_template = model.settings.chatTemplate
|
|
delete model.settings.chatTemplate
|
|
}
|
|
|
|
// Add missing settings with defaults
|
|
if (!model.settings.chat_template) {
|
|
model.settings.chat_template = {
|
|
...modelSettings.chatTemplate,
|
|
controller_props: {
|
|
...modelSettings.chatTemplate.controller_props,
|
|
},
|
|
}
|
|
}
|
|
|
|
if (!model.settings.override_tensor_buffer_t) {
|
|
model.settings.override_tensor_buffer_t = {
|
|
...modelSettings.override_tensor_buffer_t,
|
|
controller_props: {
|
|
...modelSettings.override_tensor_buffer_t
|
|
.controller_props,
|
|
},
|
|
}
|
|
}
|
|
|
|
if (!model.settings.no_kv_offload) {
|
|
model.settings.no_kv_offload = {
|
|
...modelSettings.no_kv_offload,
|
|
controller_props: {
|
|
...modelSettings.no_kv_offload.controller_props,
|
|
},
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
if (version <= 2 && state?.providers) {
|
|
state.providers.forEach((provider) => {
|
|
// Update cont_batching description for llamacpp provider
|
|
if (provider.provider === 'llamacpp' && provider.settings) {
|
|
const contBatchingSetting = provider.settings.find(
|
|
(s) => s.key === 'cont_batching'
|
|
)
|
|
if (contBatchingSetting) {
|
|
contBatchingSetting.description =
|
|
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.'
|
|
}
|
|
}
|
|
|
|
// Migrate model settings
|
|
if (provider.models && provider.provider === 'llamacpp') {
|
|
provider.models.forEach((model) => {
|
|
if (!model.settings) model.settings = {}
|
|
|
|
if (!model.settings.batch_size) {
|
|
model.settings.batch_size = {
|
|
...modelSettings.batch_size,
|
|
controller_props: {
|
|
...modelSettings.batch_size.controller_props,
|
|
},
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
if (version <= 3 && state?.providers) {
|
|
state.providers.forEach((provider) => {
|
|
// Migrate Anthropic provider base URL and add custom headers
|
|
if (provider.provider === 'anthropic') {
|
|
if (provider.base_url === 'https://api.anthropic.com') {
|
|
provider.base_url = 'https://api.anthropic.com/v1'
|
|
}
|
|
|
|
// Update base-url in settings
|
|
if (provider.settings) {
|
|
const baseUrlSetting = provider.settings.find(
|
|
(s) => s.key === 'base-url'
|
|
)
|
|
if (
|
|
baseUrlSetting?.controller_props?.value ===
|
|
'https://api.anthropic.com'
|
|
) {
|
|
baseUrlSetting.controller_props.value =
|
|
'https://api.anthropic.com/v1'
|
|
}
|
|
if (
|
|
baseUrlSetting?.controller_props?.placeholder ===
|
|
'https://api.anthropic.com'
|
|
) {
|
|
baseUrlSetting.controller_props.placeholder =
|
|
'https://api.anthropic.com/v1'
|
|
}
|
|
}
|
|
|
|
if (!provider.custom_header) {
|
|
provider.custom_header = [
|
|
{
|
|
header: 'anthropic-version',
|
|
value: '2023-06-01',
|
|
},
|
|
{
|
|
header: 'anthropic-dangerous-direct-browser-access',
|
|
value: 'true',
|
|
},
|
|
]
|
|
}
|
|
}
|
|
|
|
if (provider.provider === 'cohere') {
|
|
if (provider.base_url === 'https://api.cohere.ai/compatibility/v1') {
|
|
provider.base_url = 'https://api.cohere.ai/v1'
|
|
}
|
|
|
|
// Update base-url in settings
|
|
if (provider.settings) {
|
|
const baseUrlSetting = provider.settings.find(
|
|
(s) => s.key === 'base-url'
|
|
)
|
|
if (
|
|
baseUrlSetting?.controller_props?.value ===
|
|
'https://api.cohere.ai/compatibility/v1'
|
|
) {
|
|
baseUrlSetting.controller_props.value =
|
|
'https://api.cohere.ai/v1'
|
|
}
|
|
if (
|
|
baseUrlSetting?.controller_props?.placeholder ===
|
|
'https://api.cohere.ai/compatibility/v1'
|
|
) {
|
|
baseUrlSetting.controller_props.placeholder =
|
|
'https://api.cohere.ai/v1'
|
|
}
|
|
}
|
|
}
|
|
|
|
})
|
|
}
|
|
|
|
return state
|
|
},
|
|
version: 4,
|
|
}
|
|
)
|
|
)
|