fix: support load model configurations (#5843)
* fix: support load model configurations * chore: remove log * chore: sampling params add from send completion * chore: remove comment * chore: remove comment on predefined file * chore: update test model service
This commit is contained in:
parent
7b3b6cc8be
commit
1d443e1f7d
@ -231,7 +231,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Loads a model into memory
|
* Loads a model into memory
|
||||||
*/
|
*/
|
||||||
abstract load(modelId: string): Promise<SessionInfo>
|
abstract load(modelId: string, settings?: any): Promise<SessionInfo>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads a model from memory
|
* Unloads a model from memory
|
||||||
|
|||||||
@ -87,10 +87,12 @@ export function ModelSetting({
|
|||||||
...(params as unknown as object),
|
...(params as unknown as object),
|
||||||
})
|
})
|
||||||
|
|
||||||
// Call debounced stopModel after updating the model
|
// Call debounced stopModel only when updating ctx_len or ngl
|
||||||
|
if (key === 'ctx_len' || key === 'ngl') {
|
||||||
debouncedStopModel(model.id)
|
debouncedStopModel(model.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Sheet>
|
<Sheet>
|
||||||
@ -106,7 +108,9 @@ export function ModelSetting({
|
|||||||
</SheetTrigger>
|
</SheetTrigger>
|
||||||
<SheetContent className="h-[calc(100%-8px)] top-1 right-1 rounded-e-md overflow-y-auto">
|
<SheetContent className="h-[calc(100%-8px)] top-1 right-1 rounded-e-md overflow-y-auto">
|
||||||
<SheetHeader>
|
<SheetHeader>
|
||||||
<SheetTitle>{t('common:modelSettings.title', { modelId: model.id })}</SheetTitle>
|
<SheetTitle>
|
||||||
|
{t('common:modelSettings.title', { modelId: model.id })}
|
||||||
|
</SheetTitle>
|
||||||
<SheetDescription>
|
<SheetDescription>
|
||||||
{t('common:modelSettings.description')}
|
{t('common:modelSettings.description')}
|
||||||
</SheetDescription>
|
</SheetDescription>
|
||||||
|
|||||||
@ -261,6 +261,25 @@ export const useChat = () => {
|
|||||||
!abortController.signal.aborted &&
|
!abortController.signal.aborted &&
|
||||||
activeProvider
|
activeProvider
|
||||||
) {
|
) {
|
||||||
|
const modelConfig = activeProvider.models.find(
|
||||||
|
(m) => m.id === selectedModel?.id
|
||||||
|
)
|
||||||
|
|
||||||
|
const modelSettings = modelConfig?.settings
|
||||||
|
? Object.fromEntries(
|
||||||
|
Object.entries(modelConfig.settings)
|
||||||
|
.filter(
|
||||||
|
([key, value]) =>
|
||||||
|
key !== 'ctx_len' &&
|
||||||
|
key !== 'ngl' &&
|
||||||
|
value.controller_props?.value !== undefined &&
|
||||||
|
value.controller_props?.value !== null &&
|
||||||
|
value.controller_props?.value !== ''
|
||||||
|
)
|
||||||
|
.map(([key, value]) => [key, value.controller_props?.value])
|
||||||
|
)
|
||||||
|
: undefined
|
||||||
|
|
||||||
const completion = await sendCompletion(
|
const completion = await sendCompletion(
|
||||||
activeThread,
|
activeThread,
|
||||||
activeProvider,
|
activeProvider,
|
||||||
@ -268,7 +287,10 @@ export const useChat = () => {
|
|||||||
abortController,
|
abortController,
|
||||||
availableTools,
|
availableTools,
|
||||||
currentAssistant.parameters?.stream === false ? false : true,
|
currentAssistant.parameters?.stream === false ? false : true,
|
||||||
currentAssistant.parameters as unknown as Record<string, object>
|
{
|
||||||
|
...modelSettings,
|
||||||
|
...currentAssistant.parameters,
|
||||||
|
} as unknown as Record<string, object>
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!completion) throw new Error('No completion received')
|
if (!completion) throw new Error('No completion received')
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
fetchModels,
|
fetchModels,
|
||||||
fetchModelCatalog,
|
fetchModelCatalog,
|
||||||
@ -10,9 +11,8 @@ import {
|
|||||||
stopModel,
|
stopModel,
|
||||||
stopAllModels,
|
stopAllModels,
|
||||||
startModel,
|
startModel,
|
||||||
configurePullOptions,
|
|
||||||
} from '../models'
|
} from '../models'
|
||||||
import { EngineManager } from '@janhq/core'
|
import { EngineManager, Model } from '@janhq/core'
|
||||||
|
|
||||||
// Mock EngineManager
|
// Mock EngineManager
|
||||||
vi.mock('@janhq/core', () => ({
|
vi.mock('@janhq/core', () => ({
|
||||||
@ -118,7 +118,7 @@ describe('models service', () => {
|
|||||||
settings: [{ key: 'temperature', value: 0.7 }],
|
settings: [{ key: 'temperature', value: 0.7 }],
|
||||||
}
|
}
|
||||||
|
|
||||||
await updateModel(model)
|
await updateModel(model as any)
|
||||||
|
|
||||||
expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings)
|
expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings)
|
||||||
})
|
})
|
||||||
@ -209,7 +209,14 @@ describe('models service', () => {
|
|||||||
|
|
||||||
describe('startModel', () => {
|
describe('startModel', () => {
|
||||||
it('should start model successfully', async () => {
|
it('should start model successfully', async () => {
|
||||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
const mockSettings = {
|
||||||
|
ctx_len: { controller_props: { value: 4096 } },
|
||||||
|
ngl: { controller_props: { value: 32 } },
|
||||||
|
}
|
||||||
|
const provider = {
|
||||||
|
provider: 'openai',
|
||||||
|
models: [{ id: 'model1', settings: mockSettings }],
|
||||||
|
} as any
|
||||||
const model = 'model1'
|
const model = 'model1'
|
||||||
const mockSession = { id: 'session1' }
|
const mockSession = { id: 'session1' }
|
||||||
|
|
||||||
@ -221,11 +228,21 @@ describe('models service', () => {
|
|||||||
const result = await startModel(provider, model)
|
const result = await startModel(provider, model)
|
||||||
|
|
||||||
expect(result).toEqual(mockSession)
|
expect(result).toEqual(mockSession)
|
||||||
expect(mockEngine.load).toHaveBeenCalledWith(model)
|
expect(mockEngine.load).toHaveBeenCalledWith(model, {
|
||||||
|
ctx_size: 4096,
|
||||||
|
n_gpu_layers: 32,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle start model error', async () => {
|
it('should handle start model error', async () => {
|
||||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
const mockSettings = {
|
||||||
|
ctx_len: { controller_props: { value: 4096 } },
|
||||||
|
ngl: { controller_props: { value: 32 } },
|
||||||
|
}
|
||||||
|
const provider = {
|
||||||
|
provider: 'openai',
|
||||||
|
models: [{ id: 'model1', settings: mockSettings }],
|
||||||
|
} as any
|
||||||
const model = 'model1'
|
const model = 'model1'
|
||||||
const error = new Error('Failed to start model')
|
const error = new Error('Failed to start model')
|
||||||
|
|
||||||
@ -237,7 +254,14 @@ describe('models service', () => {
|
|||||||
await expect(startModel(provider, model)).rejects.toThrow(error)
|
await expect(startModel(provider, model)).rejects.toThrow(error)
|
||||||
})
|
})
|
||||||
it('should not load model again', async () => {
|
it('should not load model again', async () => {
|
||||||
const provider = { provider: 'openai', models: [] } as ProviderObject
|
const mockSettings = {
|
||||||
|
ctx_len: { controller_props: { value: 4096 } },
|
||||||
|
ngl: { controller_props: { value: 32 } },
|
||||||
|
}
|
||||||
|
const provider = {
|
||||||
|
provider: 'openai',
|
||||||
|
models: [{ id: 'model1', settings: mockSettings }],
|
||||||
|
} as any
|
||||||
const model = 'model1'
|
const model = 'model1'
|
||||||
|
|
||||||
mockEngine.getLoadedModels.mockResolvedValue({
|
mockEngine.getLoadedModels.mockResolvedValue({
|
||||||
|
|||||||
@ -150,7 +150,29 @@ export const startModel = async (
|
|||||||
if (!engine) return undefined
|
if (!engine) return undefined
|
||||||
|
|
||||||
if ((await engine.getLoadedModels()).includes(model)) return undefined
|
if ((await engine.getLoadedModels()).includes(model)) return undefined
|
||||||
return engine.load(model).catch((error) => {
|
|
||||||
|
// Find the model configuration to get settings
|
||||||
|
const modelConfig = provider.models.find((m) => m.id === model)
|
||||||
|
|
||||||
|
// Key mapping function to transform setting keys
|
||||||
|
const mapSettingKey = (key: string): string => {
|
||||||
|
const keyMappings: Record<string, string> = {
|
||||||
|
ctx_len: 'ctx_size',
|
||||||
|
ngl: 'n_gpu_layers',
|
||||||
|
}
|
||||||
|
return keyMappings[key] || key
|
||||||
|
}
|
||||||
|
|
||||||
|
const settings = modelConfig?.settings
|
||||||
|
? Object.fromEntries(
|
||||||
|
Object.entries(modelConfig.settings).map(([key, value]) => [
|
||||||
|
mapSettingKey(key),
|
||||||
|
value.controller_props?.value,
|
||||||
|
])
|
||||||
|
)
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
return engine.load(model, settings).catch((error) => {
|
||||||
console.error(
|
console.error(
|
||||||
`Failed to start model ${model} for provider ${provider.provider}:`,
|
`Failed to start model ${model} for provider ${provider.provider}:`,
|
||||||
error
|
error
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user