diff --git a/.gitignore b/.gitignore index 646e6842a..eaee28a62 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ core/test_results.html coverage .yarn .yarnrc +*.tsbuildinfo diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts index 84250888e..f38cc0b40 100644 --- a/core/src/browser/core.test.ts +++ b/core/src/browser/core.test.ts @@ -1,98 +1,109 @@ -import { openExternalUrl } from './core'; -import { joinPath } from './core'; -import { openFileExplorer } from './core'; -import { getJanDataFolderPath } from './core'; -import { abortDownload } from './core'; -import { getFileSize } from './core'; -import { executeOnMain } from './core'; +import { openExternalUrl } from './core' +import { joinPath } from './core' +import { openFileExplorer } from './core' +import { getJanDataFolderPath } from './core' +import { abortDownload } from './core' +import { getFileSize } from './core' +import { executeOnMain } from './core' -it('should open external url', async () => { - const url = 'http://example.com'; - globalThis.core = { - api: { - openExternalUrl: jest.fn().mockResolvedValue('opened') +describe('test core apis', () => { + it('should open external url', async () => { + const url = 'http://example.com' + globalThis.core = { + api: { + openExternalUrl: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openExternalUrl(url); - expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url); - expect(result).toBe('opened'); -}); + const result = await openExternalUrl(url) + expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url) + expect(result).toBe('opened') + }) - -it('should join paths', async () => { - const paths = ['/path/one', '/path/two']; - globalThis.core = { - api: { - joinPath: jest.fn().mockResolvedValue('/path/one/path/two') + it('should join paths', async () => { + const paths = ['/path/one', '/path/two'] + globalThis.core = { + api: { + joinPath: jest.fn().mockResolvedValue('/path/one/path/two'), + }, } - }; - const result = await joinPath(paths); - expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths); - expect(result).toBe('/path/one/path/two'); -}); + const result = await joinPath(paths) + expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths) + expect(result).toBe('/path/one/path/two') + }) - -it('should open file explorer', async () => { - const path = '/path/to/open'; - globalThis.core = { - api: { - openFileExplorer: jest.fn().mockResolvedValue('opened') + it('should open file explorer', async () => { + const path = '/path/to/open' + globalThis.core = { + api: { + openFileExplorer: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openFileExplorer(path); - expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path); - expect(result).toBe('opened'); -}); + const result = await openFileExplorer(path) + expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path) + expect(result).toBe('opened') + }) - -it('should get jan data folder path', async () => { - globalThis.core = { - api: { - getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data') + it('should get jan data folder path', async () => { + globalThis.core = { + api: { + getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'), + }, } - }; - const result = await getJanDataFolderPath(); - expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled(); - expect(result).toBe('/path/to/jan/data'); -}); + const result = await getJanDataFolderPath() + expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled() + expect(result).toBe('/path/to/jan/data') + }) - -it('should abort download', async () => { - const fileName = 'testFile'; - globalThis.core = { - api: { - abortDownload: jest.fn().mockResolvedValue('aborted') + it('should abort download', async () => { + const fileName = 'testFile' + globalThis.core = { + api: { + abortDownload: jest.fn().mockResolvedValue('aborted'), + }, } - }; - const result = await abortDownload(fileName); - expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName); - expect(result).toBe('aborted'); -}); + const result = await abortDownload(fileName) + expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName) + expect(result).toBe('aborted') + }) - -it('should get file size', async () => { - const url = 'http://example.com/file'; - globalThis.core = { - api: { - getFileSize: jest.fn().mockResolvedValue(1024) + it('should get file size', async () => { + const url = 'http://example.com/file' + globalThis.core = { + api: { + getFileSize: jest.fn().mockResolvedValue(1024), + }, } - }; - const result = await getFileSize(url); - expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url); - expect(result).toBe(1024); -}); + const result = await getFileSize(url) + expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url) + expect(result).toBe(1024) + }) - -it('should execute function on main process', async () => { - const extension = 'testExtension'; - const method = 'testMethod'; - const args = ['arg1', 'arg2']; - globalThis.core = { - api: { - invokeExtensionFunc: jest.fn().mockResolvedValue('result') + it('should execute function on main process', async () => { + const extension = 'testExtension' + const method = 'testMethod' + const args = ['arg1', 'arg2'] + globalThis.core = { + api: { + invokeExtensionFunc: jest.fn().mockResolvedValue('result'), + }, } - }; - const result = await executeOnMain(extension, method, ...args); - expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args); - expect(result).toBe('result'); -}); + const result = await executeOnMain(extension, method, ...args) + expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args) + expect(result).toBe('result') + }) +}) + +describe('dirName - just a pass thru api', () => { + it('should retrieve the directory name from a file path', async () => { + const mockDirName = jest.fn() + globalThis.core = { + api: { + dirName: mockDirName.mockResolvedValue('/path/to'), + }, + } + // Normal file path with extension + const path = '/path/to/file.txt' + await globalThis.core.api.dirName(path) + expect(mockDirName).toHaveBeenCalledWith(path) + }) +}) diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts index fdbceb06b..b19e0b339 100644 --- a/core/src/browser/core.ts +++ b/core/src/browser/core.ts @@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) => const joinPath: (paths: string[]) => Promise = (paths) => globalThis.core.api?.joinPath(paths) +/** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + * @returns {Promise} A promise that resolves the dirname. + */ +const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path) + /** * Retrieve the basename from an url. * @param path - The path to retrieve. @@ -161,5 +168,6 @@ export { systemInformation, showToast, getFileSize, + dirName, FileStat, } diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 7cd9f513e..75354de88 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core' import { events } from '../../events' import { BaseExtension } from '../../extension' import { fs } from '../../fs' -import { MessageRequest, Model, ModelEvent } from '../../../types' +import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types' import { EngineManager } from './EngineManager' /** @@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension { override onLoad() { this.registerEngine() - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } @@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads the model. */ - async loadModel(model: Model): Promise { + async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return Promise.resolve() events.emit(ModelEvent.OnModelReady, model) return Promise.resolve() diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index fb9e4962c..123b9a593 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,6 +1,6 @@ -import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' +import { executeOnMain, systemInformation, dirName } from '../../core' import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' +import { Model, ModelEvent, ModelFile } from '../../../types' import { OAIEngine } from './OAIEngine' /** @@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine { unloadModelFunctionName: string = 'unloadModel' /** - * On extension load, subscribe to events. + * This class represents a base for local inference providers in the OpenAI architecture. + * It extends the OAIEngine class and provides the implementation of loading and unloading models locally. + * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated. + * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped. */ override onLoad() { super.onLoad() // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** * Load the model. */ - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return - const modelFolderName = 'models' - const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) + const modelFolder = await dirName(model.file_path) const systemInfo = await systemInformation() const res = await executeOnMain( this.nodeModule, diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index 5b3089403..040542927 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -4,6 +4,7 @@ import { HuggingFaceRepoData, ImportingModel, Model, + ModelFile, ModelInterface, OptionType, } from '../../types' @@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter network?: { proxy: string; ignoreSSL?: boolean } ): Promise abstract cancelModelDownload(modelId: string): Promise - abstract deleteModel(modelId: string): Promise - abstract saveModel(model: Model): Promise - abstract getDownloadedModels(): Promise - abstract getConfiguredModels(): Promise + abstract deleteModel(model: ModelFile): Promise + abstract getDownloadedModels(): Promise + abstract getConfiguredModels(): Promise abstract importModels(models: ImportingModel[], optionType: OptionType): Promise - abstract updateModelInfo(modelInfo: Partial): Promise + abstract updateModelInfo(modelInfo: Partial): Promise abstract fetchHuggingFaceRepoData(repoId: string): Promise abstract getDefaultModel(): Promise } diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts index 3ada5df1e..5c4daef29 100644 --- a/core/src/node/api/processors/app.test.ts +++ b/core/src/node/api/processors/app.test.ts @@ -1,40 +1,57 @@ -import { App } from './app'; +jest.mock('../../helper', () => ({ + ...jest.requireActual('../../helper'), + getJanDataFolderPath: () => './app', +})) +import { dirname } from 'path' +import { App } from './app' it('should call stopServer', () => { - const app = new App(); - const stopServerMock = jest.fn().mockResolvedValue('Server stopped'); + const app = new App() + const stopServerMock = jest.fn().mockResolvedValue('Server stopped') jest.mock('@janhq/server', () => ({ - stopServer: stopServerMock - })); - const result = app.stopServer(); - expect(stopServerMock).toHaveBeenCalled(); -}); + stopServer: stopServerMock, + })) + app.stopServer() + expect(stopServerMock).toHaveBeenCalled() +}) it('should correctly retrieve basename', () => { - const app = new App(); - const result = app.baseName('/path/to/file.txt'); - expect(result).toBe('file.txt'); -}); + const app = new App() + const result = app.baseName('/path/to/file.txt') + expect(result).toBe('file.txt') +}) it('should correctly identify subdirectories', () => { - const app = new App(); - const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'; - const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'; - const result = app.isSubdirectory(basePath, subPath); - expect(result).toBe(true); -}); + const app = new App() + const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to' + const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir' + const result = app.isSubdirectory(basePath, subPath) + expect(result).toBe(true) +}) it('should correctly join multiple paths', () => { - const app = new App(); - const result = app.joinPath(['path', 'to', 'file']); - const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'; - expect(result).toBe(expectedPath); -}); + const app = new App() + const result = app.joinPath(['path', 'to', 'file']) + const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file' + expect(result).toBe(expectedPath) +}) it('should call correct function with provided arguments using process method', () => { - const app = new App(); - const mockFunc = jest.fn(); - app.joinPath = mockFunc; - app.process('joinPath', ['path1', 'path2']); - expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']); -}); + const app = new App() + const mockFunc = jest.fn() + app.joinPath = mockFunc + app.process('joinPath', ['path1', 'path2']) + expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']) +}) + +it('should retrieve the directory name from a file path (Unix/Windows)', async () => { + const app = new App() + const path = 'C:/Users/John Doe/Desktop/file.txt' + expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop') +}) + +it('should retrieve the directory name when using file protocol', async () => { + const app = new App() + const path = 'file:/models/file.txt' + expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models') +}) diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts index 15460ba56..a0808c5ac 100644 --- a/core/src/node/api/processors/app.ts +++ b/core/src/node/api/processors/app.ts @@ -1,4 +1,4 @@ -import { basename, isAbsolute, join, relative } from 'path' +import { basename, dirname, isAbsolute, join, relative } from 'path' import { Processor } from './Processor' import { @@ -6,6 +6,8 @@ import { appResourcePath, getAppConfigurations as appConfiguration, updateAppConfiguration, + normalizeFilePath, + getJanDataFolderPath, } from '../../helper' export class App implements Processor { @@ -28,6 +30,18 @@ export class App implements Processor { return join(...args) } + /** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + */ + dirName(path: string) { + const arg = + path.startsWith(`file:/`) || path.startsWith(`file:\\`) + ? join(getJanDataFolderPath(), normalizeFilePath(path)) + : path + return dirname(arg) + } + /** * Checks if the given path is a subdirectory of the given directory. * diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index bca11c0a8..8f1ff70bf 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -37,6 +37,7 @@ export enum AppRoute { getAppConfigurations = 'getAppConfigurations', updateAppConfiguration = 'updateAppConfiguration', joinPath = 'joinPath', + dirName = 'dirName', isSubdirectory = 'isSubdirectory', baseName = 'baseName', startServer = 'startServer', diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 1b36a5777..4db956b1e 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -52,3 +52,18 @@ type DownloadSize = { total: number transferred: number } + +/** + * The file metadata + */ +export type FileMetadata = { + /** + * The origin file path. + */ + file_path: string + + /** + * The file name. + */ + file_name: string +} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index f154f7f04..933c698c3 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -1,3 +1,5 @@ +import { FileMetadata } from '../file' + /** * Represents the information about a model. * @stored @@ -151,3 +153,8 @@ export type ModelRuntimeParams = { export type ModelInitFailed = Model & { error: Error } + +/** + * ModelFile is the model.json entity and it's file metadata + */ +export type ModelFile = Model & FileMetadata diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 639c7c8d3..5b5856231 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -1,5 +1,5 @@ import { GpuSetting } from '../miscellaneous' -import { Model } from './modelEntity' +import { Model, ModelFile } from './modelEntity' /** * Model extension for managing models. @@ -29,14 +29,7 @@ export interface ModelInterface { * @param modelId - The ID of the model to delete. * @returns A Promise that resolves when the model has been deleted. */ - deleteModel(modelId: string): Promise - - /** - * Saves a model. - * @param model - The model to save. - * @returns A Promise that resolves when the model has been saved. - */ - saveModel(model: Model): Promise + deleteModel(model: ModelFile): Promise /** * Gets a list of downloaded models. diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts index c13e91119..5d7328053 100644 --- a/electron/tests/e2e/thread.e2e.spec.ts +++ b/electron/tests/e2e/thread.e2e.spec.ts @@ -1,32 +1,29 @@ import { expect } from '@playwright/test' import { page, test, TIMEOUT } from '../config/fixtures' -test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => { +test('Select GPT model from Hub and Chat with Invalid API Key', async ({ + hubPage, +}) => { await hubPage.navigateByMenu() await hubPage.verifyContainerVisible() // Select the first GPT model await page .locator('[data-testid^="use-model-btn"][data-testid*="gpt"]') - .first().click() - - // Attempt to create thread and chat in Thread page - await page - .getByTestId('btn-create-thread') + .first() .click() - await page - .getByTestId('txt-input-chat') - .fill('dummy value') + await page.getByTestId('txt-input-chat').fill('dummy value') - await page - .getByTestId('btn-send-chat') - .click() + await page.getByTestId('btn-send-chat').click() - await page.waitForFunction(() => { - const loaders = document.querySelectorAll('[data-testid$="loader"]'); - return !loaders.length; - }, { timeout: TIMEOUT }); + await page.waitForFunction( + () => { + const loaders = document.querySelectorAll('[data-testid$="loader"]') + return !loaders.length + }, + { timeout: TIMEOUT } + ) const APIKeyError = page.getByTestId('invalid-API-key-error') await expect(APIKeyError).toBeVisible({ diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index d79e076d4..6e825e8fd 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -22,6 +22,7 @@ import { downloadFile, DownloadState, DownloadEvent, + ModelFile, } from '@janhq/core' declare const CUDA_DOWNLOAD_URL: string @@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine { this.nitroProcessInfo = health } - override loadModel(model: Model): Promise { + override loadModel(model: ModelFile): Promise { if (model.engine !== this.provider) return Promise.resolve() this.getNitroProcessHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index edc2d013d..98ca4572f 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry' import { log, getSystemResourceInfo, - Model, InferenceEngine, ModelSettingParams, PromptTemplate, SystemInformation, getJanDataFolderPath, + ModelFile, } from '@janhq/core/node' import { executableNitroFile } from './execute' import terminate from 'terminate' @@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch) */ interface ModelInitOptions { modelFolder: string - model: Model + model: ModelFile } // The PORT to use for the Nitro subprocess const PORT = 3928 @@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise { if (!settings?.ngl) { settings.ngl = 100 } - log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) + log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`) return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { method: 'POST', headers: { @@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise { }) .then((res) => { log( - `[CORTEX]::Debug: Load model success with response ${JSON.stringify( + `[CORTEX]:: Load model success with response ${JSON.stringify( res )}` ) @@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise { async function validateModelStatus(modelId: string): Promise { // Send a GET request to the validation URL. // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. - log(`[CORTEX]::Debug: Validating model ${modelId}`) + log(`[CORTEX]:: Validating model ${modelId}`) return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { method: 'POST', body: JSON.stringify({ @@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise { retryDelay: 300, }).then(async (res: Response) => { log( - `[CORTEX]::Debug: Validate model state with response ${JSON.stringify( + `[CORTEX]:: Validate model state with response ${JSON.stringify( res.status )}` ) @@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise { // Otherwise, return an object with an error message. if (body.model_loaded) { log( - `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( + `[CORTEX]:: Validate model state success with response ${JSON.stringify( body )}` ) @@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise { } const errorBody = await res.text() log( - `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( + `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( res.statusText )}` ) @@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise { async function killSubprocess(): Promise { const controller = new AbortController() setTimeout(() => controller.abort(), 5000) - log(`[CORTEX]::Debug: Request to kill cortex`) + log(`[CORTEX]:: Request to kill cortex`) const killRequest = () => { return fetch(NITRO_HTTP_KILL_URL, { @@ -321,17 +321,17 @@ async function killSubprocess(): Promise { .then(() => tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) ) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch((err) => { log( - `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` ) throw 'PORT_NOT_AVAILABLE' }) } if (subprocess?.pid && process.platform !== 'darwin') { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -341,7 +341,7 @@ async function killSubprocess(): Promise { } else { tcpPortUsed .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .then(() => resolve()) .catch(() => { log( @@ -362,7 +362,7 @@ async function killSubprocess(): Promise { * @returns A promise that resolves when the Nitro subprocess is started. */ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { - log(`[CORTEX]::Debug: Spawning cortex subprocess...`) + log(`[CORTEX]:: Spawning cortex subprocess...`) return new Promise(async (resolve, reject) => { let executableOptions = executableNitroFile( @@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { const args: string[] = ['1', LOCAL_HOST, PORT.toString()] // Execute the binary log( - `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` ) log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`) @@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { // Handle subprocess output subprocess.stdout.on('data', (data: any) => { - log(`[CORTEX]::Debug: ${data}`) + log(`[CORTEX]:: ${data}`) }) subprocess.stderr.on('data', (data: any) => { @@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { }) subprocess.on('close', (code: any) => { - log(`[CORTEX]::Debug: cortex exited with code: ${code}`) + log(`[CORTEX]:: cortex exited with code: ${code}`) subprocess = undefined reject(`child process exited with code ${code}`) }) @@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { tcpPortUsed .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) .then(() => { - log(`[CORTEX]::Debug: cortex is ready`) + log(`[CORTEX]:: cortex is ready`) resolve() }) }) diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json index 6852a1892..72517d540 100644 --- a/extensions/inference-openai-extension/resources/models.json +++ b/extensions/inference-openai-extension/resources/models.json @@ -119,5 +119,65 @@ ] }, "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-preview", + "object": "model", + "name": "OpenAI o1-preview", + "version": "1.0", + "description": "OpenAI o1-preview is a new model with complex reasoning", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-mini", + "object": "model", + "name": "OpenAI o1-mini", + "version": "1.0", + "description": "OpenAI o1-mini is a lightweight reasoning model", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" } ] diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/model-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json index 4a2c61b71..9a406dcf4 100644 --- a/extensions/model-extension/package.json +++ b/extensions/model-extension/package.json @@ -8,6 +8,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" }, diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index c3f3acc77..d36d8ffac 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -27,7 +27,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Compile TypeScript files // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // commonjs(), @@ -62,7 +62,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) commonjs(), // Allow node_modules resolution, so you can use 'external' to control diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts new file mode 100644 index 000000000..6816d7101 --- /dev/null +++ b/extensions/model-extension/src/index.test.ts @@ -0,0 +1,564 @@ +const readDirSyncMock = jest.fn() +const existMock = jest.fn() +const readFileSyncMock = jest.fn() + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + fs: { + existsSync: existMock, + readdirSync: readDirSyncMock, + readFileSync: readFileSyncMock, + fileStat: () => ({ + isDirectory: false, + }), + }, + dirName: jest.fn(), + joinPath: (paths) => paths.join('/'), + ModelExtension: jest.fn(), +})) + +import JanModelExtension from '.' +import { fs, dirName } from '@janhq/core' + +describe('JanModelExtension', () => { + let sut: JanModelExtension + + beforeAll(() => { + // @ts-ignore + sut = new JanModelExtension() + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + describe('getConfiguredModels', () => { + describe("when there's no models are pre-populated", () => { + it('should return empty array', async () => { + // Mock configured models data + const configuredModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getConfiguredModels() + expect(result).toEqual([]) + }) + }) + + describe("when there's are pre-populated models - all flattened", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe("when there's are pre-populated models - there are nested folders", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else if (path.includes('model2/model2-1')) + return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('getDownloadedModels', () => { + describe('no models downloaded', () => { + it('should return empty array', async () => { + // Mock downloaded models data + const downloadedModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getDownloadedModels() + expect(result).toEqual([]) + }) + }) + describe('only one model is downloaded', () => { + describe('flatten folder', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded', () => { + describe('nested folders', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded with uppercased GGUF files', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.GGUF'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe('all models are downloaded - GGUF & Tensort RT', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json', 'test.engine'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('deleteModel', () => { + describe('model is a GGUF model', () => { + it('should delete the GGUF file', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockImplementation((path) => { + return ['model.json', 'test.gguf'] + }) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.gguf' + ) + }) + + it('no gguf file presented', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockReturnValue(['model.json']) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledTimes(0) + }) + + it('delete an imported model', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.gguf']) + + // MARK: This is a tricky logic implement? + // I will just add test for now but will align on the legacy implementation + fs.readFileSync = jest.fn().mockReturnValue( + JSON.stringify({ + metadata: { + author: 'user', + }, + }) + ) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.rm).toHaveBeenCalledWith('file://models/model1') + }) + + it('delete tensorrt-models', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.engine']) + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine') + }) + }) + }) +}) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index e2f68a58c..ac9b06a09 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -22,6 +22,8 @@ import { getFileSize, AllQuantizations, ModelEvent, + ModelFile, + dirName, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension { ] private static readonly _tensorRtEngineFormat = '.engine' private static readonly _supportedGpuArch = ['ampere', 'ada'] - private static readonly _safetensorsRegexs = [ - /model\.safetensors$/, - /model-[0-9]+-of-[0-9]+\.safetensors$/, - ] - private static readonly _pytorchRegexs = [ - /pytorch_model\.bin$/, - /consolidated\.[0-9]+\.pth$/, - /pytorch_model-[0-9]+-of-[0-9]+\.bin$/, - /.*\.pt$/, - ] + interrupted = false /** @@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension { * @param filePath - The path to the model file to delete. * @returns A Promise that resolves when the model is deleted. */ - async deleteModel(modelId: string): Promise { + async deleteModel(model: ModelFile): Promise { try { - const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) + const dirPath = await dirName(model.file_path) const jsonFilePath = await joinPath([ dirPath, JanModelExtension._modelMetadataFileName, @@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension { await this.readModelMetadata(jsonFilePath) ) as Model + // TODO: This is so tricky? + // Should depend on sources? const isUserImportModel = modelInfo.metadata?.author?.toLowerCase() === 'user' - if (isUserImportModel) { + if (isUserImportModel) { // just delete the folder return fs.rm(dirPath) } @@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension { } } - /** - * Saves a model file. - * @param model - The model to save. - * @returns A Promise that resolves when the model is saved. - */ - async saveModel(model: Model): Promise { - const jsonFilePath = await joinPath([ - JanModelExtension._homeDir, - model.id, - JanModelExtension._modelMetadataFileName, - ]) - - try { - await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2)) - } catch (err) { - console.error(err) - } - } - /** * Gets all downloaded models. * @returns A Promise that resolves with an array of all models. */ - async getDownloadedModels(): Promise { + async getDownloadedModels(): Promise { return await this.getModelsMetadata( async (modelDir: string, model: Model) => { if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) @@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension { ): Promise { // try to find model.json recursively inside each folder if (!(await fs.existsSync(folderFullPath))) return undefined + const files: string[] = await fs.readdirSync(folderFullPath) if (files.length === 0) return undefined + if (files.includes(JanModelExtension._modelMetadataFileName)) { return joinPath([ folderFullPath, @@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension { private async getModelsMetadata( selector?: (path: string, model: Model) => Promise - ): Promise { + ): Promise { try { if (!(await fs.existsSync(JanModelExtension._homeDir))) { console.debug('Model folder not found') @@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._homeDir, dirName, ]) + const jsonPath = await this.getModelJsonPath(folderFullPath) if (await fs.existsSync(jsonPath)) { @@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension { }, ] } + model.file_path = jsonPath + model.file_name = JanModelExtension._modelMetadataFileName if (selector && !(await selector?.(dirName, model))) { return @@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension { typeof result.value === 'object' ? result.value : JSON.parse(result.value) - return model as Model + return model as ModelFile } catch { console.debug(`Unable to parse model metadata: ${result.value}`) } @@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension { * Gets all available models. * @returns A Promise that resolves with an array of all models. */ - async getConfiguredModels(): Promise { + async getConfiguredModels(): Promise { return this.getModelsMetadata() } @@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension { modelBinaryPath: string, modelFolderName: string, modelFolderPath: string - ): Promise { + ): Promise { const fileStats = await fs.fileStat(modelBinaryPath, true) const binaryFileSize = fileStats.size @@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension { await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) - return model + return { + ...model, + file_path: modelFilePath, + file_name: JanModelExtension._modelMetadataFileName, + } } - async updateModelInfo(modelInfo: Partial): Promise { - const modelId = modelInfo.id + async updateModelInfo(modelInfo: Partial): Promise { if (modelInfo.id == null) throw new Error('Model ID is required') - const janDataFolderPath = await getJanDataFolderPath() - const jsonFilePath = await joinPath([ - janDataFolderPath, - 'models', - modelId, - JanModelExtension._modelMetadataFileName, - ]) const model = JSON.parse( - await this.readModelMetadata(jsonFilePath) - ) as Model + await this.readModelMetadata(modelInfo.file_path) + ) as ModelFile - const updatedModel: Model = { + const updatedModel: ModelFile = { ...model, ...modelInfo, parameters: { @@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension { ...model.metadata, ...modelInfo.metadata, }, + // Should not persist file_path & file_name + file_path: undefined, + file_name: undefined, } - await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) + await fs.writeFileSync( + modelInfo.file_path, + JSON.stringify(updatedModel, null, 2) + ) return updatedModel } diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json index addd8e127..0d3252934 100644 --- a/extensions/model-extension/tsconfig.json +++ b/extensions/model-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index 189abc706..7f68c43bd 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -23,6 +23,7 @@ import { ModelEvent, getJanDataFolderPath, SystemInformation, + ModelFile, } from '@janhq/core' /** @@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.emit(ModelEvent.OnModelsUpdate, {}) } - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if ((await this.installationState()) === 'Installed') return super.loadModel(model) diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts index c8bc48459..77003389f 100644 --- a/extensions/tensorrt-llm-extension/src/node/index.ts +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -97,7 +97,7 @@ function unloadModel(): Promise { } if (subprocess?.pid) { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -107,7 +107,7 @@ function unloadModel(): Promise { return tcpPortUsed .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) .then(() => resolve()) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch(() => { killRequest() }) diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx index 6bed2b424..7613584e0 100644 --- a/web/containers/Layout/RibbonPanel/index.tsx +++ b/web/containers/Layout/RibbonPanel/index.tsx @@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge' import { MainViewState } from '@/constants/screens' -import { localEngines } from '@/utils/modelEngine' - import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom' import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' + import { reduceTransparentAtom, selectedSettingAtom, } from '@/helpers/atoms/Setting.atom' -import { threadsAtom } from '@/helpers/atoms/Thread.atom' +import { + isDownloadALocalModelAtom, + threadsAtom, +} from '@/helpers/atoms/Thread.atom' export default function RibbonPanel() { const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom) @@ -32,8 +33,9 @@ export default function RibbonPanel() { const matches = useMediaQuery('(max-width: 880px)') const reduceTransparent = useAtomValue(reduceTransparentAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) + const threads = useAtomValue(threadsAtom) + const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom) const onMenuClick = (state: MainViewState) => { if (mainViewState === state) return @@ -43,10 +45,6 @@ export default function RibbonPanel() { setEditMessage('') } - const isDownloadALocalModel = downloadedModels.some((x) => - localEngines.includes(x.engine) - ) - const RibbonNavMenus = [ { name: 'Thread', diff --git a/web/containers/Layout/TopPanel/index.tsx b/web/containers/Layout/TopPanel/index.tsx index 213f7dfa9..aff616973 100644 --- a/web/containers/Layout/TopPanel/index.tsx +++ b/web/containers/Layout/TopPanel/index.tsx @@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast' import { MainViewState } from '@/constants/screens' import { useCreateNewThread } from '@/hooks/useCreateNewThread' +import { useStarterScreen } from '@/hooks/useStarterScreen' import { mainViewStateAtom, @@ -58,6 +59,8 @@ const TopPanel = () => { requestCreateNewThread(assistants[0]) } + const { isShowStarterScreen } = useStarterScreen() + return (
{ )} )} - {mainViewState === MainViewState.Thread && ( + {mainViewState === MainViewState.Thread && !isShowStarterScreen && ( diff --git a/web/screens/Settings/Advanced/index.test.tsx b/web/screens/Settings/Advanced/index.test.tsx new file mode 100644 index 000000000..10ea810b1 --- /dev/null +++ b/web/screens/Settings/Advanced/index.test.tsx @@ -0,0 +1,154 @@ +import React from 'react' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import '@testing-library/jest-dom' +import Advanced from '.' + +class ResizeObserverMock { + observe() {} + unobserve() {} + disconnect() {} +} + +global.ResizeObserver = ResizeObserverMock +// @ts-ignore +global.window.core = { + api: { + getAppConfigurations: () => jest.fn(), + updateAppConfiguration: () => jest.fn(), + relaunch: () => jest.fn(), + }, +} + +const setSettingsMock = jest.fn() + +// Mock useSettings hook +jest.mock('@/hooks/useSettings', () => ({ + __esModule: true, + useSettings: () => ({ + readSettings: () => ({ + run_mode: 'gpu', + experimental: false, + proxy: false, + gpus: [{ name: 'gpu-1' }, { name: 'gpu-2' }], + gpus_in_use: ['0'], + quick_ask: false, + }), + setSettings: setSettingsMock, + }), +})) + +import * as toast from '@/containers/Toast' + +jest.mock('@/containers/Toast') + +jest.mock('@janhq/core', () => ({ + __esModule: true, + ...jest.requireActual('@janhq/core'), + fs: { + rm: jest.fn(), + }, +})) + +// Simulate a full advanced settings screen +// @ts-ignore +global.isMac = false +// @ts-ignore +global.isWindows = true + +describe('Advanced', () => { + it('renders the component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Experimental Mode')).toBeInTheDocument() + expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument() + expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument() + expect(screen.getByText('Jan Data Folder')).toBeInTheDocument() + expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument() + }) + }) + + it('updates Experimental enabled', async () => { + render() + let experimentalToggle + await waitFor(() => { + experimentalToggle = screen.getByTestId(/experimental-switch/i) + fireEvent.click(experimentalToggle!) + }) + expect(experimentalToggle).toBeChecked() + }) + + it('updates Experimental disabled', async () => { + render() + + let experimentalToggle + await waitFor(() => { + experimentalToggle = screen.getByTestId(/experimental-switch/i) + fireEvent.click(experimentalToggle!) + }) + expect(experimentalToggle).not.toBeChecked() + }) + + it('clears logs', async () => { + const jestMock = jest.fn() + jest.spyOn(toast, 'toaster').mockImplementation(jestMock) + + render() + let clearLogsButton + await waitFor(() => { + clearLogsButton = screen.getByTestId(/clear-logs/i) + fireEvent.click(clearLogsButton) + }) + expect(clearLogsButton).toBeInTheDocument() + expect(jestMock).toHaveBeenCalled() + }) + + it('toggles proxy enabled', async () => { + render() + let proxyToggle + await waitFor(() => { + expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument() + proxyToggle = screen.getByTestId(/proxy-switch/i) + fireEvent.click(proxyToggle) + }) + expect(proxyToggle).toBeChecked() + }) + + it('updates proxy settings', async () => { + render() + let proxyInput + await waitFor(() => { + const proxyToggle = screen.getByTestId(/proxy-switch/i) + fireEvent.click(proxyToggle) + proxyInput = screen.getByTestId(/proxy-input/i) + fireEvent.change(proxyInput, { target: { value: 'http://proxy.com' } }) + }) + expect(proxyInput).toHaveValue('http://proxy.com') + }) + + it('toggles ignore SSL certificates', async () => { + render() + let ignoreSslToggle + await waitFor(() => { + expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument() + ignoreSslToggle = screen.getByTestId(/ignore-ssl-switch/i) + fireEvent.click(ignoreSslToggle) + }) + expect(ignoreSslToggle).toBeChecked() + }) + + it('renders DataFolder component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Jan Data Folder')).toBeInTheDocument() + expect(screen.getByTestId(/jan-data-folder-input/i)).toBeInTheDocument() + }) + }) + + it('renders FactoryReset component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument() + expect(screen.getByTestId(/reset-button/i)).toBeInTheDocument() + }) + }) +}) diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index f132f81e7..1384f5688 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -43,19 +43,10 @@ type GPU = { name: string } -const test = [ - { - id: 'test a', - vram: 2, - name: 'nvidia A', - }, - { - id: 'test', - vram: 2, - name: 'nvidia B', - }, -] - +/** + * Advanced Settings Screen + * @returns + */ const Advanced = () => { const [experimentalEnabled, setExperimentalEnabled] = useAtom( experimentalFeatureEnabledAtom @@ -69,7 +60,7 @@ const Advanced = () => { const [partialProxy, setPartialProxy] = useState(proxy) const [gpuEnabled, setGpuEnabled] = useState(false) - const [gpuList, setGpuList] = useState(test) + const [gpuList, setGpuList] = useState([]) const [gpusInUse, setGpusInUse] = useState([]) const [dropdownOptions, setDropdownOptions] = useState( null @@ -87,6 +78,9 @@ const Advanced = () => { return y['name'] }) + /** + * Handle proxy change + */ const onProxyChange = useCallback( (event: ChangeEvent) => { const value = event.target.value || '' @@ -100,6 +94,12 @@ const Advanced = () => { [setPartialProxy, setProxy] ) + /** + * Update Quick Ask Enabled + * @param e + * @param relaunch + * @returns void + */ const updateQuickAskEnabled = async ( e: boolean, relaunch: boolean = true @@ -111,6 +111,12 @@ const Advanced = () => { if (relaunch) window.core?.api?.relaunch() } + /** + * Update Vulkan Enabled + * @param e + * @param relaunch + * @returns void + */ const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => { toaster({ title: 'Reload', @@ -123,11 +129,19 @@ const Advanced = () => { if (relaunch) window.location.reload() } + /** + * Update Experimental Enabled + * @param e + * @returns + */ const updateExperimentalEnabled = async ( e: ChangeEvent ) => { setExperimentalEnabled(e.target.checked) - if (e) return + + // If it checked, we don't need to do anything else + // Otherwise have to reset other settings + if (e.target.checked) return // It affects other settings, so we need to reset them const isRelaunch = quickAskEnabled || vulkanEnabled @@ -136,6 +150,9 @@ const Advanced = () => { if (isRelaunch) window.core?.api?.relaunch() } + /** + * useEffect to set GPU enabled if possible + */ useEffect(() => { const setUseGpuIfPossible = async () => { const settings = await readSettings() @@ -149,6 +166,10 @@ const Advanced = () => { setUseGpuIfPossible() }, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled]) + /** + * Clear logs + * @returns + */ const clearLogs = async () => { try { await fs.rm(`file://logs`) @@ -163,6 +184,11 @@ const Advanced = () => { }) } + /** + * Handle GPU Change + * @param gpuId + * @returns + */ const handleGPUChange = (gpuId: string) => { let updatedGpusInUse = [...gpusInUse] if (updatedGpusInUse.includes(gpuId)) { @@ -188,6 +214,9 @@ const Advanced = () => { const gpuSelectionPlaceHolder = gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU" + /** + * Handle click outside + */ useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) return ( @@ -204,6 +233,7 @@ const Advanced = () => {

@@ -401,11 +431,13 @@ const Advanced = () => {
setProxyEnabled(!proxyEnabled)} />
:@:'} value={partialProxy} onChange={onProxyChange} @@ -428,6 +460,7 @@ const Advanced = () => {

setIgnoreSSL(e.target.checked)} /> @@ -448,6 +481,7 @@ const Advanced = () => {

{ toaster({ @@ -471,7 +505,11 @@ const Advanced = () => { Clear all logs from Jan app.

- diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index 951a11d59..c3f09f171 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC = ({ const { requestCreateNewThread } = useCreateNewThread() const setMainViewState = useSetAtom(mainViewStateAtom) const assistants = useAtomValue(assistantsAtom) - const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null + const downloadedModel = downloadedModels.find((md) => md.id === fileName) const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) const defaultModel = useAtomValue(defaultModelAtom) @@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC = ({ alert('No assistant available') return } - await requestCreateNewThread(assistants[0], model) + await requestCreateNewThread(assistants[0], downloadedModel) setMainViewState(MainViewState.Thread) setHfImportingStage('NONE') }, [ assistants, - model, + downloadedModel, requestCreateNewThread, setMainViewState, setHfImportingStage, @@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC = ({ - {isDownloaded ? ( + {downloadedModel ? ( - + {toGibibytes(featModel.metadata.size)} @@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { ) })} -
+

Cloud Models

@@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { return (
{row.map((remoteEngine) => { const engineLogo = getLogoEngine( @@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { /> )} -

+

{getTitleByEngine( remoteEngine as InferenceEngine )} diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 9e7cdf7d8..e7d0a27b9 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { getConfigurationsData } from '@/utils/componentSettings' import { localEngines } from '@/utils/modelEngine' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import PromptTemplateSetting from './PromptTemplateSetting' import Tools from './Tools' @@ -68,14 +71,26 @@ const ThreadRightPanel = () => { const settings = useMemo(() => { // runtime setting - const modelRuntimeParams = toRuntimeParams(activeModelParams) + const modelRuntimeParams = extractInferenceParams( + { + ...selectedModel?.parameters, + ...activeModelParams, + }, + selectedModel?.parameters + ) const componentDataRuntimeSetting = getConfigurationsData( modelRuntimeParams, selectedModel ).filter((x) => x.key !== 'prompt_template') // engine setting - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams( + { + ...selectedModel?.settings, + ...activeModelParams, + }, + selectedModel?.settings + ) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel @@ -126,7 +141,10 @@ const ThreadRightPanel = () => { }, [activeModelParams, selectedModel]) const promptTemplateSettings = useMemo(() => { - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams({ + ...selectedModel?.settings, + ...activeModelParams, + }) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel diff --git a/web/screens/Thread/index.test.tsx b/web/screens/Thread/index.test.tsx new file mode 100644 index 000000000..01af0ffc5 --- /dev/null +++ b/web/screens/Thread/index.test.tsx @@ -0,0 +1,35 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import ThreadScreen from './index' +import { useStarterScreen } from '../../hooks/useStarterScreen' +import '@testing-library/jest-dom' + +global.ResizeObserver = class { + observe() {} + unobserve() {} + disconnect() {} +} +// Mock the useStarterScreen hook +jest.mock('@/hooks/useStarterScreen') + +describe('ThreadScreen', () => { + it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => { + ;(useStarterScreen as jest.Mock).mockReturnValue({ + isShowStarterScreen: true, + extensionHasSettings: false, + }) + + const { getByText } = render() + expect(getByText('Select a model to start')).toBeInTheDocument() + }) + + it('renders Thread panels when isShowStarterScreen is false', () => { + ;(useStarterScreen as jest.Mock).mockReturnValue({ + isShowStarterScreen: false, + extensionHasSettings: false, + }) + + const { getByText } = render() + expect(getByText('Welcome!')).toBeInTheDocument() + }) +}) diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts new file mode 100644 index 000000000..f1b858955 --- /dev/null +++ b/web/utils/modelParam.test.ts @@ -0,0 +1,183 @@ +// web/utils/modelParam.test.ts +import { normalizeValue, validationRules } from './modelParam' + +describe('validationRules', () => { + it('should validate temperature correctly', () => { + expect(validationRules.temperature(0.5)).toBe(true) + expect(validationRules.temperature(2)).toBe(true) + expect(validationRules.temperature(0)).toBe(true) + expect(validationRules.temperature(-0.1)).toBe(false) + expect(validationRules.temperature(2.3)).toBe(false) + expect(validationRules.temperature('0.5')).toBe(false) + }) + + it('should validate token_limit correctly', () => { + expect(validationRules.token_limit(100)).toBe(true) + expect(validationRules.token_limit(1)).toBe(true) + expect(validationRules.token_limit(0)).toBe(true) + expect(validationRules.token_limit(-1)).toBe(false) + expect(validationRules.token_limit('100')).toBe(false) + }) + + it('should validate top_k correctly', () => { + expect(validationRules.top_k(0.5)).toBe(true) + expect(validationRules.top_k(1)).toBe(true) + expect(validationRules.top_k(0)).toBe(true) + expect(validationRules.top_k(-0.1)).toBe(false) + expect(validationRules.top_k(1.1)).toBe(false) + expect(validationRules.top_k('0.5')).toBe(false) + }) + + it('should validate top_p correctly', () => { + expect(validationRules.top_p(0.5)).toBe(true) + expect(validationRules.top_p(1)).toBe(true) + expect(validationRules.top_p(0)).toBe(true) + expect(validationRules.top_p(-0.1)).toBe(false) + expect(validationRules.top_p(1.1)).toBe(false) + expect(validationRules.top_p('0.5')).toBe(false) + }) + + it('should validate stream correctly', () => { + expect(validationRules.stream(true)).toBe(true) + expect(validationRules.stream(false)).toBe(true) + expect(validationRules.stream('true')).toBe(false) + expect(validationRules.stream(1)).toBe(false) + }) + + it('should validate max_tokens correctly', () => { + expect(validationRules.max_tokens(100)).toBe(true) + expect(validationRules.max_tokens(1)).toBe(true) + expect(validationRules.max_tokens(0)).toBe(true) + expect(validationRules.max_tokens(-1)).toBe(false) + expect(validationRules.max_tokens('100')).toBe(false) + }) + + it('should validate stop correctly', () => { + expect(validationRules.stop(['word1', 'word2'])).toBe(true) + expect(validationRules.stop([])).toBe(true) + expect(validationRules.stop(['word1', 2])).toBe(false) + expect(validationRules.stop('word1')).toBe(false) + }) + + it('should validate frequency_penalty correctly', () => { + expect(validationRules.frequency_penalty(0.5)).toBe(true) + expect(validationRules.frequency_penalty(1)).toBe(true) + expect(validationRules.frequency_penalty(0)).toBe(true) + expect(validationRules.frequency_penalty(-0.1)).toBe(false) + expect(validationRules.frequency_penalty(1.1)).toBe(false) + expect(validationRules.frequency_penalty('0.5')).toBe(false) + }) + + it('should validate presence_penalty correctly', () => { + expect(validationRules.presence_penalty(0.5)).toBe(true) + expect(validationRules.presence_penalty(1)).toBe(true) + expect(validationRules.presence_penalty(0)).toBe(true) + expect(validationRules.presence_penalty(-0.1)).toBe(false) + expect(validationRules.presence_penalty(1.1)).toBe(false) + expect(validationRules.presence_penalty('0.5')).toBe(false) + }) + + it('should validate ctx_len correctly', () => { + expect(validationRules.ctx_len(1024)).toBe(true) + expect(validationRules.ctx_len(1)).toBe(true) + expect(validationRules.ctx_len(0)).toBe(true) + expect(validationRules.ctx_len(-1)).toBe(false) + expect(validationRules.ctx_len('1024')).toBe(false) + }) + + it('should validate ngl correctly', () => { + expect(validationRules.ngl(12)).toBe(true) + expect(validationRules.ngl(1)).toBe(true) + expect(validationRules.ngl(0)).toBe(true) + expect(validationRules.ngl(-1)).toBe(false) + expect(validationRules.ngl('12')).toBe(false) + }) + + it('should validate embedding correctly', () => { + expect(validationRules.embedding(true)).toBe(true) + expect(validationRules.embedding(false)).toBe(true) + expect(validationRules.embedding('true')).toBe(false) + expect(validationRules.embedding(1)).toBe(false) + }) + + it('should validate n_parallel correctly', () => { + expect(validationRules.n_parallel(2)).toBe(true) + expect(validationRules.n_parallel(1)).toBe(true) + expect(validationRules.n_parallel(0)).toBe(true) + expect(validationRules.n_parallel(-1)).toBe(false) + expect(validationRules.n_parallel('2')).toBe(false) + }) + + it('should validate cpu_threads correctly', () => { + expect(validationRules.cpu_threads(4)).toBe(true) + expect(validationRules.cpu_threads(1)).toBe(true) + expect(validationRules.cpu_threads(0)).toBe(true) + expect(validationRules.cpu_threads(-1)).toBe(false) + expect(validationRules.cpu_threads('4')).toBe(false) + }) + + it('should validate prompt_template correctly', () => { + expect(validationRules.prompt_template('template')).toBe(true) + expect(validationRules.prompt_template('')).toBe(true) + expect(validationRules.prompt_template(123)).toBe(false) + }) + + it('should validate llama_model_path correctly', () => { + expect(validationRules.llama_model_path('path')).toBe(true) + expect(validationRules.llama_model_path('')).toBe(true) + expect(validationRules.llama_model_path(123)).toBe(false) + }) + + it('should validate mmproj correctly', () => { + expect(validationRules.mmproj('mmproj')).toBe(true) + expect(validationRules.mmproj('')).toBe(true) + expect(validationRules.mmproj(123)).toBe(false) + }) + + it('should validate vision_model correctly', () => { + expect(validationRules.vision_model(true)).toBe(true) + expect(validationRules.vision_model(false)).toBe(true) + expect(validationRules.vision_model('true')).toBe(false) + expect(validationRules.vision_model(1)).toBe(false) + }) + + it('should validate text_model correctly', () => { + expect(validationRules.text_model(true)).toBe(true) + expect(validationRules.text_model(false)).toBe(true) + expect(validationRules.text_model('true')).toBe(false) + expect(validationRules.text_model(1)).toBe(false) + }) +}) + +describe('normalizeValue', () => { + it('should normalize ctx_len correctly', () => { + expect(normalizeValue('ctx_len', 100.5)).toBe(100) + expect(normalizeValue('ctx_len', '2')).toBe(2) + expect(normalizeValue('ctx_len', 100)).toBe(100) + }) + it('should normalize token_limit correctly', () => { + expect(normalizeValue('token_limit', 100.5)).toBe(100) + expect(normalizeValue('token_limit', '1')).toBe(1) + expect(normalizeValue('token_limit', 0)).toBe(0) + }) + it('should normalize max_tokens correctly', () => { + expect(normalizeValue('max_tokens', 100.5)).toBe(100) + expect(normalizeValue('max_tokens', '1')).toBe(1) + expect(normalizeValue('max_tokens', 0)).toBe(0) + }) + it('should normalize ngl correctly', () => { + expect(normalizeValue('ngl', 12.5)).toBe(12) + expect(normalizeValue('ngl', '2')).toBe(2) + expect(normalizeValue('ngl', 0)).toBe(0) + }) + it('should normalize n_parallel correctly', () => { + expect(normalizeValue('n_parallel', 2.5)).toBe(2) + expect(normalizeValue('n_parallel', '2')).toBe(2) + expect(normalizeValue('n_parallel', 0)).toBe(0) + }) + it('should normalize cpu_threads correctly', () => { + expect(normalizeValue('cpu_threads', 4.5)).toBe(4) + expect(normalizeValue('cpu_threads', '4')).toBe(4) + expect(normalizeValue('cpu_threads', 0)).toBe(0) + }) +}) diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts index a6d144c3e..dda9cf761 100644 --- a/web/utils/modelParam.ts +++ b/web/utils/modelParam.ts @@ -1,9 +1,69 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/naming-convention */ import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' import { ModelParams } from '@/helpers/atoms/Thread.atom' -export const toRuntimeParams = ( - modelParams?: ModelParams +/** + * Validation rules for model parameters + */ +export const validationRules: { [key: string]: (value: any) => boolean } = { + temperature: (value: any) => + typeof value === 'number' && value >= 0 && value <= 2, + token_limit: (value: any) => Number.isInteger(value) && value >= 0, + top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + stream: (value: any) => typeof value === 'boolean', + max_tokens: (value: any) => Number.isInteger(value) && value >= 0, + stop: (value: any) => + Array.isArray(value) && value.every((v) => typeof v === 'string'), + frequency_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + presence_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + + ctx_len: (value: any) => Number.isInteger(value) && value >= 0, + ngl: (value: any) => Number.isInteger(value) && value >= 0, + embedding: (value: any) => typeof value === 'boolean', + n_parallel: (value: any) => Number.isInteger(value) && value >= 0, + cpu_threads: (value: any) => Number.isInteger(value) && value >= 0, + prompt_template: (value: any) => typeof value === 'string', + llama_model_path: (value: any) => typeof value === 'string', + mmproj: (value: any) => typeof value === 'string', + vision_model: (value: any) => typeof value === 'boolean', + text_model: (value: any) => typeof value === 'boolean', +} + +/** + * There are some parameters that need to be normalized before being sent to the server + * E.g. ctx_len should be an integer, but it can be a float from the input field + * @param key + * @param value + * @returns + */ +export const normalizeValue = (key: string, value: any) => { + if ( + key === 'token_limit' || + key === 'max_tokens' || + key === 'ctx_len' || + key === 'ngl' || + key === 'n_parallel' || + key === 'cpu_threads' + ) { + // Convert to integer + return Math.floor(Number(value)) + } + return value +} + +/** + * Extract inference parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractInferenceParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelRuntimeParams => { if (!modelParams) return {} const defaultModelParams: ModelRuntimeParams = { @@ -22,15 +82,35 @@ export const toRuntimeParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultModelParams) { - Object.assign(runtimeParams, { ...runtimeParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: normalizeValue(key, value), + }) + } } } return runtimeParams } -export const toSettingParams = ( - modelParams?: ModelParams +/** + * Extract model load parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractModelLoadParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelSettingParams => { if (!modelParams) return {} const defaultSettingParams: ModelSettingParams = { @@ -49,7 +129,21 @@ export const toSettingParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultSettingParams) { - Object.assign(settingParams, { ...settingParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(modelParams, { + ...modelParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(settingParams, { + ...settingParams, + [key]: normalizeValue(key, value), + }) + } } }