import { create } from 'zustand' import { persist, createJSONStorage } from 'zustand/middleware' import { localStorageKey } from '@/constants/localStorage' import { sep } from '@tauri-apps/api/path' import { modelSettings } from '@/lib/predefined' type ModelProviderState = { providers: ModelProvider[] selectedProvider: string selectedModel: Model | null deletedModels: string[] getModelBy: (modelId: string) => Model | undefined setProviders: (providers: ModelProvider[]) => void getProviderByName: (providerName: string) => ModelProvider | undefined updateProvider: (providerName: string, data: Partial) => void selectModelProvider: ( providerName: string, modelName: string ) => Model | undefined addProvider: (provider: ModelProvider) => void deleteProvider: (providerName: string) => void deleteModel: (modelId: string) => void } export const useModelProvider = create()( persist( (set, get) => ({ providers: [], selectedProvider: 'llamacpp', selectedModel: null, deletedModels: [], getModelBy: (modelId: string) => { const provider = get().providers.find( (provider) => provider.provider === get().selectedProvider ) if (!provider) return undefined return provider.models.find((model) => model.id === modelId) }, setProviders: (providers) => set((state) => { const existingProviders = state.providers // Filter out legacy llama.cpp provider for migration // Can remove after a couple of releases .filter((e) => e.provider !== 'llama.cpp') .map((provider) => { return { ...provider, models: provider.models.filter( (e) => ('id' in e || 'model' in e) && typeof (e.id ?? e.model) === 'string' ), } }) let legacyModels: Model[] | undefined = [] /// Cortex Migration if ( localStorage.getItem('cortex_model_settings_migrated') !== 'true' ) { legacyModels = state.providers.find( (e) => e.provider === 'llama.cpp' )?.models localStorage.setItem('cortex_model_settings_migrated', 'true') } // Ensure deletedModels is always an array const currentDeletedModels = Array.isArray(state.deletedModels) ? state.deletedModels : [] const updatedProviders = providers.map((provider) => { const existingProvider = existingProviders.find( (x) => x.provider === provider.provider ) const models = (existingProvider?.models || []).filter( (e) => ('id' in e || 'model' in e) && typeof (e.id ?? e.model) === 'string' ) const mergedModels = [ ...(provider?.models ?? []).filter( (e) => ('id' in e || 'model' in e) && typeof (e.id ?? e.model) === 'string' && !models.some((m) => m.id === e.id) && !currentDeletedModels.includes(e.id) ), ...models, ] const updatedModels = provider.models?.map((model) => { const settings = (legacyModels && legacyModels?.length > 0 ? legacyModels : models ).find( (m) => m.id.split(':').slice(0, 2).join(sep()) === model.id )?.settings || model.settings const existingModel = models.find((m) => m.id === model.id) return { ...model, settings: settings, capabilities: existingModel?.capabilities || model.capabilities, } }) return { ...provider, models: provider.persist ? updatedModels : mergedModels, settings: provider.settings.map((setting) => { const existingSetting = provider.persist ? undefined : existingProvider?.settings?.find( (x) => x.key === setting.key ) return { ...setting, controller_props: { ...setting.controller_props, ...(existingSetting?.controller_props || {}), }, } }), api_key: existingProvider?.api_key || provider.api_key, base_url: existingProvider?.base_url || provider.base_url, active: existingProvider ? existingProvider?.active : true, } }) return { providers: [ ...updatedProviders, ...existingProviders.filter( (e) => !updatedProviders.some((p) => p.provider === e.provider) ), ], } }), updateProvider: (providerName, data) => { set((state) => ({ providers: state.providers.map((provider) => { if (provider.provider === providerName) { return { ...provider, ...data, } } return provider }), })) }, getProviderByName: (providerName: string) => { const provider = get().providers.find( (provider) => provider.provider === providerName ) return provider }, selectModelProvider: (providerName: string, modelName: string) => { // Find the model object const provider = get().providers.find( (provider) => provider.provider === providerName ) let modelObject: Model | undefined = undefined if (provider && provider.models) { modelObject = provider.models.find((model) => model.id === modelName) } // Update state with provider name and model object set({ selectedProvider: providerName, selectedModel: modelObject || null, }) return modelObject }, deleteModel: (modelId: string) => { set((state) => { // Ensure deletedModels is always an array const currentDeletedModels = Array.isArray(state.deletedModels) ? state.deletedModels : [] return { providers: state.providers.map((provider) => { const models = provider.models.filter( (model) => model.id !== modelId ) return { ...provider, models, } }), deletedModels: [...currentDeletedModels, modelId], } }) }, addProvider: (provider: ModelProvider) => { set((state) => ({ providers: [...state.providers, provider], })) }, deleteProvider: (providerName: string) => { set((state) => ({ providers: state.providers.filter( (provider) => provider.provider !== providerName ), })) }, }), { name: localStorageKey.modelProvider, storage: createJSONStorage(() => localStorage), migrate: (persistedState: unknown, version: number) => { const state = persistedState as ModelProviderState & { providers: Array< ModelProvider & { models: Array< Model & { settings?: Record & { chatTemplate?: string chat_template?: string } } > } > } if (version === 0 && state?.providers) { state.providers.forEach((provider) => { // Update cont_batching description for llamacpp provider if (provider.provider === 'llamacpp' && provider.settings) { const contBatchingSetting = provider.settings.find( (s) => s.key === 'cont_batching' ) if (contBatchingSetting) { contBatchingSetting.description = 'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.' } } // Migrate model settings if (provider.models && provider.provider === 'llamacpp') { provider.models.forEach((model) => { if (!model.settings) model.settings = {} // Migrate chatTemplate key to chat_template if (model.settings.chatTemplate) { model.settings.chat_template = model.settings.chatTemplate delete model.settings.chatTemplate } // Add missing settings with defaults if (!model.settings.chat_template) { model.settings.chat_template = { ...modelSettings.chatTemplate, controller_props: { ...modelSettings.chatTemplate.controller_props, }, } } if (!model.settings.override_tensor_buffer_t) { model.settings.override_tensor_buffer_t = { ...modelSettings.override_tensor_buffer_t, controller_props: { ...modelSettings.override_tensor_buffer_t .controller_props, }, } } }) } }) } return state }, version: 1, } ) )