jan/web-app/src/hooks/useModelProvider.ts
Akarshan Biswas 706dad2687
feat: Add support for llamacpp MoE offloading setting (#6748)
* feat: Add support for llamacpp MoE offloading setting

Introduces the n_cpu_moe configuration setting for the llamacpp provider. This allows users to specify the number of Mixture of Experts (MoE) layers whose weights should be offloaded to the CPU via the --n-cpu-moe flag in llama.cpp.

This is useful for running large MoE models by balancing resource usage, for example, by keeping attention on the GPU and offloading expert FFNs to the CPU.

The changes include:

 - Updating the llamacpp-extension to accept and pass the --n-cpu-moe argument.

 - Adding the input field to the Model Settings UI (ModelSetting.tsx).

 - Including model setting migration logic and bumping the store version to 4.

* remove unused import

* feat: add cpu-moe boolean flag

* chore: remove unused migration cont_batching

* chore: fix migration delete old key and add new one

* chore: fix migration

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
2025-10-07 19:37:58 +05:30

432 lines
15 KiB
TypeScript

import { create } from 'zustand'
import { persist, createJSONStorage } from 'zustand/middleware'
import { localStorageKey } from '@/constants/localStorage'
import { getServiceHub } from '@/hooks/useServiceHub'
import { modelSettings } from '@/lib/predefined'
type ModelProviderState = {
providers: ModelProvider[]
selectedProvider: string
selectedModel: Model | null
deletedModels: string[]
getModelBy: (modelId: string) => Model | undefined
setProviders: (providers: ModelProvider[]) => void
getProviderByName: (providerName: string) => ModelProvider | undefined
updateProvider: (providerName: string, data: Partial<ModelProvider>) => void
selectModelProvider: (
providerName: string,
modelName: string
) => Model | undefined
addProvider: (provider: ModelProvider) => void
deleteProvider: (providerName: string) => void
deleteModel: (modelId: string) => void
}
export const useModelProvider = create<ModelProviderState>()(
persist(
(set, get) => ({
providers: [],
selectedProvider: 'llamacpp',
selectedModel: null,
deletedModels: [],
getModelBy: (modelId: string) => {
const provider = get().providers.find(
(provider) => provider.provider === get().selectedProvider
)
if (!provider) return undefined
return provider.models.find((model) => model.id === modelId)
},
setProviders: (providers) =>
set((state) => {
const existingProviders = state.providers
// Filter out legacy llama.cpp provider for migration
// Can remove after a couple of releases
.filter((e) => e.provider !== 'llama.cpp')
.map((provider) => {
return {
...provider,
models: provider.models.filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
),
}
})
let legacyModels: Model[] | undefined = []
/// Cortex Migration
if (
localStorage.getItem('cortex_model_settings_migrated') !== 'true'
) {
legacyModels = state.providers.find(
(e) => e.provider === 'llama.cpp'
)?.models
localStorage.setItem('cortex_model_settings_migrated', 'true')
}
// Ensure deletedModels is always an array
const currentDeletedModels = Array.isArray(state.deletedModels)
? state.deletedModels
: []
const updatedProviders = providers.map((provider) => {
const existingProvider = existingProviders.find(
(x) => x.provider === provider.provider
)
const models = (existingProvider?.models || []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
)
const mergedModels = [
...(provider?.models ?? []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string' &&
!models.some((m) => m.id === e.id) &&
!currentDeletedModels.includes(e.id)
),
...models,
]
const updatedModels = provider.models?.map((model) => {
const settings =
(legacyModels && legacyModels?.length > 0
? legacyModels
: models
).find(
(m) =>
m.id
.split(':')
.slice(0, 2)
.join(getServiceHub().path().sep()) === model.id
)?.settings || model.settings
const existingModel = models.find((m) => m.id === model.id)
return {
...model,
settings: settings,
capabilities: existingModel?.capabilities || model.capabilities,
displayName: existingModel?.displayName || model.displayName,
}
})
return {
...provider,
models: provider.persist ? updatedModels : mergedModels,
settings: provider.settings.map((setting) => {
const existingSetting = provider.persist
? undefined
: existingProvider?.settings?.find(
(x) => x.key === setting.key
)
return {
...setting,
controller_props: {
...setting.controller_props,
...(existingSetting?.controller_props || {}),
},
}
}),
api_key: existingProvider?.api_key || provider.api_key,
base_url: existingProvider?.base_url || provider.base_url,
active: existingProvider ? existingProvider?.active : true,
}
})
return {
providers: [
...updatedProviders,
...existingProviders.filter(
(e) => !updatedProviders.some((p) => p.provider === e.provider)
),
],
}
}),
updateProvider: (providerName, data) => {
set((state) => ({
providers: state.providers.map((provider) => {
if (provider.provider === providerName) {
return {
...provider,
...data,
}
}
return provider
}),
}))
},
getProviderByName: (providerName: string) => {
const provider = get().providers.find(
(provider) => provider.provider === providerName
)
return provider
},
selectModelProvider: (providerName: string, modelName: string) => {
// Find the model object
const provider = get().providers.find(
(provider) => provider.provider === providerName
)
let modelObject: Model | undefined = undefined
if (provider && provider.models) {
modelObject = provider.models.find((model) => model.id === modelName)
}
// Update state with provider name and model object
set({
selectedProvider: providerName,
selectedModel: modelObject || null,
})
return modelObject
},
deleteModel: (modelId: string) => {
set((state) => {
// Ensure deletedModels is always an array
const currentDeletedModels = Array.isArray(state.deletedModels)
? state.deletedModels
: []
return {
providers: state.providers.map((provider) => {
const models = provider.models.filter(
(model) => model.id !== modelId
)
return {
...provider,
models,
}
}),
deletedModels: [...currentDeletedModels, modelId],
}
})
},
addProvider: (provider: ModelProvider) => {
set((state) => ({
providers: [...state.providers, provider],
}))
},
deleteProvider: (providerName: string) => {
set((state) => ({
providers: state.providers.filter(
(provider) => provider.provider !== providerName
),
}))
},
}),
{
name: localStorageKey.modelProvider,
storage: createJSONStorage(() => localStorage),
migrate: (persistedState: unknown, version: number) => {
const state = persistedState as ModelProviderState & {
providers: Array<
ModelProvider & {
models: Array<
Model & {
settings?: Record<string, unknown> & {
chatTemplate?: string
chat_template?: string
}
}
>
}
>
}
if (version <= 1 && state?.providers) {
state.providers.forEach((provider) => {
// Update cont_batching description for llamacpp provider
if (provider.provider === 'llamacpp' && provider.settings) {
const contBatchingSetting = provider.settings.find(
(s) => s.key === 'cont_batching'
)
if (contBatchingSetting) {
contBatchingSetting.description =
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.'
}
}
// Migrate model settings
if (provider.models && provider.provider === 'llamacpp') {
provider.models.forEach((model) => {
if (!model.settings) model.settings = {}
// Migrate chatTemplate key to chat_template
if (model.settings.chatTemplate) {
model.settings.chat_template = model.settings.chatTemplate
delete model.settings.chatTemplate
}
// Add missing settings with defaults
if (!model.settings.chat_template) {
model.settings.chat_template = {
...modelSettings.chatTemplate,
controller_props: {
...modelSettings.chatTemplate.controller_props,
},
}
}
if (!model.settings.override_tensor_buffer_t) {
model.settings.override_tensor_buffer_t = {
...modelSettings.override_tensor_buffer_t,
controller_props: {
...modelSettings.override_tensor_buffer_t
.controller_props,
},
}
}
if (!model.settings.no_kv_offload) {
model.settings.no_kv_offload = {
...modelSettings.no_kv_offload,
controller_props: {
...modelSettings.no_kv_offload.controller_props,
},
}
}
})
}
})
}
if (version <= 2 && state?.providers) {
state.providers.forEach((provider) => {
// Update cont_batching description for llamacpp provider
if (provider.provider === 'llamacpp' && provider.settings) {
const contBatchingSetting = provider.settings.find(
(s) => s.key === 'cont_batching'
)
if (contBatchingSetting) {
contBatchingSetting.description =
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.'
}
}
// Migrate model settings
if (provider.models && provider.provider === 'llamacpp') {
provider.models.forEach((model) => {
if (!model.settings) model.settings = {}
if (!model.settings.batch_size) {
model.settings.batch_size = {
...modelSettings.batch_size,
controller_props: {
...modelSettings.batch_size.controller_props,
},
}
}
})
}
})
}
if (version <= 3 && state?.providers) {
state.providers.forEach((provider) => {
// Migrate Anthropic provider base URL and add custom headers
if (provider.provider === 'anthropic') {
if (provider.base_url === 'https://api.anthropic.com') {
provider.base_url = 'https://api.anthropic.com/v1'
}
// Update base-url in settings
if (provider.settings) {
const baseUrlSetting = provider.settings.find(
(s) => s.key === 'base-url'
)
if (
baseUrlSetting?.controller_props?.value ===
'https://api.anthropic.com'
) {
baseUrlSetting.controller_props.value =
'https://api.anthropic.com/v1'
}
if (
baseUrlSetting?.controller_props?.placeholder ===
'https://api.anthropic.com'
) {
baseUrlSetting.controller_props.placeholder =
'https://api.anthropic.com/v1'
}
}
if (!provider.custom_header) {
provider.custom_header = [
{
header: 'anthropic-version',
value: '2023-06-01',
},
{
header: 'anthropic-dangerous-direct-browser-access',
value: 'true',
},
]
}
}
if (provider.provider === 'cohere') {
if (
provider.base_url === 'https://api.cohere.ai/compatibility/v1'
) {
provider.base_url = 'https://api.cohere.ai/v1'
}
// Update base-url in settings
if (provider.settings) {
const baseUrlSetting = provider.settings.find(
(s) => s.key === 'base-url'
)
if (
baseUrlSetting?.controller_props?.value ===
'https://api.cohere.ai/compatibility/v1'
) {
baseUrlSetting.controller_props.value =
'https://api.cohere.ai/v1'
}
if (
baseUrlSetting?.controller_props?.placeholder ===
'https://api.cohere.ai/compatibility/v1'
) {
baseUrlSetting.controller_props.placeholder =
'https://api.cohere.ai/v1'
}
}
}
})
}
if (version <= 4 && state?.providers) {
state.providers.forEach((provider) => {
// Migrate model settings
if (provider.models && provider.provider === 'llamacpp') {
provider.models.forEach((model) => {
if (!model.settings) model.settings = {}
if (!model.settings.cpu_moe) {
model.settings.cpu_moe = {
...modelSettings.cpu_moe,
controller_props: {
...modelSettings.cpu_moe.controller_props,
},
}
}
if (!model.settings.n_cpu_moe) {
model.settings.n_cpu_moe = {
...modelSettings.n_cpu_moe,
controller_props: {
...modelSettings.n_cpu_moe.controller_props,
},
}
}
})
}
})
}
return state
},
version: 5,
}
)
)