Merge branch 'dev' into copyfix
This commit is contained in:
commit
338232b173
1
.gitignore
vendored
1
.gitignore
vendored
@ -45,3 +45,4 @@ core/test_results.html
|
|||||||
coverage
|
coverage
|
||||||
.yarn
|
.yarn
|
||||||
.yarnrc
|
.yarnrc
|
||||||
|
*.tsbuildinfo
|
||||||
|
|||||||
@ -1,98 +1,109 @@
|
|||||||
import { openExternalUrl } from './core';
|
import { openExternalUrl } from './core'
|
||||||
import { joinPath } from './core';
|
import { joinPath } from './core'
|
||||||
import { openFileExplorer } from './core';
|
import { openFileExplorer } from './core'
|
||||||
import { getJanDataFolderPath } from './core';
|
import { getJanDataFolderPath } from './core'
|
||||||
import { abortDownload } from './core';
|
import { abortDownload } from './core'
|
||||||
import { getFileSize } from './core';
|
import { getFileSize } from './core'
|
||||||
import { executeOnMain } from './core';
|
import { executeOnMain } from './core'
|
||||||
|
|
||||||
it('should open external url', async () => {
|
describe('test core apis', () => {
|
||||||
const url = 'http://example.com';
|
it('should open external url', async () => {
|
||||||
globalThis.core = {
|
const url = 'http://example.com'
|
||||||
api: {
|
globalThis.core = {
|
||||||
openExternalUrl: jest.fn().mockResolvedValue('opened')
|
api: {
|
||||||
|
openExternalUrl: jest.fn().mockResolvedValue('opened'),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await openExternalUrl(url)
|
||||||
const result = await openExternalUrl(url);
|
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
|
||||||
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url);
|
expect(result).toBe('opened')
|
||||||
expect(result).toBe('opened');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should join paths', async () => {
|
||||||
it('should join paths', async () => {
|
const paths = ['/path/one', '/path/two']
|
||||||
const paths = ['/path/one', '/path/two'];
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
|
||||||
joinPath: jest.fn().mockResolvedValue('/path/one/path/two')
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await joinPath(paths)
|
||||||
const result = await joinPath(paths);
|
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
|
||||||
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths);
|
expect(result).toBe('/path/one/path/two')
|
||||||
expect(result).toBe('/path/one/path/two');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should open file explorer', async () => {
|
||||||
it('should open file explorer', async () => {
|
const path = '/path/to/open'
|
||||||
const path = '/path/to/open';
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
openFileExplorer: jest.fn().mockResolvedValue('opened'),
|
||||||
openFileExplorer: jest.fn().mockResolvedValue('opened')
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await openFileExplorer(path)
|
||||||
const result = await openFileExplorer(path);
|
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
|
||||||
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path);
|
expect(result).toBe('opened')
|
||||||
expect(result).toBe('opened');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should get jan data folder path', async () => {
|
||||||
it('should get jan data folder path', async () => {
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
|
||||||
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data')
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await getJanDataFolderPath()
|
||||||
const result = await getJanDataFolderPath();
|
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
|
||||||
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled();
|
expect(result).toBe('/path/to/jan/data')
|
||||||
expect(result).toBe('/path/to/jan/data');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should abort download', async () => {
|
||||||
it('should abort download', async () => {
|
const fileName = 'testFile'
|
||||||
const fileName = 'testFile';
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
abortDownload: jest.fn().mockResolvedValue('aborted'),
|
||||||
abortDownload: jest.fn().mockResolvedValue('aborted')
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await abortDownload(fileName)
|
||||||
const result = await abortDownload(fileName);
|
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
|
||||||
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName);
|
expect(result).toBe('aborted')
|
||||||
expect(result).toBe('aborted');
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should get file size', async () => {
|
||||||
it('should get file size', async () => {
|
const url = 'http://example.com/file'
|
||||||
const url = 'http://example.com/file';
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
getFileSize: jest.fn().mockResolvedValue(1024),
|
||||||
getFileSize: jest.fn().mockResolvedValue(1024)
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await getFileSize(url)
|
||||||
const result = await getFileSize(url);
|
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
|
||||||
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url);
|
expect(result).toBe(1024)
|
||||||
expect(result).toBe(1024);
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should execute function on main process', async () => {
|
||||||
it('should execute function on main process', async () => {
|
const extension = 'testExtension'
|
||||||
const extension = 'testExtension';
|
const method = 'testMethod'
|
||||||
const method = 'testMethod';
|
const args = ['arg1', 'arg2']
|
||||||
const args = ['arg1', 'arg2'];
|
globalThis.core = {
|
||||||
globalThis.core = {
|
api: {
|
||||||
api: {
|
invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
|
||||||
invokeExtensionFunc: jest.fn().mockResolvedValue('result')
|
},
|
||||||
}
|
}
|
||||||
};
|
const result = await executeOnMain(extension, method, ...args)
|
||||||
const result = await executeOnMain(extension, method, ...args);
|
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
|
||||||
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args);
|
expect(result).toBe('result')
|
||||||
expect(result).toBe('result');
|
})
|
||||||
});
|
})
|
||||||
|
|
||||||
|
describe('dirName - just a pass thru api', () => {
|
||||||
|
it('should retrieve the directory name from a file path', async () => {
|
||||||
|
const mockDirName = jest.fn()
|
||||||
|
globalThis.core = {
|
||||||
|
api: {
|
||||||
|
dirName: mockDirName.mockResolvedValue('/path/to'),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Normal file path with extension
|
||||||
|
const path = '/path/to/file.txt'
|
||||||
|
await globalThis.core.api.dirName(path)
|
||||||
|
expect(mockDirName).toHaveBeenCalledWith(path)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
|
|||||||
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
|
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
|
||||||
globalThis.core.api?.joinPath(paths)
|
globalThis.core.api?.joinPath(paths)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get dirname of a file path.
|
||||||
|
* @param path - The file path to retrieve dirname.
|
||||||
|
* @returns {Promise<string>} A promise that resolves the dirname.
|
||||||
|
*/
|
||||||
|
const dirName: (path: string) => Promise<string> = (path) => globalThis.core.api?.dirName(path)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieve the basename from an url.
|
* Retrieve the basename from an url.
|
||||||
* @param path - The path to retrieve.
|
* @param path - The path to retrieve.
|
||||||
@ -161,5 +168,6 @@ export {
|
|||||||
systemInformation,
|
systemInformation,
|
||||||
showToast,
|
showToast,
|
||||||
getFileSize,
|
getFileSize,
|
||||||
|
dirName,
|
||||||
FileStat,
|
FileStat,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
|
|||||||
import { events } from '../../events'
|
import { events } from '../../events'
|
||||||
import { BaseExtension } from '../../extension'
|
import { BaseExtension } from '../../extension'
|
||||||
import { fs } from '../../fs'
|
import { fs } from '../../fs'
|
||||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
|
||||||
import { EngineManager } from './EngineManager'
|
import { EngineManager } from './EngineManager'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
override onLoad() {
|
override onLoad() {
|
||||||
this.registerEngine()
|
this.registerEngine()
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Loads the model.
|
* Loads the model.
|
||||||
*/
|
*/
|
||||||
async loadModel(model: Model): Promise<any> {
|
async loadModel(model: ModelFile): Promise<any> {
|
||||||
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
|
import { executeOnMain, systemInformation, dirName } from '../../core'
|
||||||
import { events } from '../../events'
|
import { events } from '../../events'
|
||||||
import { Model, ModelEvent } from '../../../types'
|
import { Model, ModelEvent, ModelFile } from '../../../types'
|
||||||
import { OAIEngine } from './OAIEngine'
|
import { OAIEngine } from './OAIEngine'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
unloadModelFunctionName: string = 'unloadModel'
|
unloadModelFunctionName: string = 'unloadModel'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* This class represents a base for local inference providers in the OpenAI architecture.
|
||||||
|
* It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
|
||||||
|
* The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
|
||||||
|
* The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
|
||||||
*/
|
*/
|
||||||
override onLoad() {
|
override onLoad() {
|
||||||
super.onLoad()
|
super.onLoad()
|
||||||
// These events are applicable to local inference providers
|
// These events are applicable to local inference providers
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the model.
|
* Load the model.
|
||||||
*/
|
*/
|
||||||
override async loadModel(model: Model): Promise<void> {
|
override async loadModel(model: ModelFile): Promise<void> {
|
||||||
if (model.engine.toString() !== this.provider) return
|
if (model.engine.toString() !== this.provider) return
|
||||||
const modelFolderName = 'models'
|
const modelFolder = await dirName(model.file_path)
|
||||||
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
|
|
||||||
const systemInfo = await systemInformation()
|
const systemInfo = await systemInformation()
|
||||||
const res = await executeOnMain(
|
const res = await executeOnMain(
|
||||||
this.nodeModule,
|
this.nodeModule,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import {
|
|||||||
HuggingFaceRepoData,
|
HuggingFaceRepoData,
|
||||||
ImportingModel,
|
ImportingModel,
|
||||||
Model,
|
Model,
|
||||||
|
ModelFile,
|
||||||
ModelInterface,
|
ModelInterface,
|
||||||
OptionType,
|
OptionType,
|
||||||
} from '../../types'
|
} from '../../types'
|
||||||
@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
|
|||||||
network?: { proxy: string; ignoreSSL?: boolean }
|
network?: { proxy: string; ignoreSSL?: boolean }
|
||||||
): Promise<void>
|
): Promise<void>
|
||||||
abstract cancelModelDownload(modelId: string): Promise<void>
|
abstract cancelModelDownload(modelId: string): Promise<void>
|
||||||
abstract deleteModel(modelId: string): Promise<void>
|
abstract deleteModel(model: ModelFile): Promise<void>
|
||||||
abstract saveModel(model: Model): Promise<void>
|
abstract getDownloadedModels(): Promise<ModelFile[]>
|
||||||
abstract getDownloadedModels(): Promise<Model[]>
|
abstract getConfiguredModels(): Promise<ModelFile[]>
|
||||||
abstract getConfiguredModels(): Promise<Model[]>
|
|
||||||
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
|
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
|
||||||
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model>
|
abstract updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile>
|
||||||
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
|
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
|
||||||
abstract getDefaultModel(): Promise<Model>
|
abstract getDefaultModel(): Promise<Model>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,40 +1,57 @@
|
|||||||
import { App } from './app';
|
jest.mock('../../helper', () => ({
|
||||||
|
...jest.requireActual('../../helper'),
|
||||||
|
getJanDataFolderPath: () => './app',
|
||||||
|
}))
|
||||||
|
import { dirname } from 'path'
|
||||||
|
import { App } from './app'
|
||||||
|
|
||||||
it('should call stopServer', () => {
|
it('should call stopServer', () => {
|
||||||
const app = new App();
|
const app = new App()
|
||||||
const stopServerMock = jest.fn().mockResolvedValue('Server stopped');
|
const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
|
||||||
jest.mock('@janhq/server', () => ({
|
jest.mock('@janhq/server', () => ({
|
||||||
stopServer: stopServerMock
|
stopServer: stopServerMock,
|
||||||
}));
|
}))
|
||||||
const result = app.stopServer();
|
app.stopServer()
|
||||||
expect(stopServerMock).toHaveBeenCalled();
|
expect(stopServerMock).toHaveBeenCalled()
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should correctly retrieve basename', () => {
|
it('should correctly retrieve basename', () => {
|
||||||
const app = new App();
|
const app = new App()
|
||||||
const result = app.baseName('/path/to/file.txt');
|
const result = app.baseName('/path/to/file.txt')
|
||||||
expect(result).toBe('file.txt');
|
expect(result).toBe('file.txt')
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should correctly identify subdirectories', () => {
|
it('should correctly identify subdirectories', () => {
|
||||||
const app = new App();
|
const app = new App()
|
||||||
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to';
|
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
|
||||||
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir';
|
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
|
||||||
const result = app.isSubdirectory(basePath, subPath);
|
const result = app.isSubdirectory(basePath, subPath)
|
||||||
expect(result).toBe(true);
|
expect(result).toBe(true)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should correctly join multiple paths', () => {
|
it('should correctly join multiple paths', () => {
|
||||||
const app = new App();
|
const app = new App()
|
||||||
const result = app.joinPath(['path', 'to', 'file']);
|
const result = app.joinPath(['path', 'to', 'file'])
|
||||||
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file';
|
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
|
||||||
expect(result).toBe(expectedPath);
|
expect(result).toBe(expectedPath)
|
||||||
});
|
})
|
||||||
|
|
||||||
it('should call correct function with provided arguments using process method', () => {
|
it('should call correct function with provided arguments using process method', () => {
|
||||||
const app = new App();
|
const app = new App()
|
||||||
const mockFunc = jest.fn();
|
const mockFunc = jest.fn()
|
||||||
app.joinPath = mockFunc;
|
app.joinPath = mockFunc
|
||||||
app.process('joinPath', ['path1', 'path2']);
|
app.process('joinPath', ['path1', 'path2'])
|
||||||
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']);
|
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
|
||||||
});
|
})
|
||||||
|
|
||||||
|
it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
|
||||||
|
const app = new App()
|
||||||
|
const path = 'C:/Users/John Doe/Desktop/file.txt'
|
||||||
|
expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should retrieve the directory name when using file protocol', async () => {
|
||||||
|
const app = new App()
|
||||||
|
const path = 'file:/models/file.txt'
|
||||||
|
expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
|
||||||
|
})
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { basename, isAbsolute, join, relative } from 'path'
|
import { basename, dirname, isAbsolute, join, relative } from 'path'
|
||||||
|
|
||||||
import { Processor } from './Processor'
|
import { Processor } from './Processor'
|
||||||
import {
|
import {
|
||||||
@ -6,6 +6,8 @@ import {
|
|||||||
appResourcePath,
|
appResourcePath,
|
||||||
getAppConfigurations as appConfiguration,
|
getAppConfigurations as appConfiguration,
|
||||||
updateAppConfiguration,
|
updateAppConfiguration,
|
||||||
|
normalizeFilePath,
|
||||||
|
getJanDataFolderPath,
|
||||||
} from '../../helper'
|
} from '../../helper'
|
||||||
|
|
||||||
export class App implements Processor {
|
export class App implements Processor {
|
||||||
@ -28,6 +30,18 @@ export class App implements Processor {
|
|||||||
return join(...args)
|
return join(...args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get dirname of a file path.
|
||||||
|
* @param path - The file path to retrieve dirname.
|
||||||
|
*/
|
||||||
|
dirName(path: string) {
|
||||||
|
const arg =
|
||||||
|
path.startsWith(`file:/`) || path.startsWith(`file:\\`)
|
||||||
|
? join(getJanDataFolderPath(), normalizeFilePath(path))
|
||||||
|
: path
|
||||||
|
return dirname(arg)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the given path is a subdirectory of the given directory.
|
* Checks if the given path is a subdirectory of the given directory.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -37,6 +37,7 @@ export enum AppRoute {
|
|||||||
getAppConfigurations = 'getAppConfigurations',
|
getAppConfigurations = 'getAppConfigurations',
|
||||||
updateAppConfiguration = 'updateAppConfiguration',
|
updateAppConfiguration = 'updateAppConfiguration',
|
||||||
joinPath = 'joinPath',
|
joinPath = 'joinPath',
|
||||||
|
dirName = 'dirName',
|
||||||
isSubdirectory = 'isSubdirectory',
|
isSubdirectory = 'isSubdirectory',
|
||||||
baseName = 'baseName',
|
baseName = 'baseName',
|
||||||
startServer = 'startServer',
|
startServer = 'startServer',
|
||||||
|
|||||||
@ -52,3 +52,18 @@ type DownloadSize = {
|
|||||||
total: number
|
total: number
|
||||||
transferred: number
|
transferred: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The file metadata
|
||||||
|
*/
|
||||||
|
export type FileMetadata = {
|
||||||
|
/**
|
||||||
|
* The origin file path.
|
||||||
|
*/
|
||||||
|
file_path: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The file name.
|
||||||
|
*/
|
||||||
|
file_name: string
|
||||||
|
}
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import { FileMetadata } from '../file'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the information about a model.
|
* Represents the information about a model.
|
||||||
* @stored
|
* @stored
|
||||||
@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
|
|||||||
export type ModelInitFailed = Model & {
|
export type ModelInitFailed = Model & {
|
||||||
error: Error
|
error: Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ModelFile is the model.json entity and it's file metadata
|
||||||
|
*/
|
||||||
|
export type ModelFile = Model & FileMetadata
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { GpuSetting } from '../miscellaneous'
|
import { GpuSetting } from '../miscellaneous'
|
||||||
import { Model } from './modelEntity'
|
import { Model, ModelFile } from './modelEntity'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model extension for managing models.
|
* Model extension for managing models.
|
||||||
@ -29,14 +29,7 @@ export interface ModelInterface {
|
|||||||
* @param modelId - The ID of the model to delete.
|
* @param modelId - The ID of the model to delete.
|
||||||
* @returns A Promise that resolves when the model has been deleted.
|
* @returns A Promise that resolves when the model has been deleted.
|
||||||
*/
|
*/
|
||||||
deleteModel(modelId: string): Promise<void>
|
deleteModel(model: ModelFile): Promise<void>
|
||||||
|
|
||||||
/**
|
|
||||||
* Saves a model.
|
|
||||||
* @param model - The model to save.
|
|
||||||
* @returns A Promise that resolves when the model has been saved.
|
|
||||||
*/
|
|
||||||
saveModel(model: Model): Promise<void>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets a list of downloaded models.
|
* Gets a list of downloaded models.
|
||||||
|
|||||||
@ -1,32 +1,29 @@
|
|||||||
import { expect } from '@playwright/test'
|
import { expect } from '@playwright/test'
|
||||||
import { page, test, TIMEOUT } from '../config/fixtures'
|
import { page, test, TIMEOUT } from '../config/fixtures'
|
||||||
|
|
||||||
test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => {
|
test('Select GPT model from Hub and Chat with Invalid API Key', async ({
|
||||||
|
hubPage,
|
||||||
|
}) => {
|
||||||
await hubPage.navigateByMenu()
|
await hubPage.navigateByMenu()
|
||||||
await hubPage.verifyContainerVisible()
|
await hubPage.verifyContainerVisible()
|
||||||
|
|
||||||
// Select the first GPT model
|
// Select the first GPT model
|
||||||
await page
|
await page
|
||||||
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
|
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
|
||||||
.first().click()
|
.first()
|
||||||
|
|
||||||
// Attempt to create thread and chat in Thread page
|
|
||||||
await page
|
|
||||||
.getByTestId('btn-create-thread')
|
|
||||||
.click()
|
.click()
|
||||||
|
|
||||||
await page
|
await page.getByTestId('txt-input-chat').fill('dummy value')
|
||||||
.getByTestId('txt-input-chat')
|
|
||||||
.fill('dummy value')
|
|
||||||
|
|
||||||
await page
|
await page.getByTestId('btn-send-chat').click()
|
||||||
.getByTestId('btn-send-chat')
|
|
||||||
.click()
|
|
||||||
|
|
||||||
await page.waitForFunction(() => {
|
await page.waitForFunction(
|
||||||
const loaders = document.querySelectorAll('[data-testid$="loader"]');
|
() => {
|
||||||
return !loaders.length;
|
const loaders = document.querySelectorAll('[data-testid$="loader"]')
|
||||||
}, { timeout: TIMEOUT });
|
return !loaders.length
|
||||||
|
},
|
||||||
|
{ timeout: TIMEOUT }
|
||||||
|
)
|
||||||
|
|
||||||
const APIKeyError = page.getByTestId('invalid-API-key-error')
|
const APIKeyError = page.getByTestId('invalid-API-key-error')
|
||||||
await expect(APIKeyError).toBeVisible({
|
await expect(APIKeyError).toBeVisible({
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import {
|
|||||||
downloadFile,
|
downloadFile,
|
||||||
DownloadState,
|
DownloadState,
|
||||||
DownloadEvent,
|
DownloadEvent,
|
||||||
|
ModelFile,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
declare const CUDA_DOWNLOAD_URL: string
|
declare const CUDA_DOWNLOAD_URL: string
|
||||||
@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
|||||||
this.nitroProcessInfo = health
|
this.nitroProcessInfo = health
|
||||||
}
|
}
|
||||||
|
|
||||||
override loadModel(model: Model): Promise<void> {
|
override loadModel(model: ModelFile): Promise<void> {
|
||||||
if (model.engine !== this.provider) return Promise.resolve()
|
if (model.engine !== this.provider) return Promise.resolve()
|
||||||
this.getNitroProcessHealthIntervalId = setInterval(
|
this.getNitroProcessHealthIntervalId = setInterval(
|
||||||
() => this.periodicallyGetNitroHealth(),
|
() => this.periodicallyGetNitroHealth(),
|
||||||
|
|||||||
@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
|
|||||||
import {
|
import {
|
||||||
log,
|
log,
|
||||||
getSystemResourceInfo,
|
getSystemResourceInfo,
|
||||||
Model,
|
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
ModelSettingParams,
|
ModelSettingParams,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
SystemInformation,
|
SystemInformation,
|
||||||
getJanDataFolderPath,
|
getJanDataFolderPath,
|
||||||
|
ModelFile,
|
||||||
} from '@janhq/core/node'
|
} from '@janhq/core/node'
|
||||||
import { executableNitroFile } from './execute'
|
import { executableNitroFile } from './execute'
|
||||||
import terminate from 'terminate'
|
import terminate from 'terminate'
|
||||||
@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
|
|||||||
*/
|
*/
|
||||||
interface ModelInitOptions {
|
interface ModelInitOptions {
|
||||||
modelFolder: string
|
modelFolder: string
|
||||||
model: Model
|
model: ModelFile
|
||||||
}
|
}
|
||||||
// The PORT to use for the Nitro subprocess
|
// The PORT to use for the Nitro subprocess
|
||||||
const PORT = 3928
|
const PORT = 3928
|
||||||
@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
if (!settings?.ngl) {
|
if (!settings?.ngl) {
|
||||||
settings.ngl = 100
|
settings.ngl = 100
|
||||||
}
|
}
|
||||||
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
|
log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
|
||||||
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
|
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
})
|
})
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Load model success with response ${JSON.stringify(
|
`[CORTEX]:: Load model success with response ${JSON.stringify(
|
||||||
res
|
res
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
async function validateModelStatus(modelId: string): Promise<void> {
|
async function validateModelStatus(modelId: string): Promise<void> {
|
||||||
// Send a GET request to the validation URL.
|
// Send a GET request to the validation URL.
|
||||||
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
||||||
log(`[CORTEX]::Debug: Validating model ${modelId}`)
|
log(`[CORTEX]:: Validating model ${modelId}`)
|
||||||
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
retryDelay: 300,
|
retryDelay: 300,
|
||||||
}).then(async (res: Response) => {
|
}).then(async (res: Response) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
|
`[CORTEX]:: Validate model state with response ${JSON.stringify(
|
||||||
res.status
|
res.status
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
// Otherwise, return an object with an error message.
|
// Otherwise, return an object with an error message.
|
||||||
if (body.model_loaded) {
|
if (body.model_loaded) {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
|
`[CORTEX]:: Validate model state success with response ${JSON.stringify(
|
||||||
body
|
body
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
}
|
}
|
||||||
const errorBody = await res.text()
|
const errorBody = await res.text()
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
`[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
||||||
res.statusText
|
res.statusText
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
async function killSubprocess(): Promise<void> {
|
async function killSubprocess(): Promise<void> {
|
||||||
const controller = new AbortController()
|
const controller = new AbortController()
|
||||||
setTimeout(() => controller.abort(), 5000)
|
setTimeout(() => controller.abort(), 5000)
|
||||||
log(`[CORTEX]::Debug: Request to kill cortex`)
|
log(`[CORTEX]:: Request to kill cortex`)
|
||||||
|
|
||||||
const killRequest = () => {
|
const killRequest = () => {
|
||||||
return fetch(NITRO_HTTP_KILL_URL, {
|
return fetch(NITRO_HTTP_KILL_URL, {
|
||||||
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
|
|||||||
.then(() =>
|
.then(() =>
|
||||||
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||||
)
|
)
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
`[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
||||||
)
|
)
|
||||||
throw 'PORT_NOT_AVAILABLE'
|
throw 'PORT_NOT_AVAILABLE'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (subprocess?.pid && process.platform !== 'darwin') {
|
if (subprocess?.pid && process.platform !== 'darwin') {
|
||||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||||
const pid = subprocess.pid
|
const pid = subprocess.pid
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
terminate(pid, function (err) {
|
terminate(pid, function (err) {
|
||||||
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
|
|||||||
} else {
|
} else {
|
||||||
tcpPortUsed
|
tcpPortUsed
|
||||||
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.then(() => resolve())
|
.then(() => resolve())
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
log(
|
log(
|
||||||
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
|
|||||||
* @returns A promise that resolves when the Nitro subprocess is started.
|
* @returns A promise that resolves when the Nitro subprocess is started.
|
||||||
*/
|
*/
|
||||||
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||||
log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
|
log(`[CORTEX]:: Spawning cortex subprocess...`)
|
||||||
|
|
||||||
return new Promise<void>(async (resolve, reject) => {
|
return new Promise<void>(async (resolve, reject) => {
|
||||||
let executableOptions = executableNitroFile(
|
let executableOptions = executableNitroFile(
|
||||||
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
||||||
// Execute the binary
|
// Execute the binary
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
`[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
||||||
)
|
)
|
||||||
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
|
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
|
||||||
|
|
||||||
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
|
|
||||||
// Handle subprocess output
|
// Handle subprocess output
|
||||||
subprocess.stdout.on('data', (data: any) => {
|
subprocess.stdout.on('data', (data: any) => {
|
||||||
log(`[CORTEX]::Debug: ${data}`)
|
log(`[CORTEX]:: ${data}`)
|
||||||
})
|
})
|
||||||
|
|
||||||
subprocess.stderr.on('data', (data: any) => {
|
subprocess.stderr.on('data', (data: any) => {
|
||||||
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
})
|
})
|
||||||
|
|
||||||
subprocess.on('close', (code: any) => {
|
subprocess.on('close', (code: any) => {
|
||||||
log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
|
log(`[CORTEX]:: cortex exited with code: ${code}`)
|
||||||
subprocess = undefined
|
subprocess = undefined
|
||||||
reject(`child process exited with code ${code}`)
|
reject(`child process exited with code ${code}`)
|
||||||
})
|
})
|
||||||
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
tcpPortUsed
|
tcpPortUsed
|
||||||
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
|
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
log(`[CORTEX]::Debug: cortex is ready`)
|
log(`[CORTEX]:: cortex is ready`)
|
||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -119,5 +119,65 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"engine": "openai"
|
"engine": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sources": [
|
||||||
|
{
|
||||||
|
"url": "https://openai.com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"id": "o1-preview",
|
||||||
|
"object": "model",
|
||||||
|
"name": "OpenAI o1-preview",
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "OpenAI o1-preview is a new model with complex reasoning",
|
||||||
|
"format": "api",
|
||||||
|
"settings": {},
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"stream": true,
|
||||||
|
"stop": [],
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"author": "OpenAI",
|
||||||
|
"tags": [
|
||||||
|
"General"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"engine": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sources": [
|
||||||
|
{
|
||||||
|
"url": "https://openai.com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"id": "o1-mini",
|
||||||
|
"object": "model",
|
||||||
|
"name": "OpenAI o1-mini",
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "OpenAI o1-mini is a lightweight reasoning model",
|
||||||
|
"format": "api",
|
||||||
|
"settings": {},
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"stream": true,
|
||||||
|
"stop": [],
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"author": "OpenAI",
|
||||||
|
"tags": [
|
||||||
|
"General"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"engine": "openai"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
9
extensions/model-extension/jest.config.js
Normal file
9
extensions/model-extension/jest.config.js
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||||
|
module.exports = {
|
||||||
|
preset: 'ts-jest',
|
||||||
|
testEnvironment: 'node',
|
||||||
|
transform: {
|
||||||
|
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
|
||||||
|
},
|
||||||
|
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
|
||||||
|
}
|
||||||
@ -8,6 +8,7 @@
|
|||||||
"author": "Jan <service@jan.ai>",
|
"author": "Jan <service@jan.ai>",
|
||||||
"license": "AGPL-3.0",
|
"license": "AGPL-3.0",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"test": "jest",
|
||||||
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
|
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
|
||||||
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -27,7 +27,7 @@ export default [
|
|||||||
// Allow json resolution
|
// Allow json resolution
|
||||||
json(),
|
json(),
|
||||||
// Compile TypeScript files
|
// Compile TypeScript files
|
||||||
typescript({ useTsconfigDeclarationDir: true }),
|
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
|
||||||
// Compile TypeScript files
|
// Compile TypeScript files
|
||||||
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
||||||
// commonjs(),
|
// commonjs(),
|
||||||
@ -62,7 +62,7 @@ export default [
|
|||||||
// Allow json resolution
|
// Allow json resolution
|
||||||
json(),
|
json(),
|
||||||
// Compile TypeScript files
|
// Compile TypeScript files
|
||||||
typescript({ useTsconfigDeclarationDir: true }),
|
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
|
||||||
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
||||||
commonjs(),
|
commonjs(),
|
||||||
// Allow node_modules resolution, so you can use 'external' to control
|
// Allow node_modules resolution, so you can use 'external' to control
|
||||||
|
|||||||
564
extensions/model-extension/src/index.test.ts
Normal file
564
extensions/model-extension/src/index.test.ts
Normal file
@ -0,0 +1,564 @@
|
|||||||
|
const readDirSyncMock = jest.fn()
|
||||||
|
const existMock = jest.fn()
|
||||||
|
const readFileSyncMock = jest.fn()
|
||||||
|
|
||||||
|
jest.mock('@janhq/core', () => ({
|
||||||
|
...jest.requireActual('@janhq/core/node'),
|
||||||
|
fs: {
|
||||||
|
existsSync: existMock,
|
||||||
|
readdirSync: readDirSyncMock,
|
||||||
|
readFileSync: readFileSyncMock,
|
||||||
|
fileStat: () => ({
|
||||||
|
isDirectory: false,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
dirName: jest.fn(),
|
||||||
|
joinPath: (paths) => paths.join('/'),
|
||||||
|
ModelExtension: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import JanModelExtension from '.'
|
||||||
|
import { fs, dirName } from '@janhq/core'
|
||||||
|
|
||||||
|
describe('JanModelExtension', () => {
|
||||||
|
let sut: JanModelExtension
|
||||||
|
|
||||||
|
beforeAll(() => {
|
||||||
|
// @ts-ignore
|
||||||
|
sut = new JanModelExtension()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getConfiguredModels', () => {
|
||||||
|
describe("when there's no models are pre-populated", () => {
|
||||||
|
it('should return empty array', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = []
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
readDirSyncMock.mockReturnValue([])
|
||||||
|
|
||||||
|
const result = await sut.getConfiguredModels()
|
||||||
|
expect(result).toEqual([])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("when there's are pre-populated models - all flattened", () => {
|
||||||
|
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2']
|
||||||
|
else return ['model.json']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getConfiguredModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model2/model.json',
|
||||||
|
id: '2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("when there's are pre-populated models - there are nested folders", () => {
|
||||||
|
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||||
|
else return ['model.json']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else if (path.includes('model2/model2-1'))
|
||||||
|
return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getConfiguredModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model2/model2-1/model.json',
|
||||||
|
id: '2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getDownloadedModels', () => {
|
||||||
|
describe('no models downloaded', () => {
|
||||||
|
it('should return empty array', async () => {
|
||||||
|
// Mock downloaded models data
|
||||||
|
const downloadedModels = []
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
readDirSyncMock.mockReturnValue([])
|
||||||
|
|
||||||
|
const result = await sut.getDownloadedModels()
|
||||||
|
expect(result).toEqual([])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
describe('only one model is downloaded', () => {
|
||||||
|
describe('flatten folder', () => {
|
||||||
|
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2']
|
||||||
|
else if (path === 'file://models/model1')
|
||||||
|
return ['model.json', 'test.gguf']
|
||||||
|
else return ['model.json']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getDownloadedModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('all models are downloaded', () => {
|
||||||
|
describe('nested folders', () => {
|
||||||
|
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||||
|
else return ['model.json', 'test.gguf']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getDownloadedModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model2/model2-1/model.json',
|
||||||
|
id: '2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('all models are downloaded with uppercased GGUF files', () => {
|
||||||
|
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||||
|
else if (path === 'file://models/model1')
|
||||||
|
return ['model.json', 'test.GGUF']
|
||||||
|
else return ['model.json', 'test.gguf']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getDownloadedModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model2/model2-1/model.json',
|
||||||
|
id: '2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('all models are downloaded - GGUF & Tensort RT', () => {
|
||||||
|
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||||
|
// Mock configured models data
|
||||||
|
const configuredModels = [
|
||||||
|
{
|
||||||
|
id: '1',
|
||||||
|
name: 'Model 1',
|
||||||
|
version: '1.0.0',
|
||||||
|
description: 'Model 1 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model1',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
created: new Date(),
|
||||||
|
updated: new Date(),
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: '2',
|
||||||
|
name: 'Model 2',
|
||||||
|
version: '2.0.0',
|
||||||
|
description: 'Model 2 description',
|
||||||
|
object: {
|
||||||
|
type: 'model',
|
||||||
|
uri: 'http://localhost:5000/models/model2',
|
||||||
|
},
|
||||||
|
format: 'onnx',
|
||||||
|
sources: [],
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
metadata: {},
|
||||||
|
engine: 'test',
|
||||||
|
} as any,
|
||||||
|
]
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||||
|
else if (path === 'file://models/model1')
|
||||||
|
return ['model.json', 'test.gguf']
|
||||||
|
else return ['model.json', 'test.engine']
|
||||||
|
})
|
||||||
|
|
||||||
|
readFileSyncMock.mockImplementation((path) => {
|
||||||
|
if (path.includes('model1'))
|
||||||
|
return JSON.stringify(configuredModels[0])
|
||||||
|
else return JSON.stringify(configuredModels[1])
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await sut.getDownloadedModels()
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
id: '1',
|
||||||
|
}),
|
||||||
|
expect.objectContaining({
|
||||||
|
file_path: 'file://models/model2/model2-1/model.json',
|
||||||
|
id: '2',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('deleteModel', () => {
|
||||||
|
describe('model is a GGUF model', () => {
|
||||||
|
it('should delete the GGUF file', async () => {
|
||||||
|
fs.unlinkSync = jest.fn()
|
||||||
|
const dirMock = dirName as jest.Mock
|
||||||
|
dirMock.mockReturnValue('file://models/model1')
|
||||||
|
|
||||||
|
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||||
|
|
||||||
|
readDirSyncMock.mockImplementation((path) => {
|
||||||
|
return ['model.json', 'test.gguf']
|
||||||
|
})
|
||||||
|
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
await sut.deleteModel({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
expect(fs.unlinkSync).toHaveBeenCalledWith(
|
||||||
|
'file://models/model1/test.gguf'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('no gguf file presented', async () => {
|
||||||
|
fs.unlinkSync = jest.fn()
|
||||||
|
const dirMock = dirName as jest.Mock
|
||||||
|
dirMock.mockReturnValue('file://models/model1')
|
||||||
|
|
||||||
|
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||||
|
|
||||||
|
readDirSyncMock.mockReturnValue(['model.json'])
|
||||||
|
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
await sut.deleteModel({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('delete an imported model', async () => {
|
||||||
|
fs.rm = jest.fn()
|
||||||
|
const dirMock = dirName as jest.Mock
|
||||||
|
dirMock.mockReturnValue('file://models/model1')
|
||||||
|
|
||||||
|
readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
|
||||||
|
|
||||||
|
// MARK: This is a tricky logic implement?
|
||||||
|
// I will just add test for now but will align on the legacy implementation
|
||||||
|
fs.readFileSync = jest.fn().mockReturnValue(
|
||||||
|
JSON.stringify({
|
||||||
|
metadata: {
|
||||||
|
author: 'user',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
await sut.deleteModel({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('delete tensorrt-models', async () => {
|
||||||
|
fs.rm = jest.fn()
|
||||||
|
const dirMock = dirName as jest.Mock
|
||||||
|
dirMock.mockReturnValue('file://models/model1')
|
||||||
|
|
||||||
|
readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
|
||||||
|
|
||||||
|
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||||
|
|
||||||
|
existMock.mockReturnValue(true)
|
||||||
|
|
||||||
|
await sut.deleteModel({
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -22,6 +22,8 @@ import {
|
|||||||
getFileSize,
|
getFileSize,
|
||||||
AllQuantizations,
|
AllQuantizations,
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
|
ModelFile,
|
||||||
|
dirName,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { extractFileName } from './helpers/path'
|
import { extractFileName } from './helpers/path'
|
||||||
@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
]
|
]
|
||||||
private static readonly _tensorRtEngineFormat = '.engine'
|
private static readonly _tensorRtEngineFormat = '.engine'
|
||||||
private static readonly _supportedGpuArch = ['ampere', 'ada']
|
private static readonly _supportedGpuArch = ['ampere', 'ada']
|
||||||
private static readonly _safetensorsRegexs = [
|
|
||||||
/model\.safetensors$/,
|
|
||||||
/model-[0-9]+-of-[0-9]+\.safetensors$/,
|
|
||||||
]
|
|
||||||
private static readonly _pytorchRegexs = [
|
|
||||||
/pytorch_model\.bin$/,
|
|
||||||
/consolidated\.[0-9]+\.pth$/,
|
|
||||||
/pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
|
|
||||||
/.*\.pt$/,
|
|
||||||
]
|
|
||||||
interrupted = false
|
interrupted = false
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @param filePath - The path to the model file to delete.
|
* @param filePath - The path to the model file to delete.
|
||||||
* @returns A Promise that resolves when the model is deleted.
|
* @returns A Promise that resolves when the model is deleted.
|
||||||
*/
|
*/
|
||||||
async deleteModel(modelId: string): Promise<void> {
|
async deleteModel(model: ModelFile): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
|
const dirPath = await dirName(model.file_path)
|
||||||
const jsonFilePath = await joinPath([
|
const jsonFilePath = await joinPath([
|
||||||
dirPath,
|
dirPath,
|
||||||
JanModelExtension._modelMetadataFileName,
|
JanModelExtension._modelMetadataFileName,
|
||||||
@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
await this.readModelMetadata(jsonFilePath)
|
await this.readModelMetadata(jsonFilePath)
|
||||||
) as Model
|
) as Model
|
||||||
|
|
||||||
|
// TODO: This is so tricky?
|
||||||
|
// Should depend on sources?
|
||||||
const isUserImportModel =
|
const isUserImportModel =
|
||||||
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
||||||
if (isUserImportModel) {
|
if (isUserImportModel) {
|
||||||
// just delete the folder
|
// just delete the folder
|
||||||
return fs.rm(dirPath)
|
return fs.rm(dirPath)
|
||||||
}
|
}
|
||||||
@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Saves a model file.
|
|
||||||
* @param model - The model to save.
|
|
||||||
* @returns A Promise that resolves when the model is saved.
|
|
||||||
*/
|
|
||||||
async saveModel(model: Model): Promise<void> {
|
|
||||||
const jsonFilePath = await joinPath([
|
|
||||||
JanModelExtension._homeDir,
|
|
||||||
model.id,
|
|
||||||
JanModelExtension._modelMetadataFileName,
|
|
||||||
])
|
|
||||||
|
|
||||||
try {
|
|
||||||
await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets all downloaded models.
|
* Gets all downloaded models.
|
||||||
* @returns A Promise that resolves with an array of all models.
|
* @returns A Promise that resolves with an array of all models.
|
||||||
*/
|
*/
|
||||||
async getDownloadedModels(): Promise<Model[]> {
|
async getDownloadedModels(): Promise<ModelFile[]> {
|
||||||
return await this.getModelsMetadata(
|
return await this.getModelsMetadata(
|
||||||
async (modelDir: string, model: Model) => {
|
async (modelDir: string, model: Model) => {
|
||||||
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
|
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
|
||||||
@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
): Promise<string | undefined> {
|
): Promise<string | undefined> {
|
||||||
// try to find model.json recursively inside each folder
|
// try to find model.json recursively inside each folder
|
||||||
if (!(await fs.existsSync(folderFullPath))) return undefined
|
if (!(await fs.existsSync(folderFullPath))) return undefined
|
||||||
|
|
||||||
const files: string[] = await fs.readdirSync(folderFullPath)
|
const files: string[] = await fs.readdirSync(folderFullPath)
|
||||||
if (files.length === 0) return undefined
|
if (files.length === 0) return undefined
|
||||||
|
|
||||||
if (files.includes(JanModelExtension._modelMetadataFileName)) {
|
if (files.includes(JanModelExtension._modelMetadataFileName)) {
|
||||||
return joinPath([
|
return joinPath([
|
||||||
folderFullPath,
|
folderFullPath,
|
||||||
@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
|
|
||||||
private async getModelsMetadata(
|
private async getModelsMetadata(
|
||||||
selector?: (path: string, model: Model) => Promise<boolean>
|
selector?: (path: string, model: Model) => Promise<boolean>
|
||||||
): Promise<Model[]> {
|
): Promise<ModelFile[]> {
|
||||||
try {
|
try {
|
||||||
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
|
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
|
||||||
console.debug('Model folder not found')
|
console.debug('Model folder not found')
|
||||||
@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
JanModelExtension._homeDir,
|
JanModelExtension._homeDir,
|
||||||
dirName,
|
dirName,
|
||||||
])
|
])
|
||||||
|
|
||||||
const jsonPath = await this.getModelJsonPath(folderFullPath)
|
const jsonPath = await this.getModelJsonPath(folderFullPath)
|
||||||
|
|
||||||
if (await fs.existsSync(jsonPath)) {
|
if (await fs.existsSync(jsonPath)) {
|
||||||
@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
model.file_path = jsonPath
|
||||||
|
model.file_name = JanModelExtension._modelMetadataFileName
|
||||||
|
|
||||||
if (selector && !(await selector?.(dirName, model))) {
|
if (selector && !(await selector?.(dirName, model))) {
|
||||||
return
|
return
|
||||||
@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
typeof result.value === 'object'
|
typeof result.value === 'object'
|
||||||
? result.value
|
? result.value
|
||||||
: JSON.parse(result.value)
|
: JSON.parse(result.value)
|
||||||
return model as Model
|
return model as ModelFile
|
||||||
} catch {
|
} catch {
|
||||||
console.debug(`Unable to parse model metadata: ${result.value}`)
|
console.debug(`Unable to parse model metadata: ${result.value}`)
|
||||||
}
|
}
|
||||||
@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* Gets all available models.
|
* Gets all available models.
|
||||||
* @returns A Promise that resolves with an array of all models.
|
* @returns A Promise that resolves with an array of all models.
|
||||||
*/
|
*/
|
||||||
async getConfiguredModels(): Promise<Model[]> {
|
async getConfiguredModels(): Promise<ModelFile[]> {
|
||||||
return this.getModelsMetadata()
|
return this.getModelsMetadata()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
modelBinaryPath: string,
|
modelBinaryPath: string,
|
||||||
modelFolderName: string,
|
modelFolderName: string,
|
||||||
modelFolderPath: string
|
modelFolderPath: string
|
||||||
): Promise<Model> {
|
): Promise<ModelFile> {
|
||||||
const fileStats = await fs.fileStat(modelBinaryPath, true)
|
const fileStats = await fs.fileStat(modelBinaryPath, true)
|
||||||
const binaryFileSize = fileStats.size
|
const binaryFileSize = fileStats.size
|
||||||
|
|
||||||
@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
|
|
||||||
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
|
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
|
||||||
|
|
||||||
return model
|
return {
|
||||||
|
...model,
|
||||||
|
file_path: modelFilePath,
|
||||||
|
file_name: JanModelExtension._modelMetadataFileName,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateModelInfo(modelInfo: Partial<Model>): Promise<Model> {
|
async updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile> {
|
||||||
const modelId = modelInfo.id
|
|
||||||
if (modelInfo.id == null) throw new Error('Model ID is required')
|
if (modelInfo.id == null) throw new Error('Model ID is required')
|
||||||
|
|
||||||
const janDataFolderPath = await getJanDataFolderPath()
|
|
||||||
const jsonFilePath = await joinPath([
|
|
||||||
janDataFolderPath,
|
|
||||||
'models',
|
|
||||||
modelId,
|
|
||||||
JanModelExtension._modelMetadataFileName,
|
|
||||||
])
|
|
||||||
const model = JSON.parse(
|
const model = JSON.parse(
|
||||||
await this.readModelMetadata(jsonFilePath)
|
await this.readModelMetadata(modelInfo.file_path)
|
||||||
) as Model
|
) as ModelFile
|
||||||
|
|
||||||
const updatedModel: Model = {
|
const updatedModel: ModelFile = {
|
||||||
...model,
|
...model,
|
||||||
...modelInfo,
|
...modelInfo,
|
||||||
parameters: {
|
parameters: {
|
||||||
@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
...model.metadata,
|
...model.metadata,
|
||||||
...modelInfo.metadata,
|
...modelInfo.metadata,
|
||||||
},
|
},
|
||||||
|
// Should not persist file_path & file_name
|
||||||
|
file_path: undefined,
|
||||||
|
file_name: undefined,
|
||||||
}
|
}
|
||||||
|
|
||||||
await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2))
|
await fs.writeFileSync(
|
||||||
|
modelInfo.file_path,
|
||||||
|
JSON.stringify(updatedModel, null, 2)
|
||||||
|
)
|
||||||
return updatedModel
|
return updatedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,5 +10,6 @@
|
|||||||
"skipLibCheck": true,
|
"skipLibCheck": true,
|
||||||
"rootDir": "./src"
|
"rootDir": "./src"
|
||||||
},
|
},
|
||||||
"include": ["./src"]
|
"include": ["./src"],
|
||||||
|
"exclude": ["**/*.test.ts"]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import {
|
|||||||
ModelEvent,
|
ModelEvent,
|
||||||
getJanDataFolderPath,
|
getJanDataFolderPath,
|
||||||
SystemInformation,
|
SystemInformation,
|
||||||
|
ModelFile,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||||
}
|
}
|
||||||
|
|
||||||
override async loadModel(model: Model): Promise<void> {
|
override async loadModel(model: ModelFile): Promise<void> {
|
||||||
if ((await this.installationState()) === 'Installed')
|
if ((await this.installationState()) === 'Installed')
|
||||||
return super.loadModel(model)
|
return super.loadModel(model)
|
||||||
|
|
||||||
|
|||||||
@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (subprocess?.pid) {
|
if (subprocess?.pid) {
|
||||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||||
const pid = subprocess.pid
|
const pid = subprocess.pid
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
terminate(pid, function (err) {
|
terminate(pid, function (err) {
|
||||||
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
|
|||||||
return tcpPortUsed
|
return tcpPortUsed
|
||||||
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
|
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
|
||||||
.then(() => resolve())
|
.then(() => resolve())
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
killRequest()
|
killRequest()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge'
|
|||||||
|
|
||||||
import { MainViewState } from '@/constants/screens'
|
import { MainViewState } from '@/constants/screens'
|
||||||
|
|
||||||
import { localEngines } from '@/utils/modelEngine'
|
|
||||||
|
|
||||||
import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
|
import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
|
||||||
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
|
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
|
||||||
import {
|
import {
|
||||||
reduceTransparentAtom,
|
reduceTransparentAtom,
|
||||||
selectedSettingAtom,
|
selectedSettingAtom,
|
||||||
} from '@/helpers/atoms/Setting.atom'
|
} from '@/helpers/atoms/Setting.atom'
|
||||||
import { threadsAtom } from '@/helpers/atoms/Thread.atom'
|
import {
|
||||||
|
isDownloadALocalModelAtom,
|
||||||
|
threadsAtom,
|
||||||
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export default function RibbonPanel() {
|
export default function RibbonPanel() {
|
||||||
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
|
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
|
||||||
@ -32,8 +33,9 @@ export default function RibbonPanel() {
|
|||||||
const matches = useMediaQuery('(max-width: 880px)')
|
const matches = useMediaQuery('(max-width: 880px)')
|
||||||
const reduceTransparent = useAtomValue(reduceTransparentAtom)
|
const reduceTransparent = useAtomValue(reduceTransparentAtom)
|
||||||
const setSelectedSetting = useSetAtom(selectedSettingAtom)
|
const setSelectedSetting = useSetAtom(selectedSettingAtom)
|
||||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
|
||||||
const threads = useAtomValue(threadsAtom)
|
const threads = useAtomValue(threadsAtom)
|
||||||
|
const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom)
|
||||||
|
|
||||||
const onMenuClick = (state: MainViewState) => {
|
const onMenuClick = (state: MainViewState) => {
|
||||||
if (mainViewState === state) return
|
if (mainViewState === state) return
|
||||||
@ -43,10 +45,6 @@ export default function RibbonPanel() {
|
|||||||
setEditMessage('')
|
setEditMessage('')
|
||||||
}
|
}
|
||||||
|
|
||||||
const isDownloadALocalModel = downloadedModels.some((x) =>
|
|
||||||
localEngines.includes(x.engine)
|
|
||||||
)
|
|
||||||
|
|
||||||
const RibbonNavMenus = [
|
const RibbonNavMenus = [
|
||||||
{
|
{
|
||||||
name: 'Thread',
|
name: 'Thread',
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast'
|
|||||||
import { MainViewState } from '@/constants/screens'
|
import { MainViewState } from '@/constants/screens'
|
||||||
|
|
||||||
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||||
|
import { useStarterScreen } from '@/hooks/useStarterScreen'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
mainViewStateAtom,
|
mainViewStateAtom,
|
||||||
@ -58,6 +59,8 @@ const TopPanel = () => {
|
|||||||
requestCreateNewThread(assistants[0])
|
requestCreateNewThread(assistants[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const { isShowStarterScreen } = useStarterScreen()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
@ -93,7 +96,7 @@ const TopPanel = () => {
|
|||||||
)}
|
)}
|
||||||
</Fragment>
|
</Fragment>
|
||||||
)}
|
)}
|
||||||
{mainViewState === MainViewState.Thread && (
|
{mainViewState === MainViewState.Thread && !isShowStarterScreen && (
|
||||||
<Button
|
<Button
|
||||||
data-testid="btn-create-thread"
|
data-testid="btn-create-thread"
|
||||||
onClick={onCreateNewThreadClick}
|
onClick={onCreateNewThreadClick}
|
||||||
|
|||||||
@ -46,7 +46,6 @@ import {
|
|||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
|
||||||
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
||||||
import {
|
import {
|
||||||
configuredModelsAtom,
|
configuredModelsAtom,
|
||||||
@ -91,8 +90,6 @@ const ModelDropdown = ({
|
|||||||
const featuredModel = configuredModels.filter((x) =>
|
const featuredModel = configuredModels.filter((x) =>
|
||||||
x.metadata.tags.includes('Featured')
|
x.metadata.tags.includes('Featured')
|
||||||
)
|
)
|
||||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
|
||||||
|
|
||||||
const { updateThreadMetadata } = useCreateNewThread()
|
const { updateThreadMetadata } = useCreateNewThread()
|
||||||
|
|
||||||
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
|
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
|
||||||
@ -191,27 +188,14 @@ const ModelDropdown = ({
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
// Default setting ctx_len for the model for a better onboarding experience
|
|
||||||
// TODO: When Cortex support hardware instructions, we should remove this
|
|
||||||
const defaultContextLength = preserveModelSettings
|
|
||||||
? model?.metadata?.default_ctx_len
|
|
||||||
: 2048
|
|
||||||
const defaultMaxTokens = preserveModelSettings
|
|
||||||
? model?.metadata?.default_max_tokens
|
|
||||||
: 2048
|
|
||||||
const overriddenSettings =
|
const overriddenSettings =
|
||||||
model?.settings.ctx_len && model.settings.ctx_len > 2048
|
model?.settings.ctx_len && model.settings.ctx_len > 4096
|
||||||
? { ctx_len: defaultContextLength ?? 2048 }
|
? { ctx_len: 4096 }
|
||||||
: {}
|
|
||||||
const overriddenParameters =
|
|
||||||
model?.parameters.max_tokens && model.parameters.max_tokens
|
|
||||||
? { max_tokens: defaultMaxTokens ?? 2048 }
|
|
||||||
: {}
|
: {}
|
||||||
|
|
||||||
const modelParams = {
|
const modelParams = {
|
||||||
...model?.parameters,
|
...model?.parameters,
|
||||||
...model?.settings,
|
...model?.settings,
|
||||||
...overriddenParameters,
|
|
||||||
...overriddenSettings,
|
...overriddenSettings,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,6 +206,7 @@ const ModelDropdown = ({
|
|||||||
if (model)
|
if (model)
|
||||||
updateModelParameter(activeThread, {
|
updateModelParameter(activeThread, {
|
||||||
params: modelParams,
|
params: modelParams,
|
||||||
|
modelPath: model.file_path,
|
||||||
modelId: model.id,
|
modelId: model.id,
|
||||||
engine: model.engine,
|
engine: model.engine,
|
||||||
})
|
})
|
||||||
@ -235,7 +220,6 @@ const ModelDropdown = ({
|
|||||||
setThreadModelParams,
|
setThreadModelParams,
|
||||||
updateModelParameter,
|
updateModelParameter,
|
||||||
updateThreadMetadata,
|
updateThreadMetadata,
|
||||||
preserveModelSettings,
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
100
web/containers/ModelLabel/ModelLabel.test.tsx
Normal file
100
web/containers/ModelLabel/ModelLabel.test.tsx
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { render, waitFor, screen } from '@testing-library/react'
|
||||||
|
import { useAtomValue } from 'jotai'
|
||||||
|
import { useActiveModel } from '@/hooks/useActiveModel'
|
||||||
|
import { useSettings } from '@/hooks/useSettings'
|
||||||
|
import ModelLabel from '@/containers/ModelLabel'
|
||||||
|
|
||||||
|
jest.mock('jotai', () => ({
|
||||||
|
useAtomValue: jest.fn(),
|
||||||
|
atom: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
jest.mock('@/hooks/useActiveModel', () => ({
|
||||||
|
useActiveModel: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
jest.mock('@/hooks/useSettings', () => ({
|
||||||
|
useSettings: jest.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('ModelLabel', () => {
|
||||||
|
const mockUseAtomValue = useAtomValue as jest.Mock
|
||||||
|
const mockUseActiveModel = useActiveModel as jest.Mock
|
||||||
|
const mockUseSettings = useSettings as jest.Mock
|
||||||
|
|
||||||
|
const defaultProps: any = {
|
||||||
|
metadata: {
|
||||||
|
author: 'John Doe', // Add the 'author' property with a value
|
||||||
|
tags: ['8B'],
|
||||||
|
size: 100,
|
||||||
|
},
|
||||||
|
compact: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders NotEnoughMemoryLabel when minimumRamModel is greater than totalRam', async () => {
|
||||||
|
mockUseAtomValue
|
||||||
|
.mockReturnValueOnce(0)
|
||||||
|
.mockReturnValueOnce(0)
|
||||||
|
.mockReturnValueOnce(0)
|
||||||
|
mockUseActiveModel.mockReturnValue({
|
||||||
|
activeModel: { metadata: { size: 0 } },
|
||||||
|
})
|
||||||
|
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
|
||||||
|
|
||||||
|
render(<ModelLabel {...defaultProps} />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Not enough RAM')).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders SlowOnYourDeviceLabel when minimumRamModel is less than totalRam but greater than availableRam', async () => {
|
||||||
|
mockUseAtomValue
|
||||||
|
.mockReturnValueOnce(100)
|
||||||
|
.mockReturnValueOnce(50)
|
||||||
|
.mockReturnValueOnce(10)
|
||||||
|
mockUseActiveModel.mockReturnValue({
|
||||||
|
activeModel: { metadata: { size: 0 } },
|
||||||
|
})
|
||||||
|
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
|
||||||
|
|
||||||
|
const props = {
|
||||||
|
...defaultProps,
|
||||||
|
metadata: {
|
||||||
|
...defaultProps.metadata,
|
||||||
|
size: 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
render(<ModelLabel {...props} />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Slow on your device')).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders nothing when minimumRamModel is less than availableRam', () => {
|
||||||
|
mockUseAtomValue
|
||||||
|
.mockReturnValueOnce(100)
|
||||||
|
.mockReturnValueOnce(50)
|
||||||
|
.mockReturnValueOnce(0)
|
||||||
|
mockUseActiveModel.mockReturnValue({
|
||||||
|
activeModel: { metadata: { size: 0 } },
|
||||||
|
})
|
||||||
|
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
|
||||||
|
|
||||||
|
const props = {
|
||||||
|
...defaultProps,
|
||||||
|
metadata: {
|
||||||
|
...defaultProps.metadata,
|
||||||
|
size: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const { container } = render(<ModelLabel {...props} />)
|
||||||
|
expect(container.firstChild).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -10,8 +10,6 @@ import { useSettings } from '@/hooks/useSettings'
|
|||||||
|
|
||||||
import NotEnoughMemoryLabel from './NotEnoughMemoryLabel'
|
import NotEnoughMemoryLabel from './NotEnoughMemoryLabel'
|
||||||
|
|
||||||
import RecommendedLabel from './RecommendedLabel'
|
|
||||||
|
|
||||||
import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel'
|
import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -53,9 +51,7 @@ const ModelLabel = ({ metadata, compact }: Props) => {
|
|||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if (minimumRamModel < availableRam && !compact) {
|
|
||||||
return <RecommendedLabel />
|
|
||||||
}
|
|
||||||
if (minimumRamModel < totalRam && minimumRamModel > availableRam) {
|
if (minimumRamModel < totalRam && minimumRamModel > availableRam) {
|
||||||
return <SlowOnYourDeviceLabel compact={compact} />
|
return <SlowOnYourDeviceLabel compact={compact} />
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
|
|||||||
|
|
||||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||||
|
|
||||||
import { toRuntimeParams } from '@/utils/modelParam'
|
import { extractInferenceParams } from '@/utils/modelParam'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import {
|
||||||
@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
|
const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
|
||||||
|
|
||||||
const messageRequest: MessageRequest = {
|
const messageRequest: MessageRequest = {
|
||||||
id: msgId,
|
id: msgId,
|
||||||
|
|||||||
@ -87,26 +87,28 @@ const SliderRightPanel = ({
|
|||||||
onValueChanged?.(Number(min))
|
onValueChanged?.(Number(min))
|
||||||
setVal(min.toString())
|
setVal(min.toString())
|
||||||
setShowTooltip({ max: false, min: true })
|
setShowTooltip({ max: false, min: true })
|
||||||
|
} else {
|
||||||
|
setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
// Should not accept invalid value or NaN
|
|
||||||
// E.g. anything changes that trigger onValueChanged
|
|
||||||
// Which is incorrect
|
|
||||||
if (Number(e.target.value) > Number(max)) {
|
|
||||||
setVal(max.toString())
|
|
||||||
} else if (
|
|
||||||
Number(e.target.value) < Number(min) ||
|
|
||||||
!e.target.value.length
|
|
||||||
) {
|
|
||||||
setVal(min.toString())
|
|
||||||
} else if (Number.isNaN(Number(e.target.value))) return
|
|
||||||
|
|
||||||
onValueChanged?.(Number(e.target.value))
|
|
||||||
// TODO: How to support negative number input?
|
// TODO: How to support negative number input?
|
||||||
|
// Passthru since it validates again onBlur
|
||||||
if (/^\d*\.?\d*$/.test(e.target.value)) {
|
if (/^\d*\.?\d*$/.test(e.target.value)) {
|
||||||
setVal(e.target.value)
|
setVal(e.target.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Should not accept invalid value or NaN
|
||||||
|
// E.g. anything changes that trigger onValueChanged
|
||||||
|
// Which is incorrect
|
||||||
|
if (
|
||||||
|
Number(e.target.value) > Number(max) ||
|
||||||
|
Number(e.target.value) < Number(min) ||
|
||||||
|
Number.isNaN(Number(e.target.value))
|
||||||
|
) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onValueChanged?.(Number(e.target.value))
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled'
|
|||||||
const IGNORE_SSL = 'ignoreSSLFeature'
|
const IGNORE_SSL = 'ignoreSSLFeature'
|
||||||
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
|
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
|
||||||
const QUICK_ASK_ENABLED = 'quickAskEnabled'
|
const QUICK_ASK_ENABLED = 'quickAskEnabled'
|
||||||
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
|
|
||||||
|
|
||||||
export const janDataFolderPathAtom = atom('')
|
export const janDataFolderPathAtom = atom('')
|
||||||
|
|
||||||
@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
|
|||||||
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
|
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
|
||||||
|
|
||||||
export const hostAtom = atom('http://localhost:1337/')
|
export const hostAtom = atom('http://localhost:1337/')
|
||||||
|
|
||||||
// This feature is to allow user to cache model settings on thread creation
|
|
||||||
export const preserveModelSettingsAtom = atomWithStorage(
|
|
||||||
PRESERVE_MODEL_SETTINGS,
|
|
||||||
false
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { ImportingModel, Model, InferenceEngine } from '@janhq/core'
|
import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
|
||||||
import { atom } from 'jotai'
|
import { atom } from 'jotai'
|
||||||
|
|
||||||
import { localEngines } from '@/utils/modelEngine'
|
import { localEngines } from '@/utils/modelEngine'
|
||||||
@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
export const downloadedModelsAtom = atom<Model[]>([])
|
export const downloadedModelsAtom = atom<ModelFile[]>([])
|
||||||
|
|
||||||
export const updateDownloadedModelAtom = atom(
|
|
||||||
null,
|
|
||||||
(get, set, updatedModel: Model) => {
|
|
||||||
const models: Model[] = get(downloadedModelsAtom).map((c) =>
|
|
||||||
c.id === updatedModel.id ? updatedModel : c
|
|
||||||
)
|
|
||||||
|
|
||||||
set(downloadedModelsAtom, models)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
export const removeDownloadedModelAtom = atom(
|
export const removeDownloadedModelAtom = atom(
|
||||||
null,
|
null,
|
||||||
@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
export const configuredModelsAtom = atom<Model[]>([])
|
export const configuredModelsAtom = atom<ModelFile[]>([])
|
||||||
|
|
||||||
export const defaultModelAtom = atom<Model | undefined>(undefined)
|
export const defaultModelAtom = atom<Model | undefined>(undefined)
|
||||||
|
|
||||||
@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
export const selectedModelAtom = atom<Model | undefined>(undefined)
|
export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
|
||||||
|
|
||||||
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)
|
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)
|
||||||
|
|||||||
@ -152,3 +152,6 @@ export const modalActionThreadAtom = atom<{
|
|||||||
showModal: undefined,
|
showModal: undefined,
|
||||||
thread: undefined,
|
thread: undefined,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const isDownloadALocalModelAtom = atom<boolean>(false)
|
||||||
|
export const isAnyRemoteModelConfiguredAtom = atom<boolean>(false)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useCallback, useEffect, useRef } from 'react'
|
import { useCallback, useEffect, useRef } from 'react'
|
||||||
|
|
||||||
import { EngineManager, Model } from '@janhq/core'
|
import { EngineManager, Model, ModelFile } from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { toaster } from '@/containers/Toast'
|
import { toaster } from '@/containers/Toast'
|
||||||
@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
|||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export const activeModelAtom = atom<Model | undefined>(undefined)
|
export const activeModelAtom = atom<ModelFile | undefined>(undefined)
|
||||||
export const loadModelErrorAtom = atom<string | undefined>(undefined)
|
export const loadModelErrorAtom = atom<string | undefined>(undefined)
|
||||||
|
|
||||||
type ModelState = {
|
type ModelState = {
|
||||||
@ -37,7 +37,7 @@ export function useActiveModel() {
|
|||||||
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
|
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
|
||||||
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
||||||
|
|
||||||
const downloadedModelsRef = useRef<Model[]>([])
|
const downloadedModelsRef = useRef<ModelFile[]>([])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
downloadedModelsRef.current = downloadedModels
|
downloadedModelsRef.current = downloadedModels
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import {
|
|||||||
Thread,
|
Thread,
|
||||||
ThreadAssistantInfo,
|
ThreadAssistantInfo,
|
||||||
ThreadState,
|
ThreadState,
|
||||||
Model,
|
|
||||||
AssistantTool,
|
AssistantTool,
|
||||||
|
ModelFile,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread'
|
|||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
import {
|
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
experimentalFeatureEnabledAtom,
|
|
||||||
preserveModelSettingsAtom,
|
|
||||||
} from '@/helpers/atoms/AppConfig.atom'
|
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
threadsAtom,
|
threadsAtom,
|
||||||
@ -67,7 +64,6 @@ export const useCreateNewThread = () => {
|
|||||||
const copyOverInstructionEnabled = useAtomValue(
|
const copyOverInstructionEnabled = useAtomValue(
|
||||||
copyOverInstructionEnabledAtom
|
copyOverInstructionEnabledAtom
|
||||||
)
|
)
|
||||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
|
|
||||||
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
||||||
@ -80,7 +76,7 @@ export const useCreateNewThread = () => {
|
|||||||
|
|
||||||
const requestCreateNewThread = async (
|
const requestCreateNewThread = async (
|
||||||
assistant: Assistant,
|
assistant: Assistant,
|
||||||
model?: Model | undefined
|
model?: ModelFile | undefined
|
||||||
) => {
|
) => {
|
||||||
// Stop generating if any
|
// Stop generating if any
|
||||||
setIsGeneratingResponse(false)
|
setIsGeneratingResponse(false)
|
||||||
@ -109,19 +105,13 @@ export const useCreateNewThread = () => {
|
|||||||
enabled: true,
|
enabled: true,
|
||||||
settings: assistant.tools && assistant.tools[0].settings,
|
settings: assistant.tools && assistant.tools[0].settings,
|
||||||
}
|
}
|
||||||
const defaultContextLength = preserveModelSettings
|
|
||||||
? defaultModel?.metadata?.default_ctx_len
|
|
||||||
: 2048
|
|
||||||
const defaultMaxTokens = preserveModelSettings
|
|
||||||
? defaultModel?.metadata?.default_max_tokens
|
|
||||||
: 2048
|
|
||||||
const overriddenSettings =
|
const overriddenSettings =
|
||||||
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
|
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
|
||||||
? { ctx_len: defaultContextLength ?? 2048 }
|
? { ctx_len: 4096 }
|
||||||
: {}
|
: {}
|
||||||
|
|
||||||
const overriddenParameters = defaultModel?.parameters.max_tokens
|
const overriddenParameters = defaultModel?.parameters.max_tokens
|
||||||
? { max_tokens: defaultMaxTokens ?? 2048 }
|
? { max_tokens: 4096 }
|
||||||
: {}
|
: {}
|
||||||
|
|
||||||
const createdAt = Date.now()
|
const createdAt = Date.now()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core'
|
import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core'
|
||||||
|
|
||||||
import { useSetAtom } from 'jotai'
|
import { useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -13,8 +13,8 @@ export default function useDeleteModel() {
|
|||||||
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
|
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
|
||||||
|
|
||||||
const deleteModel = useCallback(
|
const deleteModel = useCallback(
|
||||||
async (model: Model) => {
|
async (model: ModelFile) => {
|
||||||
await localDeleteModel(model.id)
|
await localDeleteModel(model)
|
||||||
removeDownloadedModel(model.id)
|
removeDownloadedModel(model.id)
|
||||||
toaster({
|
toaster({
|
||||||
title: 'Model Deletion Successful',
|
title: 'Model Deletion Successful',
|
||||||
@ -28,5 +28,7 @@ export default function useDeleteModel() {
|
|||||||
return { deleteModel }
|
return { deleteModel }
|
||||||
}
|
}
|
||||||
|
|
||||||
const localDeleteModel = async (id: string) =>
|
const localDeleteModel = async (model: ModelFile) =>
|
||||||
extensionManager.get<ModelExtension>(ExtensionTypeEnum.Model)?.deleteModel(id)
|
extensionManager
|
||||||
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
|
?.deleteModel(model)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import {
|
|||||||
Model,
|
Model,
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
ModelExtension,
|
ModelExtension,
|
||||||
|
ModelFile,
|
||||||
events,
|
events,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise<Model | undefined> =>
|
|||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.getDefaultModel()
|
?.getDefaultModel()
|
||||||
|
|
||||||
const getLocalConfiguredModels = async (): Promise<Model[]> =>
|
const getLocalConfiguredModels = async (): Promise<ModelFile[]> =>
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.getConfiguredModels() ?? []
|
?.getConfiguredModels() ?? []
|
||||||
|
|
||||||
const getLocalDownloadedModels = async (): Promise<Model[]> =>
|
const getLocalDownloadedModels = async (): Promise<ModelFile[]> =>
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.getDownloadedModels() ?? []
|
?.getDownloadedModels() ?? []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useCallback, useEffect, useState } from 'react'
|
import { useCallback, useEffect, useState } from 'react'
|
||||||
|
|
||||||
import { Model, InferenceEngine } from '@janhq/core'
|
import { Model, InferenceEngine, ModelFile } from '@janhq/core'
|
||||||
|
|
||||||
import { atom, useAtomValue } from 'jotai'
|
import { atom, useAtomValue } from 'jotai'
|
||||||
|
|
||||||
@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id'
|
|||||||
*/
|
*/
|
||||||
export default function useRecommendedModel() {
|
export default function useRecommendedModel() {
|
||||||
const activeModel = useAtomValue(activeModelAtom)
|
const activeModel = useAtomValue(activeModelAtom)
|
||||||
const [sortedModels, setSortedModels] = useState<Model[]>([])
|
const [sortedModels, setSortedModels] = useState<ModelFile[]>([])
|
||||||
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
|
const [recommendedModel, setRecommendedModel] = useState<
|
||||||
|
ModelFile | undefined
|
||||||
|
>()
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||||
|
|
||||||
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
|
const getAndSortDownloadedModels = useCallback(async (): Promise<
|
||||||
|
ModelFile[]
|
||||||
|
> => {
|
||||||
const models = downloadedModels.sort((a, b) =>
|
const models = downloadedModels.sort((a, b) =>
|
||||||
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
|
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
|
||||||
? 1
|
? 1
|
||||||
|
|||||||
@ -23,7 +23,10 @@ import {
|
|||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
import { compressImage, getBase64 } from '@/utils/base64'
|
import { compressImage, getBase64 } from '@/utils/base64'
|
||||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||||
|
|
||||||
@ -189,8 +192,8 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
if (engineParamsUpdate) setReloadModel(true)
|
if (engineParamsUpdate) setReloadModel(true)
|
||||||
|
|
||||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
const runtimeParams = extractInferenceParams(activeModelParams)
|
||||||
const settingParams = toSettingParams(activeModelParams)
|
const settingParams = extractModelLoadParams(activeModelParams)
|
||||||
|
|
||||||
const prompt = message.trim()
|
const prompt = message.trim()
|
||||||
|
|
||||||
|
|||||||
@ -63,14 +63,9 @@ export function useStarterScreen() {
|
|||||||
(x) => x.apiKey.length > 1
|
(x) => x.apiKey.length > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
let isShowStarterScreen
|
const isShowStarterScreen =
|
||||||
|
|
||||||
isShowStarterScreen =
|
|
||||||
!isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length
|
!isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length
|
||||||
|
|
||||||
// Remove this part when we rework on starter screen
|
|
||||||
isShowStarterScreen = false
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
extensionHasSettings,
|
extensionHasSettings,
|
||||||
isShowStarterScreen,
|
isShowStarterScreen,
|
||||||
|
|||||||
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
// Mock dependencies
|
||||||
|
jest.mock('ulidx')
|
||||||
|
jest.mock('@/extension')
|
||||||
|
|
||||||
|
import useUpdateModelParameters from './useUpdateModelParameters'
|
||||||
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
|
// Mock data
|
||||||
|
let model: any = {
|
||||||
|
id: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
let extension: any = {
|
||||||
|
saveThread: jest.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockThread: any = {
|
||||||
|
id: 'thread-1',
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
created: 0,
|
||||||
|
updated: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('useUpdateModelParameters', () => {
|
||||||
|
beforeAll(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
jest.mock('./useRecommendedModel', () => ({
|
||||||
|
useRecommendedModel: () => ({
|
||||||
|
recommendedModel: model,
|
||||||
|
setRecommendedModel: jest.fn(),
|
||||||
|
downloadedModels: [],
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update model parameters and save thread when params are valid', async () => {
|
||||||
|
const mockValidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
token_limit: 1000,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 0.3,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 12,
|
||||||
|
embedding: true,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(mockThread, mockValidParameters)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
token_limit: 1000,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 0.3,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 12,
|
||||||
|
embedding: true,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not update invalid model parameters', async () => {
|
||||||
|
const mockInvalidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: [1, '<eos>'],
|
||||||
|
temperature: '0.5',
|
||||||
|
token_limit: '1000',
|
||||||
|
top_k: '0.7',
|
||||||
|
top_p: '0.1',
|
||||||
|
stream: 'true',
|
||||||
|
max_tokens: '1000',
|
||||||
|
frequency_penalty: '0.3',
|
||||||
|
presence_penalty: '0.2',
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: '1024',
|
||||||
|
ngl: '12',
|
||||||
|
embedding: 'true',
|
||||||
|
n_parallel: '2',
|
||||||
|
cpu_threads: '4',
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockInvalidParameters
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
max_tokens: 1000,
|
||||||
|
token_limit: 1000,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
cpu_threads: 4,
|
||||||
|
ctx_len: 1024,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
n_parallel: 2,
|
||||||
|
ngl: 12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update valid model parameters only', async () => {
|
||||||
|
const mockInvalidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: ['<eos>'],
|
||||||
|
temperature: -0.5,
|
||||||
|
token_limit: 100.2,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 1.2,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 0,
|
||||||
|
embedding: 'true',
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockInvalidParameters
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>'],
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
token_limit: 100,
|
||||||
|
max_tokens: 1000,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 0,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle missing modelId and engine gracefully', async () => {
|
||||||
|
const mockParametersWithoutModelIdAndEngine: any = {
|
||||||
|
params: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockParametersWithoutModelIdAndEngine
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
},
|
||||||
|
settings: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -4,24 +4,19 @@ import {
|
|||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
ExtensionTypeEnum,
|
ExtensionTypeEnum,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
Model,
|
|
||||||
ModelExtension,
|
|
||||||
Thread,
|
Thread,
|
||||||
ThreadAssistantInfo,
|
ThreadAssistantInfo,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
import useRecommendedModel from './useRecommendedModel'
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
|
||||||
selectedModelAtom,
|
|
||||||
updateDownloadedModelAtom,
|
|
||||||
} from '@/helpers/atoms/Model.atom'
|
|
||||||
import {
|
import {
|
||||||
ModelParams,
|
ModelParams,
|
||||||
getActiveThreadModelParamsAtom,
|
getActiveThreadModelParamsAtom,
|
||||||
@ -31,28 +26,31 @@ import {
|
|||||||
export type UpdateModelParameter = {
|
export type UpdateModelParameter = {
|
||||||
params?: ModelParams
|
params?: ModelParams
|
||||||
modelId?: string
|
modelId?: string
|
||||||
|
modelPath?: string
|
||||||
engine?: InferenceEngine
|
engine?: InferenceEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function useUpdateModelParameters() {
|
export default function useUpdateModelParameters() {
|
||||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
const [selectedModel] = useAtom(selectedModelAtom)
|
||||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||||
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
|
|
||||||
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
|
|
||||||
const { recommendedModel, setRecommendedModel } = useRecommendedModel()
|
|
||||||
|
|
||||||
const updateModelParameter = useCallback(
|
const updateModelParameter = useCallback(
|
||||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
async (thread: Thread, settings: UpdateModelParameter) => {
|
||||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||||
const updatedModelParams = settings.modelId
|
const updatedModelParams = settings.modelId
|
||||||
? toUpdateSettings
|
? toUpdateSettings
|
||||||
: { ...activeModelParams, ...toUpdateSettings }
|
: {
|
||||||
|
...selectedModel?.parameters,
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
...toUpdateSettings,
|
||||||
|
}
|
||||||
|
|
||||||
// update the state
|
// update the state
|
||||||
setThreadModelParams(thread.id, updatedModelParams)
|
setThreadModelParams(thread.id, updatedModelParams)
|
||||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
const runtimeParams = extractInferenceParams(updatedModelParams)
|
||||||
const settingParams = toSettingParams(updatedModelParams)
|
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||||
|
|
||||||
const assistants = thread.assistants.map(
|
const assistants = thread.assistants.map(
|
||||||
(assistant: ThreadAssistantInfo) => {
|
(assistant: ThreadAssistantInfo) => {
|
||||||
@ -75,50 +73,8 @@ export default function useUpdateModelParameters() {
|
|||||||
await extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.saveThread(updatedThread)
|
?.saveThread(updatedThread)
|
||||||
|
|
||||||
// Persists default settings to model file
|
|
||||||
// Do not overwrite ctx_len and max_tokens
|
|
||||||
if (preserveModelFeatureEnabled) {
|
|
||||||
const defaultContextLength = settingParams.ctx_len
|
|
||||||
const defaultMaxTokens = runtimeParams.max_tokens
|
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
|
|
||||||
const { ctx_len, ...toSaveSettings } = settingParams
|
|
||||||
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
|
|
||||||
const { max_tokens, ...toSaveParams } = runtimeParams
|
|
||||||
|
|
||||||
const updatedModel = {
|
|
||||||
id: settings.modelId ?? selectedModel?.id,
|
|
||||||
parameters: {
|
|
||||||
...toSaveSettings,
|
|
||||||
},
|
|
||||||
settings: {
|
|
||||||
...toSaveParams,
|
|
||||||
},
|
|
||||||
metadata: {
|
|
||||||
default_ctx_len: defaultContextLength,
|
|
||||||
default_max_tokens: defaultMaxTokens,
|
|
||||||
},
|
|
||||||
} as Partial<Model>
|
|
||||||
|
|
||||||
const model = await extensionManager
|
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
|
||||||
?.updateModelInfo(updatedModel)
|
|
||||||
if (model) updateDownloadedModel(model)
|
|
||||||
if (selectedModel?.id === model?.id) setSelectedModel(model)
|
|
||||||
if (recommendedModel?.id === model?.id) setRecommendedModel(model)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
[
|
[activeModelParams, selectedModel, setThreadModelParams]
|
||||||
activeModelParams,
|
|
||||||
selectedModel,
|
|
||||||
setThreadModelParams,
|
|
||||||
preserveModelFeatureEnabled,
|
|
||||||
updateDownloadedModel,
|
|
||||||
setSelectedModel,
|
|
||||||
recommendedModel,
|
|
||||||
setRecommendedModel,
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const processStopWords = (params: ModelParams): ModelParams => {
|
const processStopWords = (params: ModelParams): ModelParams => {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import { Model } from '@janhq/core'
|
import { ModelFile } from '@janhq/core'
|
||||||
import { Button, Badge, Tooltip } from '@janhq/joi'
|
import { Button, Badge, Tooltip } from '@janhq/joi'
|
||||||
|
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
@ -38,7 +38,7 @@ import {
|
|||||||
} from '@/helpers/atoms/SystemBar.atom'
|
} from '@/helpers/atoms/SystemBar.atom'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model: Model
|
model: ModelFile
|
||||||
onClick: () => void
|
onClick: () => void
|
||||||
open: string
|
open: string
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
|
|
||||||
import { Model } from '@janhq/core'
|
import { ModelFile } from '@janhq/core'
|
||||||
import { Badge } from '@janhq/joi'
|
import { Badge } from '@janhq/joi'
|
||||||
|
|
||||||
import { twMerge } from 'tailwind-merge'
|
import { twMerge } from 'tailwind-merge'
|
||||||
@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader'
|
|||||||
import { toGibibytes } from '@/utils/converter'
|
import { toGibibytes } from '@/utils/converter'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model: Model
|
model: ModelFile
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelItem: React.FC<Props> = ({ model }) => {
|
const ModelItem: React.FC<Props> = ({ model }) => {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { useMemo } from 'react'
|
import { useMemo } from 'react'
|
||||||
|
|
||||||
import { Model } from '@janhq/core'
|
import { ModelFile } from '@janhq/core'
|
||||||
|
|
||||||
import { useAtomValue } from 'jotai'
|
import { useAtomValue } from 'jotai'
|
||||||
|
|
||||||
@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem'
|
|||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
models: Model[]
|
models: ModelFile[]
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelList = ({ models }: Props) => {
|
const ModelList = ({ models }: Props) => {
|
||||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||||
const sortedModels: Model[] = useMemo(() => {
|
const sortedModels: ModelFile[] = useMemo(() => {
|
||||||
const featuredModels: Model[] = []
|
const featuredModels: ModelFile[] = []
|
||||||
const remoteModels: Model[] = []
|
const remoteModels: ModelFile[] = []
|
||||||
const localModels: Model[] = []
|
const localModels: ModelFile[] = []
|
||||||
const remainingModels: Model[] = []
|
const remainingModels: ModelFile[] = []
|
||||||
models.forEach((m) => {
|
models.forEach((m) => {
|
||||||
if (m.metadata?.tags?.includes('Featured')) {
|
if (m.metadata?.tags?.includes('Featured')) {
|
||||||
featuredModels.push(m)
|
featuredModels.push(m)
|
||||||
|
|||||||
@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
|
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
|
|||||||
const selectedModel = useAtomValue(selectedModelAtom)
|
const selectedModel = useAtomValue(selectedModelAtom)
|
||||||
|
|
||||||
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
|
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
|
||||||
toSettingParams(selectedModel?.settings)
|
extractModelLoadParams(selectedModel?.settings)
|
||||||
)
|
)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedModel) {
|
if (selectedModel) {
|
||||||
setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
|
setCurrentModelSettingParams(
|
||||||
|
extractModelLoadParams(selectedModel?.settings)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}, [selectedModel])
|
}, [selectedModel])
|
||||||
|
|
||||||
const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
|
const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
|
||||||
|
|
||||||
const componentDataRuntimeSetting = getConfigurationsData(
|
const componentDataRuntimeSetting = getConfigurationsData(
|
||||||
modelRuntimeParams,
|
modelRuntimeParams,
|
||||||
|
|||||||
@ -100,6 +100,7 @@ const DataFolder = () => {
|
|||||||
<div className="flex items-center gap-x-3">
|
<div className="flex items-center gap-x-3">
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<Input
|
<Input
|
||||||
|
data-testid="jan-data-folder-input"
|
||||||
value={janDataFolderPath}
|
value={janDataFolderPath}
|
||||||
className="w-full pr-8 sm:w-[240px]"
|
className="w-full pr-8 sm:w-[240px]"
|
||||||
disabled
|
disabled
|
||||||
|
|||||||
@ -21,7 +21,11 @@ const FactoryReset = () => {
|
|||||||
recommended only if the application is in a corrupted state.
|
recommended only if the application is in a corrupted state.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Button theme="destructive" onClick={() => setModalValidation(true)}>
|
<Button
|
||||||
|
data-testid="reset-button"
|
||||||
|
theme="destructive"
|
||||||
|
onClick={() => setModalValidation(true)}
|
||||||
|
>
|
||||||
Reset
|
Reset
|
||||||
</Button>
|
</Button>
|
||||||
<ModalValidation />
|
<ModalValidation />
|
||||||
|
|||||||
154
web/screens/Settings/Advanced/index.test.tsx
Normal file
154
web/screens/Settings/Advanced/index.test.tsx
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
|
import '@testing-library/jest-dom'
|
||||||
|
import Advanced from '.'
|
||||||
|
|
||||||
|
class ResizeObserverMock {
|
||||||
|
observe() {}
|
||||||
|
unobserve() {}
|
||||||
|
disconnect() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
global.ResizeObserver = ResizeObserverMock
|
||||||
|
// @ts-ignore
|
||||||
|
global.window.core = {
|
||||||
|
api: {
|
||||||
|
getAppConfigurations: () => jest.fn(),
|
||||||
|
updateAppConfiguration: () => jest.fn(),
|
||||||
|
relaunch: () => jest.fn(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const setSettingsMock = jest.fn()
|
||||||
|
|
||||||
|
// Mock useSettings hook
|
||||||
|
jest.mock('@/hooks/useSettings', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
useSettings: () => ({
|
||||||
|
readSettings: () => ({
|
||||||
|
run_mode: 'gpu',
|
||||||
|
experimental: false,
|
||||||
|
proxy: false,
|
||||||
|
gpus: [{ name: 'gpu-1' }, { name: 'gpu-2' }],
|
||||||
|
gpus_in_use: ['0'],
|
||||||
|
quick_ask: false,
|
||||||
|
}),
|
||||||
|
setSettings: setSettingsMock,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
import * as toast from '@/containers/Toast'
|
||||||
|
|
||||||
|
jest.mock('@/containers/Toast')
|
||||||
|
|
||||||
|
jest.mock('@janhq/core', () => ({
|
||||||
|
__esModule: true,
|
||||||
|
...jest.requireActual('@janhq/core'),
|
||||||
|
fs: {
|
||||||
|
rm: jest.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Simulate a full advanced settings screen
|
||||||
|
// @ts-ignore
|
||||||
|
global.isMac = false
|
||||||
|
// @ts-ignore
|
||||||
|
global.isWindows = true
|
||||||
|
|
||||||
|
describe('Advanced', () => {
|
||||||
|
it('renders the component', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Experimental Mode')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('updates Experimental enabled', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
let experimentalToggle
|
||||||
|
await waitFor(() => {
|
||||||
|
experimentalToggle = screen.getByTestId(/experimental-switch/i)
|
||||||
|
fireEvent.click(experimentalToggle!)
|
||||||
|
})
|
||||||
|
expect(experimentalToggle).toBeChecked()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('updates Experimental disabled', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
|
||||||
|
let experimentalToggle
|
||||||
|
await waitFor(() => {
|
||||||
|
experimentalToggle = screen.getByTestId(/experimental-switch/i)
|
||||||
|
fireEvent.click(experimentalToggle!)
|
||||||
|
})
|
||||||
|
expect(experimentalToggle).not.toBeChecked()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clears logs', async () => {
|
||||||
|
const jestMock = jest.fn()
|
||||||
|
jest.spyOn(toast, 'toaster').mockImplementation(jestMock)
|
||||||
|
|
||||||
|
render(<Advanced />)
|
||||||
|
let clearLogsButton
|
||||||
|
await waitFor(() => {
|
||||||
|
clearLogsButton = screen.getByTestId(/clear-logs/i)
|
||||||
|
fireEvent.click(clearLogsButton)
|
||||||
|
})
|
||||||
|
expect(clearLogsButton).toBeInTheDocument()
|
||||||
|
expect(jestMock).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('toggles proxy enabled', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
let proxyToggle
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
|
||||||
|
proxyToggle = screen.getByTestId(/proxy-switch/i)
|
||||||
|
fireEvent.click(proxyToggle)
|
||||||
|
})
|
||||||
|
expect(proxyToggle).toBeChecked()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('updates proxy settings', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
let proxyInput
|
||||||
|
await waitFor(() => {
|
||||||
|
const proxyToggle = screen.getByTestId(/proxy-switch/i)
|
||||||
|
fireEvent.click(proxyToggle)
|
||||||
|
proxyInput = screen.getByTestId(/proxy-input/i)
|
||||||
|
fireEvent.change(proxyInput, { target: { value: 'http://proxy.com' } })
|
||||||
|
})
|
||||||
|
expect(proxyInput).toHaveValue('http://proxy.com')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('toggles ignore SSL certificates', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
let ignoreSslToggle
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
|
||||||
|
ignoreSslToggle = screen.getByTestId(/ignore-ssl-switch/i)
|
||||||
|
fireEvent.click(ignoreSslToggle)
|
||||||
|
})
|
||||||
|
expect(ignoreSslToggle).toBeChecked()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders DataFolder component', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
|
||||||
|
expect(screen.getByTestId(/jan-data-folder-input/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders FactoryReset component', async () => {
|
||||||
|
render(<Advanced />)
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
|
||||||
|
expect(screen.getByTestId(/reset-button/i)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -43,19 +43,10 @@ type GPU = {
|
|||||||
name: string
|
name: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const test = [
|
/**
|
||||||
{
|
* Advanced Settings Screen
|
||||||
id: 'test a',
|
* @returns
|
||||||
vram: 2,
|
*/
|
||||||
name: 'nvidia A',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'test',
|
|
||||||
vram: 2,
|
|
||||||
name: 'nvidia B',
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
const Advanced = () => {
|
const Advanced = () => {
|
||||||
const [experimentalEnabled, setExperimentalEnabled] = useAtom(
|
const [experimentalEnabled, setExperimentalEnabled] = useAtom(
|
||||||
experimentalFeatureEnabledAtom
|
experimentalFeatureEnabledAtom
|
||||||
@ -69,7 +60,7 @@ const Advanced = () => {
|
|||||||
|
|
||||||
const [partialProxy, setPartialProxy] = useState<string>(proxy)
|
const [partialProxy, setPartialProxy] = useState<string>(proxy)
|
||||||
const [gpuEnabled, setGpuEnabled] = useState<boolean>(false)
|
const [gpuEnabled, setGpuEnabled] = useState<boolean>(false)
|
||||||
const [gpuList, setGpuList] = useState<GPU[]>(test)
|
const [gpuList, setGpuList] = useState<GPU[]>([])
|
||||||
const [gpusInUse, setGpusInUse] = useState<string[]>([])
|
const [gpusInUse, setGpusInUse] = useState<string[]>([])
|
||||||
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
|
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
|
||||||
null
|
null
|
||||||
@ -87,6 +78,9 @@ const Advanced = () => {
|
|||||||
return y['name']
|
return y['name']
|
||||||
})
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle proxy change
|
||||||
|
*/
|
||||||
const onProxyChange = useCallback(
|
const onProxyChange = useCallback(
|
||||||
(event: ChangeEvent<HTMLInputElement>) => {
|
(event: ChangeEvent<HTMLInputElement>) => {
|
||||||
const value = event.target.value || ''
|
const value = event.target.value || ''
|
||||||
@ -100,6 +94,12 @@ const Advanced = () => {
|
|||||||
[setPartialProxy, setProxy]
|
[setPartialProxy, setProxy]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update Quick Ask Enabled
|
||||||
|
* @param e
|
||||||
|
* @param relaunch
|
||||||
|
* @returns void
|
||||||
|
*/
|
||||||
const updateQuickAskEnabled = async (
|
const updateQuickAskEnabled = async (
|
||||||
e: boolean,
|
e: boolean,
|
||||||
relaunch: boolean = true
|
relaunch: boolean = true
|
||||||
@ -111,6 +111,12 @@ const Advanced = () => {
|
|||||||
if (relaunch) window.core?.api?.relaunch()
|
if (relaunch) window.core?.api?.relaunch()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update Vulkan Enabled
|
||||||
|
* @param e
|
||||||
|
* @param relaunch
|
||||||
|
* @returns void
|
||||||
|
*/
|
||||||
const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => {
|
const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => {
|
||||||
toaster({
|
toaster({
|
||||||
title: 'Reload',
|
title: 'Reload',
|
||||||
@ -123,11 +129,19 @@ const Advanced = () => {
|
|||||||
if (relaunch) window.location.reload()
|
if (relaunch) window.location.reload()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update Experimental Enabled
|
||||||
|
* @param e
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
const updateExperimentalEnabled = async (
|
const updateExperimentalEnabled = async (
|
||||||
e: ChangeEvent<HTMLInputElement>
|
e: ChangeEvent<HTMLInputElement>
|
||||||
) => {
|
) => {
|
||||||
setExperimentalEnabled(e.target.checked)
|
setExperimentalEnabled(e.target.checked)
|
||||||
if (e) return
|
|
||||||
|
// If it checked, we don't need to do anything else
|
||||||
|
// Otherwise have to reset other settings
|
||||||
|
if (e.target.checked) return
|
||||||
|
|
||||||
// It affects other settings, so we need to reset them
|
// It affects other settings, so we need to reset them
|
||||||
const isRelaunch = quickAskEnabled || vulkanEnabled
|
const isRelaunch = quickAskEnabled || vulkanEnabled
|
||||||
@ -136,6 +150,9 @@ const Advanced = () => {
|
|||||||
if (isRelaunch) window.core?.api?.relaunch()
|
if (isRelaunch) window.core?.api?.relaunch()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* useEffect to set GPU enabled if possible
|
||||||
|
*/
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const setUseGpuIfPossible = async () => {
|
const setUseGpuIfPossible = async () => {
|
||||||
const settings = await readSettings()
|
const settings = await readSettings()
|
||||||
@ -149,6 +166,10 @@ const Advanced = () => {
|
|||||||
setUseGpuIfPossible()
|
setUseGpuIfPossible()
|
||||||
}, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled])
|
}, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear logs
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
const clearLogs = async () => {
|
const clearLogs = async () => {
|
||||||
try {
|
try {
|
||||||
await fs.rm(`file://logs`)
|
await fs.rm(`file://logs`)
|
||||||
@ -163,6 +184,11 @@ const Advanced = () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle GPU Change
|
||||||
|
* @param gpuId
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
const handleGPUChange = (gpuId: string) => {
|
const handleGPUChange = (gpuId: string) => {
|
||||||
let updatedGpusInUse = [...gpusInUse]
|
let updatedGpusInUse = [...gpusInUse]
|
||||||
if (updatedGpusInUse.includes(gpuId)) {
|
if (updatedGpusInUse.includes(gpuId)) {
|
||||||
@ -188,6 +214,9 @@ const Advanced = () => {
|
|||||||
const gpuSelectionPlaceHolder =
|
const gpuSelectionPlaceHolder =
|
||||||
gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU"
|
gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle click outside
|
||||||
|
*/
|
||||||
useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
|
useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -204,6 +233,7 @@ const Advanced = () => {
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Switch
|
<Switch
|
||||||
|
data-testid="experimental-switch"
|
||||||
checked={experimentalEnabled}
|
checked={experimentalEnabled}
|
||||||
onChange={updateExperimentalEnabled}
|
onChange={updateExperimentalEnabled}
|
||||||
/>
|
/>
|
||||||
@ -401,11 +431,13 @@ const Advanced = () => {
|
|||||||
|
|
||||||
<div className="flex w-full flex-shrink-0 flex-col items-end gap-2 pr-1 sm:w-1/2">
|
<div className="flex w-full flex-shrink-0 flex-col items-end gap-2 pr-1 sm:w-1/2">
|
||||||
<Switch
|
<Switch
|
||||||
|
data-testid="proxy-switch"
|
||||||
checked={proxyEnabled}
|
checked={proxyEnabled}
|
||||||
onChange={() => setProxyEnabled(!proxyEnabled)}
|
onChange={() => setProxyEnabled(!proxyEnabled)}
|
||||||
/>
|
/>
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<Input
|
<Input
|
||||||
|
data-testid="proxy-input"
|
||||||
placeholder={'http://<user>:<password>@<domain or IP>:<port>'}
|
placeholder={'http://<user>:<password>@<domain or IP>:<port>'}
|
||||||
value={partialProxy}
|
value={partialProxy}
|
||||||
onChange={onProxyChange}
|
onChange={onProxyChange}
|
||||||
@ -428,6 +460,7 @@ const Advanced = () => {
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Switch
|
<Switch
|
||||||
|
data-testid="ignore-ssl-switch"
|
||||||
checked={ignoreSSL}
|
checked={ignoreSSL}
|
||||||
onChange={(e) => setIgnoreSSL(e.target.checked)}
|
onChange={(e) => setIgnoreSSL(e.target.checked)}
|
||||||
/>
|
/>
|
||||||
@ -448,6 +481,7 @@ const Advanced = () => {
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Switch
|
<Switch
|
||||||
|
data-testid="quick-ask-switch"
|
||||||
checked={quickAskEnabled}
|
checked={quickAskEnabled}
|
||||||
onChange={() => {
|
onChange={() => {
|
||||||
toaster({
|
toaster({
|
||||||
@ -471,7 +505,11 @@ const Advanced = () => {
|
|||||||
Clear all logs from Jan app.
|
Clear all logs from Jan app.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Button theme="destructive" onClick={clearLogs}>
|
<Button
|
||||||
|
data-testid="clear-logs"
|
||||||
|
theme="destructive"
|
||||||
|
onClick={clearLogs}
|
||||||
|
>
|
||||||
Clear
|
Clear
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC<Props> = ({
|
|||||||
const { requestCreateNewThread } = useCreateNewThread()
|
const { requestCreateNewThread } = useCreateNewThread()
|
||||||
const setMainViewState = useSetAtom(mainViewStateAtom)
|
const setMainViewState = useSetAtom(mainViewStateAtom)
|
||||||
const assistants = useAtomValue(assistantsAtom)
|
const assistants = useAtomValue(assistantsAtom)
|
||||||
const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null
|
const downloadedModel = downloadedModels.find((md) => md.id === fileName)
|
||||||
|
|
||||||
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
|
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
|
||||||
const defaultModel = useAtomValue(defaultModelAtom)
|
const defaultModel = useAtomValue(defaultModelAtom)
|
||||||
@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC<Props> = ({
|
|||||||
alert('No assistant available')
|
alert('No assistant available')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
await requestCreateNewThread(assistants[0], model)
|
await requestCreateNewThread(assistants[0], downloadedModel)
|
||||||
setMainViewState(MainViewState.Thread)
|
setMainViewState(MainViewState.Thread)
|
||||||
setHfImportingStage('NONE')
|
setHfImportingStage('NONE')
|
||||||
}, [
|
}, [
|
||||||
assistants,
|
assistants,
|
||||||
model,
|
downloadedModel,
|
||||||
requestCreateNewThread,
|
requestCreateNewThread,
|
||||||
setMainViewState,
|
setMainViewState,
|
||||||
setHfImportingStage,
|
setHfImportingStage,
|
||||||
@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC<Props> = ({
|
|||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isDownloaded ? (
|
{downloadedModel ? (
|
||||||
<Button
|
<Button
|
||||||
variant="soft"
|
variant="soft"
|
||||||
className="min-w-[98px]"
|
className="min-w-[98px]"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { memo, useState } from 'react'
|
import { memo, useState } from 'react'
|
||||||
|
|
||||||
import { InferenceEngine, Model } from '@janhq/core'
|
import { InferenceEngine, ModelFile } from '@janhq/core'
|
||||||
import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi'
|
import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi'
|
||||||
import { useAtom } from 'jotai'
|
import { useAtom } from 'jotai'
|
||||||
import {
|
import {
|
||||||
@ -21,7 +21,7 @@ import { localEngines } from '@/utils/modelEngine'
|
|||||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model: Model
|
model: ModelFile
|
||||||
groupTitle?: string
|
groupTitle?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -58,9 +58,21 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
const configuredModels = useAtomValue(configuredModelsAtom)
|
const configuredModels = useAtomValue(configuredModelsAtom)
|
||||||
const setMainViewState = useSetAtom(mainViewStateAtom)
|
const setMainViewState = useSetAtom(mainViewStateAtom)
|
||||||
|
|
||||||
const featuredModel = configuredModels.filter(
|
const recommendModel = ['gemma-2-2b-it', 'llama3.1-8b-instruct']
|
||||||
(x) => x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
|
|
||||||
)
|
const featuredModel = configuredModels.filter((x) => {
|
||||||
|
const manualRecommendModel = configuredModels.filter((x) =>
|
||||||
|
recommendModel.includes(x.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (manualRecommendModel.length === 2) {
|
||||||
|
return x.id === recommendModel[0] || x.id === recommendModel[1]
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
const remoteModel = configuredModels.filter(
|
const remoteModel = configuredModels.filter(
|
||||||
(x) => !localEngines.includes(x.engine)
|
(x) => !localEngines.includes(x.engine)
|
||||||
@ -105,7 +117,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
width={48}
|
width={48}
|
||||||
height={48}
|
height={48}
|
||||||
/>
|
/>
|
||||||
<h1 className="text-base font-semibold">Select a model to start</h1>
|
<h1 className="text-base font-medium">Select a model to start</h1>
|
||||||
<div className="mt-6 w-[320px] md:w-[400px]">
|
<div className="mt-6 w-[320px] md:w-[400px]">
|
||||||
<Fragment>
|
<Fragment>
|
||||||
<div className="relative" ref={refDropdown}>
|
<div className="relative" ref={refDropdown}>
|
||||||
@ -120,7 +132,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
/>
|
/>
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'absolute left-0 top-10 max-h-[240px] w-full overflow-x-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))]',
|
'absolute left-0 top-10 z-20 max-h-[240px] w-full overflow-x-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))]',
|
||||||
!isOpen ? 'invisible' : 'visible'
|
!isOpen ? 'invisible' : 'visible'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@ -205,39 +217,41 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={featModel.id}
|
key={featModel.id}
|
||||||
className="my-2 flex items-center justify-between gap-2 border-b border-[hsla(var(--app-border))] py-4 last:border-none"
|
className="my-2 flex items-center justify-between gap-2 border-b border-[hsla(var(--app-border))] pb-4 pt-1 last:border-none"
|
||||||
>
|
>
|
||||||
<div className="w-full text-left">
|
<div className="w-full text-left">
|
||||||
<h6>{featModel.name}</h6>
|
<h6 className="font-medium">{featModel.name}</h6>
|
||||||
<p className="mt-4 text-[hsla(var(--text-secondary))]">
|
<p className="mt-2 font-medium text-[hsla(var(--text-secondary))]">
|
||||||
{featModel.metadata.author}
|
{featModel.metadata.author}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isDownloading ? (
|
{isDownloading ? (
|
||||||
<div className="flex w-full items-center gap-2">
|
<div className="flex w-full items-center gap-2">
|
||||||
{Object.values(downloadStates).map((item, i) => (
|
{Object.values(downloadStates)
|
||||||
<div
|
.filter((x) => x.modelId === featModel.id)
|
||||||
className="flex w-full items-center gap-2"
|
.map((item, i) => (
|
||||||
key={i}
|
<div
|
||||||
>
|
className="flex w-full items-center gap-2"
|
||||||
<Progress
|
key={i}
|
||||||
className="w-full"
|
>
|
||||||
value={
|
<Progress
|
||||||
formatDownloadPercentage(item?.percent, {
|
className="w-full"
|
||||||
hidePercentage: true,
|
value={
|
||||||
}) as number
|
formatDownloadPercentage(item?.percent, {
|
||||||
}
|
hidePercentage: true,
|
||||||
/>
|
}) as number
|
||||||
<div className="flex items-center justify-between gap-x-2">
|
}
|
||||||
<div className="flex gap-x-2">
|
/>
|
||||||
<span className="font-medium text-[hsla(var(--primary-bg))]">
|
<div className="flex items-center justify-between gap-x-2">
|
||||||
{formatDownloadPercentage(item?.percent)}
|
<div className="flex gap-x-2">
|
||||||
</span>
|
<span className="font-medium text-[hsla(var(--primary-bg))]">
|
||||||
|
{formatDownloadPercentage(item?.percent)}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
))}
|
||||||
))}
|
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex flex-col items-end justify-end gap-2">
|
<div className="flex flex-col items-end justify-end gap-2">
|
||||||
@ -248,7 +262,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
>
|
>
|
||||||
Download
|
Download
|
||||||
</Button>
|
</Button>
|
||||||
<span className="font-medium text-[hsla(var(--text-secondary))]">
|
<span className="text-[hsla(var(--text-secondary))]">
|
||||||
{toGibibytes(featModel.metadata.size)}
|
{toGibibytes(featModel.metadata.size)}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
|
|
||||||
<div className="mb-4 mt-8 flex items-center justify-between">
|
<div className="mb-2 mt-8 flex items-center justify-between">
|
||||||
<h2 className="text-[hsla(var(--text-secondary))]">
|
<h2 className="text-[hsla(var(--text-secondary))]">
|
||||||
Cloud Models
|
Cloud Models
|
||||||
</h2>
|
</h2>
|
||||||
@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={rowIndex}
|
key={rowIndex}
|
||||||
className="my-2 flex items-center justify-center gap-4 md:gap-10"
|
className="my-2 flex items-center gap-4 md:gap-10"
|
||||||
>
|
>
|
||||||
{row.map((remoteEngine) => {
|
{row.map((remoteEngine) => {
|
||||||
const engineLogo = getLogoEngine(
|
const engineLogo = getLogoEngine(
|
||||||
@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<p>
|
<p className="font-medium">
|
||||||
{getTitleByEngine(
|
{getTitleByEngine(
|
||||||
remoteEngine as InferenceEngine
|
remoteEngine as InferenceEngine
|
||||||
)}
|
)}
|
||||||
|
|||||||
@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
import { localEngines } from '@/utils/modelEngine'
|
import { localEngines } from '@/utils/modelEngine'
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import PromptTemplateSetting from './PromptTemplateSetting'
|
import PromptTemplateSetting from './PromptTemplateSetting'
|
||||||
import Tools from './Tools'
|
import Tools from './Tools'
|
||||||
@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
|
|||||||
|
|
||||||
const settings = useMemo(() => {
|
const settings = useMemo(() => {
|
||||||
// runtime setting
|
// runtime setting
|
||||||
const modelRuntimeParams = toRuntimeParams(activeModelParams)
|
const modelRuntimeParams = extractInferenceParams(
|
||||||
|
{
|
||||||
|
...selectedModel?.parameters,
|
||||||
|
...activeModelParams,
|
||||||
|
},
|
||||||
|
selectedModel?.parameters
|
||||||
|
)
|
||||||
const componentDataRuntimeSetting = getConfigurationsData(
|
const componentDataRuntimeSetting = getConfigurationsData(
|
||||||
modelRuntimeParams,
|
modelRuntimeParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
).filter((x) => x.key !== 'prompt_template')
|
).filter((x) => x.key !== 'prompt_template')
|
||||||
|
|
||||||
// engine setting
|
// engine setting
|
||||||
const modelEngineParams = toSettingParams(activeModelParams)
|
const modelEngineParams = extractModelLoadParams(
|
||||||
|
{
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
},
|
||||||
|
selectedModel?.settings
|
||||||
|
)
|
||||||
const componentDataEngineSetting = getConfigurationsData(
|
const componentDataEngineSetting = getConfigurationsData(
|
||||||
modelEngineParams,
|
modelEngineParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
|
|||||||
}, [activeModelParams, selectedModel])
|
}, [activeModelParams, selectedModel])
|
||||||
|
|
||||||
const promptTemplateSettings = useMemo(() => {
|
const promptTemplateSettings = useMemo(() => {
|
||||||
const modelEngineParams = toSettingParams(activeModelParams)
|
const modelEngineParams = extractModelLoadParams({
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
})
|
||||||
const componentDataEngineSetting = getConfigurationsData(
|
const componentDataEngineSetting = getConfigurationsData(
|
||||||
modelEngineParams,
|
modelEngineParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
|
|||||||
35
web/screens/Thread/index.test.tsx
Normal file
35
web/screens/Thread/index.test.tsx
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import React from 'react'
|
||||||
|
import { render, screen } from '@testing-library/react'
|
||||||
|
import ThreadScreen from './index'
|
||||||
|
import { useStarterScreen } from '../../hooks/useStarterScreen'
|
||||||
|
import '@testing-library/jest-dom'
|
||||||
|
|
||||||
|
global.ResizeObserver = class {
|
||||||
|
observe() {}
|
||||||
|
unobserve() {}
|
||||||
|
disconnect() {}
|
||||||
|
}
|
||||||
|
// Mock the useStarterScreen hook
|
||||||
|
jest.mock('@/hooks/useStarterScreen')
|
||||||
|
|
||||||
|
describe('ThreadScreen', () => {
|
||||||
|
it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
|
||||||
|
;(useStarterScreen as jest.Mock).mockReturnValue({
|
||||||
|
isShowStarterScreen: true,
|
||||||
|
extensionHasSettings: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
const { getByText } = render(<ThreadScreen />)
|
||||||
|
expect(getByText('Select a model to start')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders Thread panels when isShowStarterScreen is false', () => {
|
||||||
|
;(useStarterScreen as jest.Mock).mockReturnValue({
|
||||||
|
isShowStarterScreen: false,
|
||||||
|
extensionHasSettings: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
const { getByText } = render(<ThreadScreen />)
|
||||||
|
expect(getByText('Welcome!')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
183
web/utils/modelParam.test.ts
Normal file
183
web/utils/modelParam.test.ts
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// web/utils/modelParam.test.ts
|
||||||
|
import { normalizeValue, validationRules } from './modelParam'
|
||||||
|
|
||||||
|
describe('validationRules', () => {
|
||||||
|
it('should validate temperature correctly', () => {
|
||||||
|
expect(validationRules.temperature(0.5)).toBe(true)
|
||||||
|
expect(validationRules.temperature(2)).toBe(true)
|
||||||
|
expect(validationRules.temperature(0)).toBe(true)
|
||||||
|
expect(validationRules.temperature(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.temperature(2.3)).toBe(false)
|
||||||
|
expect(validationRules.temperature('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate token_limit correctly', () => {
|
||||||
|
expect(validationRules.token_limit(100)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(1)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(0)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(-1)).toBe(false)
|
||||||
|
expect(validationRules.token_limit('100')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate top_k correctly', () => {
|
||||||
|
expect(validationRules.top_k(0.5)).toBe(true)
|
||||||
|
expect(validationRules.top_k(1)).toBe(true)
|
||||||
|
expect(validationRules.top_k(0)).toBe(true)
|
||||||
|
expect(validationRules.top_k(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.top_k(1.1)).toBe(false)
|
||||||
|
expect(validationRules.top_k('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate top_p correctly', () => {
|
||||||
|
expect(validationRules.top_p(0.5)).toBe(true)
|
||||||
|
expect(validationRules.top_p(1)).toBe(true)
|
||||||
|
expect(validationRules.top_p(0)).toBe(true)
|
||||||
|
expect(validationRules.top_p(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.top_p(1.1)).toBe(false)
|
||||||
|
expect(validationRules.top_p('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate stream correctly', () => {
|
||||||
|
expect(validationRules.stream(true)).toBe(true)
|
||||||
|
expect(validationRules.stream(false)).toBe(true)
|
||||||
|
expect(validationRules.stream('true')).toBe(false)
|
||||||
|
expect(validationRules.stream(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate max_tokens correctly', () => {
|
||||||
|
expect(validationRules.max_tokens(100)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(1)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(0)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(-1)).toBe(false)
|
||||||
|
expect(validationRules.max_tokens('100')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate stop correctly', () => {
|
||||||
|
expect(validationRules.stop(['word1', 'word2'])).toBe(true)
|
||||||
|
expect(validationRules.stop([])).toBe(true)
|
||||||
|
expect(validationRules.stop(['word1', 2])).toBe(false)
|
||||||
|
expect(validationRules.stop('word1')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate frequency_penalty correctly', () => {
|
||||||
|
expect(validationRules.frequency_penalty(0.5)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(1)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(0)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.frequency_penalty(1.1)).toBe(false)
|
||||||
|
expect(validationRules.frequency_penalty('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate presence_penalty correctly', () => {
|
||||||
|
expect(validationRules.presence_penalty(0.5)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(1)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(0)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.presence_penalty(1.1)).toBe(false)
|
||||||
|
expect(validationRules.presence_penalty('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate ctx_len correctly', () => {
|
||||||
|
expect(validationRules.ctx_len(1024)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(1)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(0)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(-1)).toBe(false)
|
||||||
|
expect(validationRules.ctx_len('1024')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate ngl correctly', () => {
|
||||||
|
expect(validationRules.ngl(12)).toBe(true)
|
||||||
|
expect(validationRules.ngl(1)).toBe(true)
|
||||||
|
expect(validationRules.ngl(0)).toBe(true)
|
||||||
|
expect(validationRules.ngl(-1)).toBe(false)
|
||||||
|
expect(validationRules.ngl('12')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate embedding correctly', () => {
|
||||||
|
expect(validationRules.embedding(true)).toBe(true)
|
||||||
|
expect(validationRules.embedding(false)).toBe(true)
|
||||||
|
expect(validationRules.embedding('true')).toBe(false)
|
||||||
|
expect(validationRules.embedding(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate n_parallel correctly', () => {
|
||||||
|
expect(validationRules.n_parallel(2)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(1)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(0)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(-1)).toBe(false)
|
||||||
|
expect(validationRules.n_parallel('2')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate cpu_threads correctly', () => {
|
||||||
|
expect(validationRules.cpu_threads(4)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(1)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(0)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(-1)).toBe(false)
|
||||||
|
expect(validationRules.cpu_threads('4')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate prompt_template correctly', () => {
|
||||||
|
expect(validationRules.prompt_template('template')).toBe(true)
|
||||||
|
expect(validationRules.prompt_template('')).toBe(true)
|
||||||
|
expect(validationRules.prompt_template(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate llama_model_path correctly', () => {
|
||||||
|
expect(validationRules.llama_model_path('path')).toBe(true)
|
||||||
|
expect(validationRules.llama_model_path('')).toBe(true)
|
||||||
|
expect(validationRules.llama_model_path(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate mmproj correctly', () => {
|
||||||
|
expect(validationRules.mmproj('mmproj')).toBe(true)
|
||||||
|
expect(validationRules.mmproj('')).toBe(true)
|
||||||
|
expect(validationRules.mmproj(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate vision_model correctly', () => {
|
||||||
|
expect(validationRules.vision_model(true)).toBe(true)
|
||||||
|
expect(validationRules.vision_model(false)).toBe(true)
|
||||||
|
expect(validationRules.vision_model('true')).toBe(false)
|
||||||
|
expect(validationRules.vision_model(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate text_model correctly', () => {
|
||||||
|
expect(validationRules.text_model(true)).toBe(true)
|
||||||
|
expect(validationRules.text_model(false)).toBe(true)
|
||||||
|
expect(validationRules.text_model('true')).toBe(false)
|
||||||
|
expect(validationRules.text_model(1)).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('normalizeValue', () => {
|
||||||
|
it('should normalize ctx_len correctly', () => {
|
||||||
|
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('ctx_len', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('ctx_len', 100)).toBe(100)
|
||||||
|
})
|
||||||
|
it('should normalize token_limit correctly', () => {
|
||||||
|
expect(normalizeValue('token_limit', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('token_limit', '1')).toBe(1)
|
||||||
|
expect(normalizeValue('token_limit', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize max_tokens correctly', () => {
|
||||||
|
expect(normalizeValue('max_tokens', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('max_tokens', '1')).toBe(1)
|
||||||
|
expect(normalizeValue('max_tokens', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize ngl correctly', () => {
|
||||||
|
expect(normalizeValue('ngl', 12.5)).toBe(12)
|
||||||
|
expect(normalizeValue('ngl', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('ngl', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize n_parallel correctly', () => {
|
||||||
|
expect(normalizeValue('n_parallel', 2.5)).toBe(2)
|
||||||
|
expect(normalizeValue('n_parallel', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('n_parallel', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize cpu_threads correctly', () => {
|
||||||
|
expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
|
||||||
|
expect(normalizeValue('cpu_threads', '4')).toBe(4)
|
||||||
|
expect(normalizeValue('cpu_threads', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -1,9 +1,69 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
|
/* eslint-disable @typescript-eslint/naming-convention */
|
||||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||||
|
|
||||||
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export const toRuntimeParams = (
|
/**
|
||||||
modelParams?: ModelParams
|
* Validation rules for model parameters
|
||||||
|
*/
|
||||||
|
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||||
|
temperature: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 2,
|
||||||
|
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
stream: (value: any) => typeof value === 'boolean',
|
||||||
|
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
stop: (value: any) =>
|
||||||
|
Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||||
|
frequency_penalty: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
presence_penalty: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
|
||||||
|
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
embedding: (value: any) => typeof value === 'boolean',
|
||||||
|
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
prompt_template: (value: any) => typeof value === 'string',
|
||||||
|
llama_model_path: (value: any) => typeof value === 'string',
|
||||||
|
mmproj: (value: any) => typeof value === 'string',
|
||||||
|
vision_model: (value: any) => typeof value === 'boolean',
|
||||||
|
text_model: (value: any) => typeof value === 'boolean',
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* There are some parameters that need to be normalized before being sent to the server
|
||||||
|
* E.g. ctx_len should be an integer, but it can be a float from the input field
|
||||||
|
* @param key
|
||||||
|
* @param value
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const normalizeValue = (key: string, value: any) => {
|
||||||
|
if (
|
||||||
|
key === 'token_limit' ||
|
||||||
|
key === 'max_tokens' ||
|
||||||
|
key === 'ctx_len' ||
|
||||||
|
key === 'ngl' ||
|
||||||
|
key === 'n_parallel' ||
|
||||||
|
key === 'cpu_threads'
|
||||||
|
) {
|
||||||
|
// Convert to integer
|
||||||
|
return Math.floor(Number(value))
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract inference parameters from flat model parameters
|
||||||
|
* @param modelParams
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const extractInferenceParams = (
|
||||||
|
modelParams?: ModelParams,
|
||||||
|
originParams?: ModelParams
|
||||||
): ModelRuntimeParams => {
|
): ModelRuntimeParams => {
|
||||||
if (!modelParams) return {}
|
if (!modelParams) return {}
|
||||||
const defaultModelParams: ModelRuntimeParams = {
|
const defaultModelParams: ModelRuntimeParams = {
|
||||||
@ -22,15 +82,35 @@ export const toRuntimeParams = (
|
|||||||
|
|
||||||
for (const [key, value] of Object.entries(modelParams)) {
|
for (const [key, value] of Object.entries(modelParams)) {
|
||||||
if (key in defaultModelParams) {
|
if (key in defaultModelParams) {
|
||||||
Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
|
const validate = validationRules[key]
|
||||||
|
if (validate && !validate(normalizeValue(key, value))) {
|
||||||
|
// Invalid value - fall back to origin value
|
||||||
|
if (originParams && key in originParams) {
|
||||||
|
Object.assign(runtimeParams, {
|
||||||
|
...runtimeParams,
|
||||||
|
[key]: originParams[key as keyof typeof originParams],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Object.assign(runtimeParams, {
|
||||||
|
...runtimeParams,
|
||||||
|
[key]: normalizeValue(key, value),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return runtimeParams
|
return runtimeParams
|
||||||
}
|
}
|
||||||
|
|
||||||
export const toSettingParams = (
|
/**
|
||||||
modelParams?: ModelParams
|
* Extract model load parameters from flat model parameters
|
||||||
|
* @param modelParams
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const extractModelLoadParams = (
|
||||||
|
modelParams?: ModelParams,
|
||||||
|
originParams?: ModelParams
|
||||||
): ModelSettingParams => {
|
): ModelSettingParams => {
|
||||||
if (!modelParams) return {}
|
if (!modelParams) return {}
|
||||||
const defaultSettingParams: ModelSettingParams = {
|
const defaultSettingParams: ModelSettingParams = {
|
||||||
@ -49,7 +129,21 @@ export const toSettingParams = (
|
|||||||
|
|
||||||
for (const [key, value] of Object.entries(modelParams)) {
|
for (const [key, value] of Object.entries(modelParams)) {
|
||||||
if (key in defaultSettingParams) {
|
if (key in defaultSettingParams) {
|
||||||
Object.assign(settingParams, { ...settingParams, [key]: value })
|
const validate = validationRules[key]
|
||||||
|
if (validate && !validate(normalizeValue(key, value))) {
|
||||||
|
// Invalid value - fall back to origin value
|
||||||
|
if (originParams && key in originParams) {
|
||||||
|
Object.assign(modelParams, {
|
||||||
|
...modelParams,
|
||||||
|
[key]: originParams[key as keyof typeof originParams],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Object.assign(settingParams, {
|
||||||
|
...settingParams,
|
||||||
|
[key]: normalizeValue(key, value),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user