fix: support load model configurations (#5843)
* fix: support load model configurations * chore: remove log * chore: sampling params add from send completion * chore: remove comment * chore: remove comment on predefined file * chore: update test model service
This commit is contained in:
parent
7b3b6cc8be
commit
1d443e1f7d
@ -231,7 +231,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Loads a model into memory
|
||||
*/
|
||||
abstract load(modelId: string): Promise<SessionInfo>
|
||||
abstract load(modelId: string, settings?: any): Promise<SessionInfo>
|
||||
|
||||
/**
|
||||
* Unloads a model from memory
|
||||
|
||||
@ -87,8 +87,10 @@ export function ModelSetting({
|
||||
...(params as unknown as object),
|
||||
})
|
||||
|
||||
// Call debounced stopModel after updating the model
|
||||
debouncedStopModel(model.id)
|
||||
// Call debounced stopModel only when updating ctx_len or ngl
|
||||
if (key === 'ctx_len' || key === 'ngl') {
|
||||
debouncedStopModel(model.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +108,9 @@ export function ModelSetting({
|
||||
</SheetTrigger>
|
||||
<SheetContent className="h-[calc(100%-8px)] top-1 right-1 rounded-e-md overflow-y-auto">
|
||||
<SheetHeader>
|
||||
<SheetTitle>{t('common:modelSettings.title', { modelId: model.id })}</SheetTitle>
|
||||
<SheetTitle>
|
||||
{t('common:modelSettings.title', { modelId: model.id })}
|
||||
</SheetTitle>
|
||||
<SheetDescription>
|
||||
{t('common:modelSettings.description')}
|
||||
</SheetDescription>
|
||||
|
||||
@ -261,6 +261,25 @@ export const useChat = () => {
|
||||
!abortController.signal.aborted &&
|
||||
activeProvider
|
||||
) {
|
||||
const modelConfig = activeProvider.models.find(
|
||||
(m) => m.id === selectedModel?.id
|
||||
)
|
||||
|
||||
const modelSettings = modelConfig?.settings
|
||||
? Object.fromEntries(
|
||||
Object.entries(modelConfig.settings)
|
||||
.filter(
|
||||
([key, value]) =>
|
||||
key !== 'ctx_len' &&
|
||||
key !== 'ngl' &&
|
||||
value.controller_props?.value !== undefined &&
|
||||
value.controller_props?.value !== null &&
|
||||
value.controller_props?.value !== ''
|
||||
)
|
||||
.map(([key, value]) => [key, value.controller_props?.value])
|
||||
)
|
||||
: undefined
|
||||
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
activeProvider,
|
||||
@ -268,7 +287,10 @@ export const useChat = () => {
|
||||
abortController,
|
||||
availableTools,
|
||||
currentAssistant.parameters?.stream === false ? false : true,
|
||||
currentAssistant.parameters as unknown as Record<string, object>
|
||||
{
|
||||
...modelSettings,
|
||||
...currentAssistant.parameters,
|
||||
} as unknown as Record<string, object>
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
|
||||
import {
|
||||
fetchModels,
|
||||
fetchModelCatalog,
|
||||
@ -10,9 +11,8 @@ import {
|
||||
stopModel,
|
||||
stopAllModels,
|
||||
startModel,
|
||||
configurePullOptions,
|
||||
} from '../models'
|
||||
import { EngineManager } from '@janhq/core'
|
||||
import { EngineManager, Model } from '@janhq/core'
|
||||
|
||||
// Mock EngineManager
|
||||
vi.mock('@janhq/core', () => ({
|
||||
@ -118,7 +118,7 @@ describe('models service', () => {
|
||||
settings: [{ key: 'temperature', value: 0.7 }],
|
||||
}
|
||||
|
||||
await updateModel(model)
|
||||
await updateModel(model as any)
|
||||
|
||||
expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings)
|
||||
})
|
||||
@ -209,7 +209,14 @@ describe('models service', () => {
|
||||
|
||||
describe('startModel', () => {
|
||||
it('should start model successfully', async () => {
|
||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
||||
const mockSettings = {
|
||||
ctx_len: { controller_props: { value: 4096 } },
|
||||
ngl: { controller_props: { value: 32 } },
|
||||
}
|
||||
const provider = {
|
||||
provider: 'openai',
|
||||
models: [{ id: 'model1', settings: mockSettings }],
|
||||
} as any
|
||||
const model = 'model1'
|
||||
const mockSession = { id: 'session1' }
|
||||
|
||||
@ -221,11 +228,21 @@ describe('models service', () => {
|
||||
const result = await startModel(provider, model)
|
||||
|
||||
expect(result).toEqual(mockSession)
|
||||
expect(mockEngine.load).toHaveBeenCalledWith(model)
|
||||
expect(mockEngine.load).toHaveBeenCalledWith(model, {
|
||||
ctx_size: 4096,
|
||||
n_gpu_layers: 32,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle start model error', async () => {
|
||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
||||
const mockSettings = {
|
||||
ctx_len: { controller_props: { value: 4096 } },
|
||||
ngl: { controller_props: { value: 32 } },
|
||||
}
|
||||
const provider = {
|
||||
provider: 'openai',
|
||||
models: [{ id: 'model1', settings: mockSettings }],
|
||||
} as any
|
||||
const model = 'model1'
|
||||
const error = new Error('Failed to start model')
|
||||
|
||||
@ -237,7 +254,14 @@ describe('models service', () => {
|
||||
await expect(startModel(provider, model)).rejects.toThrow(error)
|
||||
})
|
||||
it('should not load model again', async () => {
|
||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
||||
const mockSettings = {
|
||||
ctx_len: { controller_props: { value: 4096 } },
|
||||
ngl: { controller_props: { value: 32 } },
|
||||
}
|
||||
const provider = {
|
||||
provider: 'openai',
|
||||
models: [{ id: 'model1', settings: mockSettings }],
|
||||
} as any
|
||||
const model = 'model1'
|
||||
|
||||
mockEngine.getLoadedModels.mockResolvedValue({
|
||||
|
||||
@ -150,7 +150,29 @@ export const startModel = async (
|
||||
if (!engine) return undefined
|
||||
|
||||
if ((await engine.getLoadedModels()).includes(model)) return undefined
|
||||
return engine.load(model).catch((error) => {
|
||||
|
||||
// Find the model configuration to get settings
|
||||
const modelConfig = provider.models.find((m) => m.id === model)
|
||||
|
||||
// Key mapping function to transform setting keys
|
||||
const mapSettingKey = (key: string): string => {
|
||||
const keyMappings: Record<string, string> = {
|
||||
ctx_len: 'ctx_size',
|
||||
ngl: 'n_gpu_layers',
|
||||
}
|
||||
return keyMappings[key] || key
|
||||
}
|
||||
|
||||
const settings = modelConfig?.settings
|
||||
? Object.fromEntries(
|
||||
Object.entries(modelConfig.settings).map(([key, value]) => [
|
||||
mapSettingKey(key),
|
||||
value.controller_props?.value,
|
||||
])
|
||||
)
|
||||
: undefined
|
||||
|
||||
return engine.load(model, settings).catch((error) => {
|
||||
console.error(
|
||||
`Failed to start model ${model} for provider ${provider.provider}:`,
|
||||
error
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user