From 1d443e1f7d24a8cbd35c504423628c4528bad37a Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Tue, 22 Jul 2025 19:52:12 +0700 Subject: [PATCH] fix: support load model configurations (#5843) * fix: support load model configurations * chore: remove log * chore: sampling params add from send completion * chore: remove comment * chore: remove comment on predefined file * chore: update test model service --- .../browser/extensions/engines/AIEngine.ts | 2 +- web-app/src/containers/ModelSetting.tsx | 10 +++-- web-app/src/hooks/useChat.ts | 24 +++++++++++- web-app/src/services/__tests__/models.test.ts | 38 +++++++++++++++---- web-app/src/services/models.ts | 24 +++++++++++- 5 files changed, 85 insertions(+), 13 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 90ce0543c..a23e8c45e 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -231,7 +231,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads a model into memory */ - abstract load(modelId: string): Promise + abstract load(modelId: string, settings?: any): Promise /** * Unloads a model from memory diff --git a/web-app/src/containers/ModelSetting.tsx b/web-app/src/containers/ModelSetting.tsx index bc5e810e1..726fdee71 100644 --- a/web-app/src/containers/ModelSetting.tsx +++ b/web-app/src/containers/ModelSetting.tsx @@ -87,8 +87,10 @@ export function ModelSetting({ ...(params as unknown as object), }) - // Call debounced stopModel after updating the model - debouncedStopModel(model.id) + // Call debounced stopModel only when updating ctx_len or ngl + if (key === 'ctx_len' || key === 'ngl') { + debouncedStopModel(model.id) + } } } @@ -106,7 +108,9 @@ export function ModelSetting({ - {t('common:modelSettings.title', { modelId: model.id })} + + {t('common:modelSettings.title', { modelId: model.id })} + {t('common:modelSettings.description')} diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 4a841846d..c8c100243 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -261,6 +261,25 @@ export const useChat = () => { !abortController.signal.aborted && activeProvider ) { + const modelConfig = activeProvider.models.find( + (m) => m.id === selectedModel?.id + ) + + const modelSettings = modelConfig?.settings + ? Object.fromEntries( + Object.entries(modelConfig.settings) + .filter( + ([key, value]) => + key !== 'ctx_len' && + key !== 'ngl' && + value.controller_props?.value !== undefined && + value.controller_props?.value !== null && + value.controller_props?.value !== '' + ) + .map(([key, value]) => [key, value.controller_props?.value]) + ) + : undefined + const completion = await sendCompletion( activeThread, activeProvider, @@ -268,7 +287,10 @@ export const useChat = () => { abortController, availableTools, currentAssistant.parameters?.stream === false ? false : true, - currentAssistant.parameters as unknown as Record + { + ...modelSettings, + ...currentAssistant.parameters, + } as unknown as Record ) if (!completion) throw new Error('No completion received') diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index 2714ac930..d5e38b034 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -1,4 +1,5 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' + import { fetchModels, fetchModelCatalog, @@ -10,9 +11,8 @@ import { stopModel, stopAllModels, startModel, - configurePullOptions, } from '../models' -import { EngineManager } from '@janhq/core' +import { EngineManager, Model } from '@janhq/core' // Mock EngineManager vi.mock('@janhq/core', () => ({ @@ -118,7 +118,7 @@ describe('models service', () => { settings: [{ key: 'temperature', value: 0.7 }], } - await updateModel(model) + await updateModel(model as any) expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings) }) @@ -209,7 +209,14 @@ describe('models service', () => { describe('startModel', () => { it('should start model successfully', async () => { - const provider = { provider: 'openai', models: [] } as ProviderObject + const mockSettings = { + ctx_len: { controller_props: { value: 4096 } }, + ngl: { controller_props: { value: 32 } }, + } + const provider = { + provider: 'openai', + models: [{ id: 'model1', settings: mockSettings }], + } as any const model = 'model1' const mockSession = { id: 'session1' } @@ -221,11 +228,21 @@ describe('models service', () => { const result = await startModel(provider, model) expect(result).toEqual(mockSession) - expect(mockEngine.load).toHaveBeenCalledWith(model) + expect(mockEngine.load).toHaveBeenCalledWith(model, { + ctx_size: 4096, + n_gpu_layers: 32, + }) }) it('should handle start model error', async () => { - const provider = { provider: 'openai', models: [] } as ProviderObject + const mockSettings = { + ctx_len: { controller_props: { value: 4096 } }, + ngl: { controller_props: { value: 32 } }, + } + const provider = { + provider: 'openai', + models: [{ id: 'model1', settings: mockSettings }], + } as any const model = 'model1' const error = new Error('Failed to start model') @@ -237,7 +254,14 @@ describe('models service', () => { await expect(startModel(provider, model)).rejects.toThrow(error) }) it('should not load model again', async () => { - const provider = { provider: 'openai', models: [] } as ProviderObject + const mockSettings = { + ctx_len: { controller_props: { value: 4096 } }, + ngl: { controller_props: { value: 32 } }, + } + const provider = { + provider: 'openai', + models: [{ id: 'model1', settings: mockSettings }], + } as any const model = 'model1' mockEngine.getLoadedModels.mockResolvedValue({ diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index fb76278b6..f38afa06f 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -150,7 +150,29 @@ export const startModel = async ( if (!engine) return undefined if ((await engine.getLoadedModels()).includes(model)) return undefined - return engine.load(model).catch((error) => { + + // Find the model configuration to get settings + const modelConfig = provider.models.find((m) => m.id === model) + + // Key mapping function to transform setting keys + const mapSettingKey = (key: string): string => { + const keyMappings: Record = { + ctx_len: 'ctx_size', + ngl: 'n_gpu_layers', + } + return keyMappings[key] || key + } + + const settings = modelConfig?.settings + ? Object.fromEntries( + Object.entries(modelConfig.settings).map(([key, value]) => [ + mapSettingKey(key), + value.controller_props?.value, + ]) + ) + : undefined + + return engine.load(model, settings).catch((error) => { console.error( `Failed to start model ${model} for provider ${provider.provider}:`, error