jan/web-app/src/hooks/useModelProvider.ts
Akarshan Biswas 1f1605bdf9
feat: Add support for overriding tensor buffer type (#6062)
* feat: Add support for overriding tensor buffer type

This commit introduces a new configuration option, `override_tensor_buffer_t`, which allows users to specify a regex for matching tensor names to override their buffer type. This is an advanced setting primarily useful for optimizing the performance of large models, particularly Mixture of Experts (MoE) models.

By overriding the tensor buffer type, users can keep critical parts of the model, like the attention layers, on the GPU while offloading other parts, such as the expert feed-forward networks, to the CPU. This can lead to significant speed improvements for massive models.

Additionally, this change refines the error message to be more specific when a model fails to load. The previous message "Failed to load llama-server" has been updated to "Failed to load model" to be more accurate.

* chore: update FE to suppoer override-tensor

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
2025-08-07 10:31:34 +05:30

310 lines
10 KiB
TypeScript

import { create } from 'zustand'
import { persist, createJSONStorage } from 'zustand/middleware'
import { localStorageKey } from '@/constants/localStorage'
import { sep } from '@tauri-apps/api/path'
import { modelSettings } from '@/lib/predefined'
type ModelProviderState = {
providers: ModelProvider[]
selectedProvider: string
selectedModel: Model | null
deletedModels: string[]
getModelBy: (modelId: string) => Model | undefined
setProviders: (providers: ModelProvider[]) => void
getProviderByName: (providerName: string) => ModelProvider | undefined
updateProvider: (providerName: string, data: Partial<ModelProvider>) => void
selectModelProvider: (
providerName: string,
modelName: string
) => Model | undefined
addProvider: (provider: ModelProvider) => void
deleteProvider: (providerName: string) => void
deleteModel: (modelId: string) => void
}
export const useModelProvider = create<ModelProviderState>()(
persist(
(set, get) => ({
providers: [],
selectedProvider: 'llamacpp',
selectedModel: null,
deletedModels: [],
getModelBy: (modelId: string) => {
const provider = get().providers.find(
(provider) => provider.provider === get().selectedProvider
)
if (!provider) return undefined
return provider.models.find((model) => model.id === modelId)
},
setProviders: (providers) =>
set((state) => {
const existingProviders = state.providers
// Filter out legacy llama.cpp provider for migration
// Can remove after a couple of releases
.filter((e) => e.provider !== 'llama.cpp')
.map((provider) => {
return {
...provider,
models: provider.models.filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
),
}
})
let legacyModels: Model[] | undefined = []
/// Cortex Migration
if (
localStorage.getItem('cortex_model_settings_migrated') !== 'true'
) {
legacyModels = state.providers.find(
(e) => e.provider === 'llama.cpp'
)?.models
localStorage.setItem('cortex_model_settings_migrated', 'true')
}
// Ensure deletedModels is always an array
const currentDeletedModels = Array.isArray(state.deletedModels)
? state.deletedModels
: []
const updatedProviders = providers.map((provider) => {
const existingProvider = existingProviders.find(
(x) => x.provider === provider.provider
)
const models = (existingProvider?.models || []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
)
const mergedModels = [
...(provider?.models ?? []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string' &&
!models.some((m) => m.id === e.id) &&
!currentDeletedModels.includes(e.id)
),
...models,
]
const updatedModels = provider.models?.map((model) => {
const settings =
(legacyModels && legacyModels?.length > 0
? legacyModels
: models
).find(
(m) => m.id.split(':').slice(0, 2).join(sep()) === model.id
)?.settings || model.settings
const existingModel = models.find((m) => m.id === model.id)
return {
...model,
settings: settings,
capabilities: existingModel?.capabilities || model.capabilities,
}
})
return {
...provider,
models: provider.persist ? updatedModels : mergedModels,
settings: provider.settings.map((setting) => {
const existingSetting = provider.persist
? undefined
: existingProvider?.settings?.find(
(x) => x.key === setting.key
)
return {
...setting,
controller_props: {
...setting.controller_props,
...(existingSetting?.controller_props || {}),
},
}
}),
api_key: existingProvider?.api_key || provider.api_key,
base_url: existingProvider?.base_url || provider.base_url,
active: existingProvider ? existingProvider?.active : true,
}
})
return {
providers: [
...updatedProviders,
...existingProviders.filter(
(e) => !updatedProviders.some((p) => p.provider === e.provider)
),
],
}
}),
updateProvider: (providerName, data) => {
set((state) => ({
providers: state.providers.map((provider) => {
if (provider.provider === providerName) {
return {
...provider,
...data,
}
}
return provider
}),
}))
},
getProviderByName: (providerName: string) => {
const provider = get().providers.find(
(provider) => provider.provider === providerName
)
return provider
},
selectModelProvider: (providerName: string, modelName: string) => {
// Find the model object
const provider = get().providers.find(
(provider) => provider.provider === providerName
)
let modelObject: Model | undefined = undefined
if (provider && provider.models) {
modelObject = provider.models.find((model) => model.id === modelName)
}
// Update state with provider name and model object
set({
selectedProvider: providerName,
selectedModel: modelObject || null,
})
return modelObject
},
deleteModel: (modelId: string) => {
set((state) => {
// Ensure deletedModels is always an array
const currentDeletedModels = Array.isArray(state.deletedModels)
? state.deletedModels
: []
return {
providers: state.providers.map((provider) => {
const models = provider.models.filter(
(model) => model.id !== modelId
)
return {
...provider,
models,
}
}),
deletedModels: [...currentDeletedModels, modelId],
}
})
},
addProvider: (provider: ModelProvider) => {
set((state) => ({
providers: [...state.providers, provider],
}))
},
deleteProvider: (providerName: string) => {
set((state) => ({
providers: state.providers.filter(
(provider) => provider.provider !== providerName
),
}))
},
}),
{
name: localStorageKey.modelProvider,
storage: createJSONStorage(() => localStorage),
migrate: (persistedState: unknown, version: number) => {
const state = persistedState as ModelProviderState & {
providers: Array<
ModelProvider & {
models: Array<
Model & {
settings?: Record<string, unknown> & {
chatTemplate?: string
chat_template?: string
}
}
>
}
>
}
// Migration for cont_batching description update (version 0 -> 1)
if (version === 0 && state?.providers) {
state.providers = state.providers.map((provider) => {
if (provider.provider === 'llamacpp' && provider.settings) {
provider.settings = provider.settings.map((setting) => {
if (setting.key === 'cont_batching') {
return {
...setting,
description:
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.',
}
}
return setting
})
}
return provider
})
}
// Migration for chatTemplate key to chat_template (version 1 -> 2)
if (version === 1 && state?.providers) {
state.providers.forEach((provider) => {
if (provider.models) {
provider.models.forEach((model) => {
// Initialize settings if it doesn't exist
if (!model.settings) {
model.settings = {}
}
// Migrate chatTemplate key to chat_template
if (model.settings.chatTemplate) {
model.settings.chat_template = model.settings.chatTemplate
delete model.settings.chatTemplate
}
// Add missing chat_template setting if it doesn't exist
if (!model.settings.chat_template) {
model.settings.chat_template = {
...modelSettings.chatTemplate,
controller_props: {
...modelSettings.chatTemplate.controller_props,
},
}
}
})
}
})
}
// Migration for override_tensor_buffer_type key (version 2 -> 3)
if (version === 2 && state?.providers) {
state.providers.forEach((provider) => {
if (provider.models) {
provider.models.forEach((model) => {
// Initialize settings if it doesn't exist
if (!model.settings) {
model.settings = {}
}
// Add missing override_tensor_buffer_type setting if it doesn't exist
if (!model.settings.override_tensor_buffer_t) {
model.settings.override_tensor_buffer_t = {
...modelSettings.override_tensor_buffer_t,
controller_props: {
...modelSettings.override_tensor_buffer_t
.controller_props,
},
}
}
})
}
})
}
return state
},
version: 3,
}
)
)