fix: support load model configurations (#5843)

* fix: support load model configurations

* chore: remove log

* chore: sampling params add from send completion

* chore: remove comment

* chore: remove comment on predefined file

* chore: update test model service
This commit is contained in:
Faisal Amir 2025-07-22 19:52:12 +07:00 committed by GitHub
parent 7b3b6cc8be
commit 1d443e1f7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 85 additions and 13 deletions

View File

@ -231,7 +231,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads a model into memory
*/
abstract load(modelId: string): Promise<SessionInfo>
abstract load(modelId: string, settings?: any): Promise<SessionInfo>
/**
* Unloads a model from memory

View File

@ -87,8 +87,10 @@ export function ModelSetting({
...(params as unknown as object),
})
// Call debounced stopModel after updating the model
debouncedStopModel(model.id)
// Call debounced stopModel only when updating ctx_len or ngl
if (key === 'ctx_len' || key === 'ngl') {
debouncedStopModel(model.id)
}
}
}
@ -106,7 +108,9 @@ export function ModelSetting({
</SheetTrigger>
<SheetContent className="h-[calc(100%-8px)] top-1 right-1 rounded-e-md overflow-y-auto">
<SheetHeader>
<SheetTitle>{t('common:modelSettings.title', { modelId: model.id })}</SheetTitle>
<SheetTitle>
{t('common:modelSettings.title', { modelId: model.id })}
</SheetTitle>
<SheetDescription>
{t('common:modelSettings.description')}
</SheetDescription>

View File

@ -261,6 +261,25 @@ export const useChat = () => {
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id
)
const modelSettings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings)
.filter(
([key, value]) =>
key !== 'ctx_len' &&
key !== 'ngl' &&
value.controller_props?.value !== undefined &&
value.controller_props?.value !== null &&
value.controller_props?.value !== ''
)
.map(([key, value]) => [key, value.controller_props?.value])
)
: undefined
const completion = await sendCompletion(
activeThread,
activeProvider,
@ -268,7 +287,10 @@ export const useChat = () => {
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object>
{
...modelSettings,
...currentAssistant.parameters,
} as unknown as Record<string, object>
)
if (!completion) throw new Error('No completion received')

View File

@ -1,4 +1,5 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
import {
fetchModels,
fetchModelCatalog,
@ -10,9 +11,8 @@ import {
stopModel,
stopAllModels,
startModel,
configurePullOptions,
} from '../models'
import { EngineManager } from '@janhq/core'
import { EngineManager, Model } from '@janhq/core'
// Mock EngineManager
vi.mock('@janhq/core', () => ({
@ -118,7 +118,7 @@ describe('models service', () => {
settings: [{ key: 'temperature', value: 0.7 }],
}
await updateModel(model)
await updateModel(model as any)
expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings)
})
@ -209,7 +209,14 @@ describe('models service', () => {
describe('startModel', () => {
it('should start model successfully', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'
const mockSession = { id: 'session1' }
@ -221,11 +228,21 @@ describe('models service', () => {
const result = await startModel(provider, model)
expect(result).toEqual(mockSession)
expect(mockEngine.load).toHaveBeenCalledWith(model)
expect(mockEngine.load).toHaveBeenCalledWith(model, {
ctx_size: 4096,
n_gpu_layers: 32,
})
})
it('should handle start model error', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'
const error = new Error('Failed to start model')
@ -237,7 +254,14 @@ describe('models service', () => {
await expect(startModel(provider, model)).rejects.toThrow(error)
})
it('should not load model again', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'
mockEngine.getLoadedModels.mockResolvedValue({

View File

@ -150,7 +150,29 @@ export const startModel = async (
if (!engine) return undefined
if ((await engine.getLoadedModels()).includes(model)) return undefined
return engine.load(model).catch((error) => {
// Find the model configuration to get settings
const modelConfig = provider.models.find((m) => m.id === model)
// Key mapping function to transform setting keys
const mapSettingKey = (key: string): string => {
const keyMappings: Record<string, string> = {
ctx_len: 'ctx_size',
ngl: 'n_gpu_layers',
}
return keyMappings[key] || key
}
const settings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings).map(([key, value]) => [
mapSettingKey(key),
value.controller_props?.value,
])
)
: undefined
return engine.load(model, settings).catch((error) => {
console.error(
`Failed to start model ${model} for provider ${provider.provider}:`,
error