Merge branch 'dev' into copyfix

This commit is contained in:
Nicole Zhu 2024-09-18 17:03:11 +08:00 committed by GitHub
commit 338232b173
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 2044 additions and 479 deletions

1
.gitignore vendored
View File

@ -45,3 +45,4 @@ core/test_results.html
coverage coverage
.yarn .yarn
.yarnrc .yarnrc
*.tsbuildinfo

View File

@ -1,98 +1,109 @@
import { openExternalUrl } from './core'; import { openExternalUrl } from './core'
import { joinPath } from './core'; import { joinPath } from './core'
import { openFileExplorer } from './core'; import { openFileExplorer } from './core'
import { getJanDataFolderPath } from './core'; import { getJanDataFolderPath } from './core'
import { abortDownload } from './core'; import { abortDownload } from './core'
import { getFileSize } from './core'; import { getFileSize } from './core'
import { executeOnMain } from './core'; import { executeOnMain } from './core'
it('should open external url', async () => { describe('test core apis', () => {
const url = 'http://example.com'; it('should open external url', async () => {
const url = 'http://example.com'
globalThis.core = { globalThis.core = {
api: { api: {
openExternalUrl: jest.fn().mockResolvedValue('opened') openExternalUrl: jest.fn().mockResolvedValue('opened'),
},
} }
}; const result = await openExternalUrl(url)
const result = await openExternalUrl(url); expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url); expect(result).toBe('opened')
expect(result).toBe('opened'); })
});
it('should join paths', async () => {
it('should join paths', async () => { const paths = ['/path/one', '/path/two']
const paths = ['/path/one', '/path/two'];
globalThis.core = { globalThis.core = {
api: { api: {
joinPath: jest.fn().mockResolvedValue('/path/one/path/two') joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
},
} }
}; const result = await joinPath(paths)
const result = await joinPath(paths); expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths); expect(result).toBe('/path/one/path/two')
expect(result).toBe('/path/one/path/two'); })
});
it('should open file explorer', async () => {
it('should open file explorer', async () => { const path = '/path/to/open'
const path = '/path/to/open';
globalThis.core = { globalThis.core = {
api: { api: {
openFileExplorer: jest.fn().mockResolvedValue('opened') openFileExplorer: jest.fn().mockResolvedValue('opened'),
},
} }
}; const result = await openFileExplorer(path)
const result = await openFileExplorer(path); expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path); expect(result).toBe('opened')
expect(result).toBe('opened'); })
});
it('should get jan data folder path', async () => {
it('should get jan data folder path', async () => {
globalThis.core = { globalThis.core = {
api: { api: {
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data') getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
},
} }
}; const result = await getJanDataFolderPath()
const result = await getJanDataFolderPath(); expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled(); expect(result).toBe('/path/to/jan/data')
expect(result).toBe('/path/to/jan/data'); })
});
it('should abort download', async () => {
it('should abort download', async () => { const fileName = 'testFile'
const fileName = 'testFile';
globalThis.core = { globalThis.core = {
api: { api: {
abortDownload: jest.fn().mockResolvedValue('aborted') abortDownload: jest.fn().mockResolvedValue('aborted'),
},
} }
}; const result = await abortDownload(fileName)
const result = await abortDownload(fileName); expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName); expect(result).toBe('aborted')
expect(result).toBe('aborted'); })
});
it('should get file size', async () => {
it('should get file size', async () => { const url = 'http://example.com/file'
const url = 'http://example.com/file';
globalThis.core = { globalThis.core = {
api: { api: {
getFileSize: jest.fn().mockResolvedValue(1024) getFileSize: jest.fn().mockResolvedValue(1024),
},
} }
}; const result = await getFileSize(url)
const result = await getFileSize(url); expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url); expect(result).toBe(1024)
expect(result).toBe(1024); })
});
it('should execute function on main process', async () => {
it('should execute function on main process', async () => { const extension = 'testExtension'
const extension = 'testExtension'; const method = 'testMethod'
const method = 'testMethod'; const args = ['arg1', 'arg2']
const args = ['arg1', 'arg2'];
globalThis.core = { globalThis.core = {
api: { api: {
invokeExtensionFunc: jest.fn().mockResolvedValue('result') invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
},
} }
}; const result = await executeOnMain(extension, method, ...args)
const result = await executeOnMain(extension, method, ...args); expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args); expect(result).toBe('result')
expect(result).toBe('result'); })
}); })
describe('dirName - just a pass thru api', () => {
it('should retrieve the directory name from a file path', async () => {
const mockDirName = jest.fn()
globalThis.core = {
api: {
dirName: mockDirName.mockResolvedValue('/path/to'),
},
}
// Normal file path with extension
const path = '/path/to/file.txt'
await globalThis.core.api.dirName(path)
expect(mockDirName).toHaveBeenCalledWith(path)
})
})

View File

@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
const joinPath: (paths: string[]) => Promise<string> = (paths) => const joinPath: (paths: string[]) => Promise<string> = (paths) =>
globalThis.core.api?.joinPath(paths) globalThis.core.api?.joinPath(paths)
/**
* Get dirname of a file path.
* @param path - The file path to retrieve dirname.
* @returns {Promise<string>} A promise that resolves the dirname.
*/
const dirName: (path: string) => Promise<string> = (path) => globalThis.core.api?.dirName(path)
/** /**
* Retrieve the basename from an url. * Retrieve the basename from an url.
* @param path - The path to retrieve. * @param path - The path to retrieve.
@ -161,5 +168,6 @@ export {
systemInformation, systemInformation,
showToast, showToast,
getFileSize, getFileSize,
dirName,
FileStat, FileStat,
} }

View File

@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events' import { events } from '../../events'
import { BaseExtension } from '../../extension' import { BaseExtension } from '../../extension'
import { fs } from '../../fs' import { fs } from '../../fs'
import { MessageRequest, Model, ModelEvent } from '../../../types' import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
import { EngineManager } from './EngineManager' import { EngineManager } from './EngineManager'
/** /**
@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
override onLoad() { override onLoad() {
this.registerEngine() this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
} }
@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
/** /**
* Loads the model. * Loads the model.
*/ */
async loadModel(model: Model): Promise<any> { async loadModel(model: ModelFile): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve() if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model) events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve() return Promise.resolve()

View File

@ -1,6 +1,6 @@
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' import { executeOnMain, systemInformation, dirName } from '../../core'
import { events } from '../../events' import { events } from '../../events'
import { Model, ModelEvent } from '../../../types' import { Model, ModelEvent, ModelFile } from '../../../types'
import { OAIEngine } from './OAIEngine' import { OAIEngine } from './OAIEngine'
/** /**
@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
unloadModelFunctionName: string = 'unloadModel' unloadModelFunctionName: string = 'unloadModel'
/** /**
* On extension load, subscribe to events. * This class represents a base for local inference providers in the OpenAI architecture.
* It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
* The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
* The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
*/ */
override onLoad() { override onLoad() {
super.onLoad() super.onLoad()
// These events are applicable to local inference providers // These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
} }
/** /**
* Load the model. * Load the model.
*/ */
override async loadModel(model: Model): Promise<void> { override async loadModel(model: ModelFile): Promise<void> {
if (model.engine.toString() !== this.provider) return if (model.engine.toString() !== this.provider) return
const modelFolderName = 'models' const modelFolder = await dirName(model.file_path)
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
const res = await executeOnMain( const res = await executeOnMain(
this.nodeModule, this.nodeModule,

View File

@ -4,6 +4,7 @@ import {
HuggingFaceRepoData, HuggingFaceRepoData,
ImportingModel, ImportingModel,
Model, Model,
ModelFile,
ModelInterface, ModelInterface,
OptionType, OptionType,
} from '../../types' } from '../../types'
@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
network?: { proxy: string; ignoreSSL?: boolean } network?: { proxy: string; ignoreSSL?: boolean }
): Promise<void> ): Promise<void>
abstract cancelModelDownload(modelId: string): Promise<void> abstract cancelModelDownload(modelId: string): Promise<void>
abstract deleteModel(modelId: string): Promise<void> abstract deleteModel(model: ModelFile): Promise<void>
abstract saveModel(model: Model): Promise<void> abstract getDownloadedModels(): Promise<ModelFile[]>
abstract getDownloadedModels(): Promise<Model[]> abstract getConfiguredModels(): Promise<ModelFile[]>
abstract getConfiguredModels(): Promise<Model[]>
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void> abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model> abstract updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile>
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData> abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
abstract getDefaultModel(): Promise<Model> abstract getDefaultModel(): Promise<Model>
} }

View File

@ -1,40 +1,57 @@
import { App } from './app'; jest.mock('../../helper', () => ({
...jest.requireActual('../../helper'),
getJanDataFolderPath: () => './app',
}))
import { dirname } from 'path'
import { App } from './app'
it('should call stopServer', () => { it('should call stopServer', () => {
const app = new App(); const app = new App()
const stopServerMock = jest.fn().mockResolvedValue('Server stopped'); const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
jest.mock('@janhq/server', () => ({ jest.mock('@janhq/server', () => ({
stopServer: stopServerMock stopServer: stopServerMock,
})); }))
const result = app.stopServer(); app.stopServer()
expect(stopServerMock).toHaveBeenCalled(); expect(stopServerMock).toHaveBeenCalled()
}); })
it('should correctly retrieve basename', () => { it('should correctly retrieve basename', () => {
const app = new App(); const app = new App()
const result = app.baseName('/path/to/file.txt'); const result = app.baseName('/path/to/file.txt')
expect(result).toBe('file.txt'); expect(result).toBe('file.txt')
}); })
it('should correctly identify subdirectories', () => { it('should correctly identify subdirectories', () => {
const app = new App(); const app = new App()
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'; const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'; const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
const result = app.isSubdirectory(basePath, subPath); const result = app.isSubdirectory(basePath, subPath)
expect(result).toBe(true); expect(result).toBe(true)
}); })
it('should correctly join multiple paths', () => { it('should correctly join multiple paths', () => {
const app = new App(); const app = new App()
const result = app.joinPath(['path', 'to', 'file']); const result = app.joinPath(['path', 'to', 'file'])
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'; const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
expect(result).toBe(expectedPath); expect(result).toBe(expectedPath)
}); })
it('should call correct function with provided arguments using process method', () => { it('should call correct function with provided arguments using process method', () => {
const app = new App(); const app = new App()
const mockFunc = jest.fn(); const mockFunc = jest.fn()
app.joinPath = mockFunc; app.joinPath = mockFunc
app.process('joinPath', ['path1', 'path2']); app.process('joinPath', ['path1', 'path2'])
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']); expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
}); })
it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
const app = new App()
const path = 'C:/Users/John Doe/Desktop/file.txt'
expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
})
it('should retrieve the directory name when using file protocol', async () => {
const app = new App()
const path = 'file:/models/file.txt'
expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
})

View File

@ -1,4 +1,4 @@
import { basename, isAbsolute, join, relative } from 'path' import { basename, dirname, isAbsolute, join, relative } from 'path'
import { Processor } from './Processor' import { Processor } from './Processor'
import { import {
@ -6,6 +6,8 @@ import {
appResourcePath, appResourcePath,
getAppConfigurations as appConfiguration, getAppConfigurations as appConfiguration,
updateAppConfiguration, updateAppConfiguration,
normalizeFilePath,
getJanDataFolderPath,
} from '../../helper' } from '../../helper'
export class App implements Processor { export class App implements Processor {
@ -28,6 +30,18 @@ export class App implements Processor {
return join(...args) return join(...args)
} }
/**
* Get dirname of a file path.
* @param path - The file path to retrieve dirname.
*/
dirName(path: string) {
const arg =
path.startsWith(`file:/`) || path.startsWith(`file:\\`)
? join(getJanDataFolderPath(), normalizeFilePath(path))
: path
return dirname(arg)
}
/** /**
* Checks if the given path is a subdirectory of the given directory. * Checks if the given path is a subdirectory of the given directory.
* *

View File

@ -37,6 +37,7 @@ export enum AppRoute {
getAppConfigurations = 'getAppConfigurations', getAppConfigurations = 'getAppConfigurations',
updateAppConfiguration = 'updateAppConfiguration', updateAppConfiguration = 'updateAppConfiguration',
joinPath = 'joinPath', joinPath = 'joinPath',
dirName = 'dirName',
isSubdirectory = 'isSubdirectory', isSubdirectory = 'isSubdirectory',
baseName = 'baseName', baseName = 'baseName',
startServer = 'startServer', startServer = 'startServer',

View File

@ -52,3 +52,18 @@ type DownloadSize = {
total: number total: number
transferred: number transferred: number
} }
/**
* The file metadata
*/
export type FileMetadata = {
/**
* The origin file path.
*/
file_path: string
/**
* The file name.
*/
file_name: string
}

View File

@ -1,3 +1,5 @@
import { FileMetadata } from '../file'
/** /**
* Represents the information about a model. * Represents the information about a model.
* @stored * @stored
@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
export type ModelInitFailed = Model & { export type ModelInitFailed = Model & {
error: Error error: Error
} }
/**
* ModelFile is the model.json entity and it's file metadata
*/
export type ModelFile = Model & FileMetadata

View File

@ -1,5 +1,5 @@
import { GpuSetting } from '../miscellaneous' import { GpuSetting } from '../miscellaneous'
import { Model } from './modelEntity' import { Model, ModelFile } from './modelEntity'
/** /**
* Model extension for managing models. * Model extension for managing models.
@ -29,14 +29,7 @@ export interface ModelInterface {
* @param modelId - The ID of the model to delete. * @param modelId - The ID of the model to delete.
* @returns A Promise that resolves when the model has been deleted. * @returns A Promise that resolves when the model has been deleted.
*/ */
deleteModel(modelId: string): Promise<void> deleteModel(model: ModelFile): Promise<void>
/**
* Saves a model.
* @param model - The model to save.
* @returns A Promise that resolves when the model has been saved.
*/
saveModel(model: Model): Promise<void>
/** /**
* Gets a list of downloaded models. * Gets a list of downloaded models.

View File

@ -1,32 +1,29 @@
import { expect } from '@playwright/test' import { expect } from '@playwright/test'
import { page, test, TIMEOUT } from '../config/fixtures' import { page, test, TIMEOUT } from '../config/fixtures'
test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => { test('Select GPT model from Hub and Chat with Invalid API Key', async ({
hubPage,
}) => {
await hubPage.navigateByMenu() await hubPage.navigateByMenu()
await hubPage.verifyContainerVisible() await hubPage.verifyContainerVisible()
// Select the first GPT model // Select the first GPT model
await page await page
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]') .locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
.first().click() .first()
// Attempt to create thread and chat in Thread page
await page
.getByTestId('btn-create-thread')
.click() .click()
await page await page.getByTestId('txt-input-chat').fill('dummy value')
.getByTestId('txt-input-chat')
.fill('dummy value')
await page await page.getByTestId('btn-send-chat').click()
.getByTestId('btn-send-chat')
.click()
await page.waitForFunction(() => { await page.waitForFunction(
const loaders = document.querySelectorAll('[data-testid$="loader"]'); () => {
return !loaders.length; const loaders = document.querySelectorAll('[data-testid$="loader"]')
}, { timeout: TIMEOUT }); return !loaders.length
},
{ timeout: TIMEOUT }
)
const APIKeyError = page.getByTestId('invalid-API-key-error') const APIKeyError = page.getByTestId('invalid-API-key-error')
await expect(APIKeyError).toBeVisible({ await expect(APIKeyError).toBeVisible({

View File

@ -22,6 +22,7 @@ import {
downloadFile, downloadFile,
DownloadState, DownloadState,
DownloadEvent, DownloadEvent,
ModelFile,
} from '@janhq/core' } from '@janhq/core'
declare const CUDA_DOWNLOAD_URL: string declare const CUDA_DOWNLOAD_URL: string
@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
this.nitroProcessInfo = health this.nitroProcessInfo = health
} }
override loadModel(model: Model): Promise<void> { override loadModel(model: ModelFile): Promise<void> {
if (model.engine !== this.provider) return Promise.resolve() if (model.engine !== this.provider) return Promise.resolve()
this.getNitroProcessHealthIntervalId = setInterval( this.getNitroProcessHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(), () => this.periodicallyGetNitroHealth(),

View File

@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
import { import {
log, log,
getSystemResourceInfo, getSystemResourceInfo,
Model,
InferenceEngine, InferenceEngine,
ModelSettingParams, ModelSettingParams,
PromptTemplate, PromptTemplate,
SystemInformation, SystemInformation,
getJanDataFolderPath, getJanDataFolderPath,
ModelFile,
} from '@janhq/core/node' } from '@janhq/core/node'
import { executableNitroFile } from './execute' import { executableNitroFile } from './execute'
import terminate from 'terminate' import terminate from 'terminate'
@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
*/ */
interface ModelInitOptions { interface ModelInitOptions {
modelFolder: string modelFolder: string
model: Model model: ModelFile
} }
// The PORT to use for the Nitro subprocess // The PORT to use for the Nitro subprocess
const PORT = 3928 const PORT = 3928
@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
if (!settings?.ngl) { if (!settings?.ngl) {
settings.ngl = 100 settings.ngl = 100
} }
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
method: 'POST', method: 'POST',
headers: { headers: {
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
}) })
.then((res) => { .then((res) => {
log( log(
`[CORTEX]::Debug: Load model success with response ${JSON.stringify( `[CORTEX]:: Load model success with response ${JSON.stringify(
res res
)}` )}`
) )
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
async function validateModelStatus(modelId: string): Promise<void> { async function validateModelStatus(modelId: string): Promise<void> {
// Send a GET request to the validation URL. // Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
log(`[CORTEX]::Debug: Validating model ${modelId}`) log(`[CORTEX]:: Validating model ${modelId}`)
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
method: 'POST', method: 'POST',
body: JSON.stringify({ body: JSON.stringify({
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
retryDelay: 300, retryDelay: 300,
}).then(async (res: Response) => { }).then(async (res: Response) => {
log( log(
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify( `[CORTEX]:: Validate model state with response ${JSON.stringify(
res.status res.status
)}` )}`
) )
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
// Otherwise, return an object with an error message. // Otherwise, return an object with an error message.
if (body.model_loaded) { if (body.model_loaded) {
log( log(
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( `[CORTEX]:: Validate model state success with response ${JSON.stringify(
body body
)}` )}`
) )
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
} }
const errorBody = await res.text() const errorBody = await res.text()
log( log(
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
res.statusText res.statusText
)}` )}`
) )
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
async function killSubprocess(): Promise<void> { async function killSubprocess(): Promise<void> {
const controller = new AbortController() const controller = new AbortController()
setTimeout(() => controller.abort(), 5000) setTimeout(() => controller.abort(), 5000)
log(`[CORTEX]::Debug: Request to kill cortex`) log(`[CORTEX]:: Request to kill cortex`)
const killRequest = () => { const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, { return fetch(NITRO_HTTP_KILL_URL, {
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
.then(() => .then(() =>
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
) )
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch((err) => { .catch((err) => {
log( log(
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
) )
throw 'PORT_NOT_AVAILABLE' throw 'PORT_NOT_AVAILABLE'
}) })
} }
if (subprocess?.pid && process.platform !== 'darwin') { if (subprocess?.pid && process.platform !== 'darwin') {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid const pid = subprocess.pid
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
terminate(pid, function (err) { terminate(pid, function (err) {
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
} else { } else {
tcpPortUsed tcpPortUsed
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) .then(() => log(`[CORTEX]:: cortex process is terminated`))
.then(() => resolve()) .then(() => resolve())
.catch(() => { .catch(() => {
log( log(
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
* @returns A promise that resolves when the Nitro subprocess is started. * @returns A promise that resolves when the Nitro subprocess is started.
*/ */
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> { function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
log(`[CORTEX]::Debug: Spawning cortex subprocess...`) log(`[CORTEX]:: Spawning cortex subprocess...`)
return new Promise<void>(async (resolve, reject) => { return new Promise<void>(async (resolve, reject) => {
let executableOptions = executableNitroFile( let executableOptions = executableNitroFile(
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
const args: string[] = ['1', LOCAL_HOST, PORT.toString()] const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary // Execute the binary
log( log(
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
) )
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`) log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
// Handle subprocess output // Handle subprocess output
subprocess.stdout.on('data', (data: any) => { subprocess.stdout.on('data', (data: any) => {
log(`[CORTEX]::Debug: ${data}`) log(`[CORTEX]:: ${data}`)
}) })
subprocess.stderr.on('data', (data: any) => { subprocess.stderr.on('data', (data: any) => {
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
}) })
subprocess.on('close', (code: any) => { subprocess.on('close', (code: any) => {
log(`[CORTEX]::Debug: cortex exited with code: ${code}`) log(`[CORTEX]:: cortex exited with code: ${code}`)
subprocess = undefined subprocess = undefined
reject(`child process exited with code ${code}`) reject(`child process exited with code ${code}`)
}) })
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
tcpPortUsed tcpPortUsed
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
.then(() => { .then(() => {
log(`[CORTEX]::Debug: cortex is ready`) log(`[CORTEX]:: cortex is ready`)
resolve() resolve()
}) })
}) })

View File

@ -119,5 +119,65 @@
] ]
}, },
"engine": "openai" "engine": "openai"
},
{
"sources": [
{
"url": "https://openai.com"
}
],
"id": "o1-preview",
"object": "model",
"name": "OpenAI o1-preview",
"version": "1.0",
"description": "OpenAI o1-preview is a new model with complex reasoning",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 0.95,
"stream": true,
"stop": [],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "OpenAI",
"tags": [
"General"
]
},
"engine": "openai"
},
{
"sources": [
{
"url": "https://openai.com"
}
],
"id": "o1-mini",
"object": "model",
"name": "OpenAI o1-mini",
"version": "1.0",
"description": "OpenAI o1-mini is a lightweight reasoning model",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 0.95,
"stream": true,
"stop": [],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "OpenAI",
"tags": [
"General"
]
},
"engine": "openai"
} }
] ]

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -8,6 +8,7 @@
"author": "Jan <service@jan.ai>", "author": "Jan <service@jan.ai>",
"license": "AGPL-3.0", "license": "AGPL-3.0",
"scripts": { "scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
}, },

View File

@ -27,7 +27,7 @@ export default [
// Allow json resolution // Allow json resolution
json(), json(),
// Compile TypeScript files // Compile TypeScript files
typescript({ useTsconfigDeclarationDir: true }), typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Compile TypeScript files // Compile TypeScript files
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
// commonjs(), // commonjs(),
@ -62,7 +62,7 @@ export default [
// Allow json resolution // Allow json resolution
json(), json(),
// Compile TypeScript files // Compile TypeScript files
typescript({ useTsconfigDeclarationDir: true }), typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
commonjs(), commonjs(),
// Allow node_modules resolution, so you can use 'external' to control // Allow node_modules resolution, so you can use 'external' to control

View File

@ -0,0 +1,564 @@
const readDirSyncMock = jest.fn()
const existMock = jest.fn()
const readFileSyncMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
fs: {
existsSync: existMock,
readdirSync: readDirSyncMock,
readFileSync: readFileSyncMock,
fileStat: () => ({
isDirectory: false,
}),
},
dirName: jest.fn(),
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
}))
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
describe('JanModelExtension', () => {
let sut: JanModelExtension
beforeAll(() => {
// @ts-ignore
sut = new JanModelExtension()
})
afterEach(() => {
jest.clearAllMocks()
})
describe('getConfiguredModels', () => {
describe("when there's no models are pre-populated", () => {
it('should return empty array', async () => {
// Mock configured models data
const configuredModels = []
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
const result = await sut.getConfiguredModels()
expect(result).toEqual([])
})
})
describe("when there's are pre-populated models - all flattened", () => {
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getConfiguredModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model.json',
id: '2',
}),
])
)
})
})
describe("when there's are pre-populated models - there are nested folders", () => {
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else if (path.includes('model2/model2-1'))
return JSON.stringify(configuredModels[1])
})
const result = await sut.getConfiguredModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('getDownloadedModels', () => {
describe('no models downloaded', () => {
it('should return empty array', async () => {
// Mock downloaded models data
const downloadedModels = []
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
const result = await sut.getDownloadedModels()
expect(result).toEqual([])
})
})
describe('only one model is downloaded', () => {
describe('flatten folder', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2']
else if (path === 'file://models/model1')
return ['model.json', 'test.gguf']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
])
)
})
})
})
describe('all models are downloaded', () => {
describe('nested folders', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else return ['model.json', 'test.gguf']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('all models are downloaded with uppercased GGUF files', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else if (path === 'file://models/model1')
return ['model.json', 'test.GGUF']
else return ['model.json', 'test.gguf']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
describe('all models are downloaded - GGUF & Tensort RT', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else if (path === 'file://models/model1')
return ['model.json', 'test.gguf']
else return ['model.json', 'test.engine']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('deleteModel', () => {
describe('model is a GGUF model', () => {
it('should delete the GGUF file', async () => {
fs.unlinkSync = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
readDirSyncMock.mockImplementation((path) => {
return ['model.json', 'test.gguf']
})
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledWith(
'file://models/model1/test.gguf'
)
})
it('no gguf file presented', async () => {
fs.unlinkSync = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
readDirSyncMock.mockReturnValue(['model.json'])
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
})
it('delete an imported model', async () => {
fs.rm = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
// MARK: This is a tricky logic implement?
// I will just add test for now but will align on the legacy implementation
fs.readFileSync = jest.fn().mockReturnValue(
JSON.stringify({
metadata: {
author: 'user',
},
})
)
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
})
it('delete tensorrt-models', async () => {
fs.rm = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
})
})
})
})

View File

@ -22,6 +22,8 @@ import {
getFileSize, getFileSize,
AllQuantizations, AllQuantizations,
ModelEvent, ModelEvent,
ModelFile,
dirName,
} from '@janhq/core' } from '@janhq/core'
import { extractFileName } from './helpers/path' import { extractFileName } from './helpers/path'
@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
] ]
private static readonly _tensorRtEngineFormat = '.engine' private static readonly _tensorRtEngineFormat = '.engine'
private static readonly _supportedGpuArch = ['ampere', 'ada'] private static readonly _supportedGpuArch = ['ampere', 'ada']
private static readonly _safetensorsRegexs = [
/model\.safetensors$/,
/model-[0-9]+-of-[0-9]+\.safetensors$/,
]
private static readonly _pytorchRegexs = [
/pytorch_model\.bin$/,
/consolidated\.[0-9]+\.pth$/,
/pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
/.*\.pt$/,
]
interrupted = false interrupted = false
/** /**
@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension {
* @param filePath - The path to the model file to delete. * @param filePath - The path to the model file to delete.
* @returns A Promise that resolves when the model is deleted. * @returns A Promise that resolves when the model is deleted.
*/ */
async deleteModel(modelId: string): Promise<void> { async deleteModel(model: ModelFile): Promise<void> {
try { try {
const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) const dirPath = await dirName(model.file_path)
const jsonFilePath = await joinPath([ const jsonFilePath = await joinPath([
dirPath, dirPath,
JanModelExtension._modelMetadataFileName, JanModelExtension._modelMetadataFileName,
@ -330,6 +323,8 @@ export default class JanModelExtension extends ModelExtension {
await this.readModelMetadata(jsonFilePath) await this.readModelMetadata(jsonFilePath)
) as Model ) as Model
// TODO: This is so tricky?
// Should depend on sources?
const isUserImportModel = const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user' modelInfo.metadata?.author?.toLowerCase() === 'user'
if (isUserImportModel) { if (isUserImportModel) {
@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension {
} }
} }
/**
* Saves a model file.
* @param model - The model to save.
* @returns A Promise that resolves when the model is saved.
*/
async saveModel(model: Model): Promise<void> {
const jsonFilePath = await joinPath([
JanModelExtension._homeDir,
model.id,
JanModelExtension._modelMetadataFileName,
])
try {
await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
} catch (err) {
console.error(err)
}
}
/** /**
* Gets all downloaded models. * Gets all downloaded models.
* @returns A Promise that resolves with an array of all models. * @returns A Promise that resolves with an array of all models.
*/ */
async getDownloadedModels(): Promise<Model[]> { async getDownloadedModels(): Promise<ModelFile[]> {
return await this.getModelsMetadata( return await this.getModelsMetadata(
async (modelDir: string, model: Model) => { async (modelDir: string, model: Model) => {
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension {
): Promise<string | undefined> { ): Promise<string | undefined> {
// try to find model.json recursively inside each folder // try to find model.json recursively inside each folder
if (!(await fs.existsSync(folderFullPath))) return undefined if (!(await fs.existsSync(folderFullPath))) return undefined
const files: string[] = await fs.readdirSync(folderFullPath) const files: string[] = await fs.readdirSync(folderFullPath)
if (files.length === 0) return undefined if (files.length === 0) return undefined
if (files.includes(JanModelExtension._modelMetadataFileName)) { if (files.includes(JanModelExtension._modelMetadataFileName)) {
return joinPath([ return joinPath([
folderFullPath, folderFullPath,
@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension {
private async getModelsMetadata( private async getModelsMetadata(
selector?: (path: string, model: Model) => Promise<boolean> selector?: (path: string, model: Model) => Promise<boolean>
): Promise<Model[]> { ): Promise<ModelFile[]> {
try { try {
if (!(await fs.existsSync(JanModelExtension._homeDir))) { if (!(await fs.existsSync(JanModelExtension._homeDir))) {
console.debug('Model folder not found') console.debug('Model folder not found')
@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._homeDir, JanModelExtension._homeDir,
dirName, dirName,
]) ])
const jsonPath = await this.getModelJsonPath(folderFullPath) const jsonPath = await this.getModelJsonPath(folderFullPath)
if (await fs.existsSync(jsonPath)) { if (await fs.existsSync(jsonPath)) {
@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension {
}, },
] ]
} }
model.file_path = jsonPath
model.file_name = JanModelExtension._modelMetadataFileName
if (selector && !(await selector?.(dirName, model))) { if (selector && !(await selector?.(dirName, model))) {
return return
@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension {
typeof result.value === 'object' typeof result.value === 'object'
? result.value ? result.value
: JSON.parse(result.value) : JSON.parse(result.value)
return model as Model return model as ModelFile
} catch { } catch {
console.debug(`Unable to parse model metadata: ${result.value}`) console.debug(`Unable to parse model metadata: ${result.value}`)
} }
@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension {
* Gets all available models. * Gets all available models.
* @returns A Promise that resolves with an array of all models. * @returns A Promise that resolves with an array of all models.
*/ */
async getConfiguredModels(): Promise<Model[]> { async getConfiguredModels(): Promise<ModelFile[]> {
return this.getModelsMetadata() return this.getModelsMetadata()
} }
@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension {
modelBinaryPath: string, modelBinaryPath: string,
modelFolderName: string, modelFolderName: string,
modelFolderPath: string modelFolderPath: string
): Promise<Model> { ): Promise<ModelFile> {
const fileStats = await fs.fileStat(modelBinaryPath, true) const fileStats = await fs.fileStat(modelBinaryPath, true)
const binaryFileSize = fileStats.size const binaryFileSize = fileStats.size
@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension {
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
return model return {
...model,
file_path: modelFilePath,
file_name: JanModelExtension._modelMetadataFileName,
}
} }
async updateModelInfo(modelInfo: Partial<Model>): Promise<Model> { async updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile> {
const modelId = modelInfo.id
if (modelInfo.id == null) throw new Error('Model ID is required') if (modelInfo.id == null) throw new Error('Model ID is required')
const janDataFolderPath = await getJanDataFolderPath()
const jsonFilePath = await joinPath([
janDataFolderPath,
'models',
modelId,
JanModelExtension._modelMetadataFileName,
])
const model = JSON.parse( const model = JSON.parse(
await this.readModelMetadata(jsonFilePath) await this.readModelMetadata(modelInfo.file_path)
) as Model ) as ModelFile
const updatedModel: Model = { const updatedModel: ModelFile = {
...model, ...model,
...modelInfo, ...modelInfo,
parameters: { parameters: {
@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension {
...model.metadata, ...model.metadata,
...modelInfo.metadata, ...modelInfo.metadata,
}, },
// Should not persist file_path & file_name
file_path: undefined,
file_name: undefined,
} }
await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) await fs.writeFileSync(
modelInfo.file_path,
JSON.stringify(updatedModel, null, 2)
)
return updatedModel return updatedModel
} }

View File

@ -10,5 +10,6 @@
"skipLibCheck": true, "skipLibCheck": true,
"rootDir": "./src" "rootDir": "./src"
}, },
"include": ["./src"] "include": ["./src"],
"exclude": ["**/*.test.ts"]
} }

View File

@ -23,6 +23,7 @@ import {
ModelEvent, ModelEvent,
getJanDataFolderPath, getJanDataFolderPath,
SystemInformation, SystemInformation,
ModelFile,
} from '@janhq/core' } from '@janhq/core'
/** /**
@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {}) events.emit(ModelEvent.OnModelsUpdate, {})
} }
override async loadModel(model: Model): Promise<void> { override async loadModel(model: ModelFile): Promise<void> {
if ((await this.installationState()) === 'Installed') if ((await this.installationState()) === 'Installed')
return super.loadModel(model) return super.loadModel(model)

View File

@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
} }
if (subprocess?.pid) { if (subprocess?.pid) {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid const pid = subprocess.pid
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
terminate(pid, function (err) { terminate(pid, function (err) {
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
return tcpPortUsed return tcpPortUsed
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
.then(() => resolve()) .then(() => resolve())
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch(() => { .catch(() => {
killRequest() killRequest()
}) })

View File

@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge'
import { MainViewState } from '@/constants/screens' import { MainViewState } from '@/constants/screens'
import { localEngines } from '@/utils/modelEngine'
import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom' import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { import {
reduceTransparentAtom, reduceTransparentAtom,
selectedSettingAtom, selectedSettingAtom,
} from '@/helpers/atoms/Setting.atom' } from '@/helpers/atoms/Setting.atom'
import { threadsAtom } from '@/helpers/atoms/Thread.atom' import {
isDownloadALocalModelAtom,
threadsAtom,
} from '@/helpers/atoms/Thread.atom'
export default function RibbonPanel() { export default function RibbonPanel() {
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom) const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
@ -32,8 +33,9 @@ export default function RibbonPanel() {
const matches = useMediaQuery('(max-width: 880px)') const matches = useMediaQuery('(max-width: 880px)')
const reduceTransparent = useAtomValue(reduceTransparentAtom) const reduceTransparent = useAtomValue(reduceTransparentAtom)
const setSelectedSetting = useSetAtom(selectedSettingAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom)
const onMenuClick = (state: MainViewState) => { const onMenuClick = (state: MainViewState) => {
if (mainViewState === state) return if (mainViewState === state) return
@ -43,10 +45,6 @@ export default function RibbonPanel() {
setEditMessage('') setEditMessage('')
} }
const isDownloadALocalModel = downloadedModels.some((x) =>
localEngines.includes(x.engine)
)
const RibbonNavMenus = [ const RibbonNavMenus = [
{ {
name: 'Thread', name: 'Thread',

View File

@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast'
import { MainViewState } from '@/constants/screens' import { MainViewState } from '@/constants/screens'
import { useCreateNewThread } from '@/hooks/useCreateNewThread' import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import { useStarterScreen } from '@/hooks/useStarterScreen'
import { import {
mainViewStateAtom, mainViewStateAtom,
@ -58,6 +59,8 @@ const TopPanel = () => {
requestCreateNewThread(assistants[0]) requestCreateNewThread(assistants[0])
} }
const { isShowStarterScreen } = useStarterScreen()
return ( return (
<div <div
className={twMerge( className={twMerge(
@ -93,7 +96,7 @@ const TopPanel = () => {
)} )}
</Fragment> </Fragment>
)} )}
{mainViewState === MainViewState.Thread && ( {mainViewState === MainViewState.Thread && !isShowStarterScreen && (
<Button <Button
data-testid="btn-create-thread" data-testid="btn-create-thread"
onClick={onCreateNewThreadClick} onClick={onCreateNewThreadClick}

View File

@ -46,7 +46,6 @@ import {
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
@ -91,8 +90,6 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) => const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured') x.metadata.tags.includes('Featured')
) )
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const { updateThreadMetadata } = useCreateNewThread() const { updateThreadMetadata } = useCreateNewThread()
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
@ -191,27 +188,14 @@ const ModelDropdown = ({
], ],
}) })
// Default setting ctx_len for the model for a better onboarding experience
// TODO: When Cortex support hardware instructions, we should remove this
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings = const overriddenSettings =
model?.settings.ctx_len && model.settings.ctx_len > 2048 model?.settings.ctx_len && model.settings.ctx_len > 4096
? { ctx_len: defaultContextLength ?? 2048 } ? { ctx_len: 4096 }
: {}
const overriddenParameters =
model?.parameters.max_tokens && model.parameters.max_tokens
? { max_tokens: defaultMaxTokens ?? 2048 }
: {} : {}
const modelParams = { const modelParams = {
...model?.parameters, ...model?.parameters,
...model?.settings, ...model?.settings,
...overriddenParameters,
...overriddenSettings, ...overriddenSettings,
} }
@ -222,6 +206,7 @@ const ModelDropdown = ({
if (model) if (model)
updateModelParameter(activeThread, { updateModelParameter(activeThread, {
params: modelParams, params: modelParams,
modelPath: model.file_path,
modelId: model.id, modelId: model.id,
engine: model.engine, engine: model.engine,
}) })
@ -235,7 +220,6 @@ const ModelDropdown = ({
setThreadModelParams, setThreadModelParams,
updateModelParameter, updateModelParameter,
updateThreadMetadata, updateThreadMetadata,
preserveModelSettings,
] ]
) )

View File

@ -0,0 +1,100 @@
import React from 'react'
import { render, waitFor, screen } from '@testing-library/react'
import { useAtomValue } from 'jotai'
import { useActiveModel } from '@/hooks/useActiveModel'
import { useSettings } from '@/hooks/useSettings'
import ModelLabel from '@/containers/ModelLabel'
jest.mock('jotai', () => ({
useAtomValue: jest.fn(),
atom: jest.fn(),
}))
jest.mock('@/hooks/useActiveModel', () => ({
useActiveModel: jest.fn(),
}))
jest.mock('@/hooks/useSettings', () => ({
useSettings: jest.fn(),
}))
describe('ModelLabel', () => {
const mockUseAtomValue = useAtomValue as jest.Mock
const mockUseActiveModel = useActiveModel as jest.Mock
const mockUseSettings = useSettings as jest.Mock
const defaultProps: any = {
metadata: {
author: 'John Doe', // Add the 'author' property with a value
tags: ['8B'],
size: 100,
},
compact: false,
}
beforeEach(() => {
jest.clearAllMocks()
})
it('renders NotEnoughMemoryLabel when minimumRamModel is greater than totalRam', async () => {
mockUseAtomValue
.mockReturnValueOnce(0)
.mockReturnValueOnce(0)
.mockReturnValueOnce(0)
mockUseActiveModel.mockReturnValue({
activeModel: { metadata: { size: 0 } },
})
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
render(<ModelLabel {...defaultProps} />)
await waitFor(() => {
expect(screen.getByText('Not enough RAM')).toBeDefined()
})
})
it('renders SlowOnYourDeviceLabel when minimumRamModel is less than totalRam but greater than availableRam', async () => {
mockUseAtomValue
.mockReturnValueOnce(100)
.mockReturnValueOnce(50)
.mockReturnValueOnce(10)
mockUseActiveModel.mockReturnValue({
activeModel: { metadata: { size: 0 } },
})
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
const props = {
...defaultProps,
metadata: {
...defaultProps.metadata,
size: 50,
},
}
render(<ModelLabel {...props} />)
await waitFor(() => {
expect(screen.getByText('Slow on your device')).toBeDefined()
})
})
it('renders nothing when minimumRamModel is less than availableRam', () => {
mockUseAtomValue
.mockReturnValueOnce(100)
.mockReturnValueOnce(50)
.mockReturnValueOnce(0)
mockUseActiveModel.mockReturnValue({
activeModel: { metadata: { size: 0 } },
})
mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
const props = {
...defaultProps,
metadata: {
...defaultProps.metadata,
size: 10,
},
}
const { container } = render(<ModelLabel {...props} />)
expect(container.firstChild).toBeNull()
})
})

View File

@ -10,8 +10,6 @@ import { useSettings } from '@/hooks/useSettings'
import NotEnoughMemoryLabel from './NotEnoughMemoryLabel' import NotEnoughMemoryLabel from './NotEnoughMemoryLabel'
import RecommendedLabel from './RecommendedLabel'
import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel' import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel'
import { import {
@ -53,9 +51,7 @@ const ModelLabel = ({ metadata, compact }: Props) => {
/> />
) )
} }
if (minimumRamModel < availableRam && !compact) {
return <RecommendedLabel />
}
if (minimumRamModel < totalRam && minimumRamModel > availableRam) { if (minimumRamModel < totalRam && minimumRamModel > availableRam) {
return <SlowOnYourDeviceLabel compact={compact} /> return <SlowOnYourDeviceLabel compact={compact} />
} }

View File

@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { toRuntimeParams } from '@/utils/modelParam' import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
}, },
] ]
const runtimeParams = toRuntimeParams(activeModelParamsRef.current) const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
const messageRequest: MessageRequest = { const messageRequest: MessageRequest = {
id: msgId, id: msgId,

View File

@ -87,26 +87,28 @@ const SliderRightPanel = ({
onValueChanged?.(Number(min)) onValueChanged?.(Number(min))
setVal(min.toString()) setVal(min.toString())
setShowTooltip({ max: false, min: true }) setShowTooltip({ max: false, min: true })
} else {
setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
} }
}} }}
onChange={(e) => { onChange={(e) => {
// Should not accept invalid value or NaN
// E.g. anything changes that trigger onValueChanged
// Which is incorrect
if (Number(e.target.value) > Number(max)) {
setVal(max.toString())
} else if (
Number(e.target.value) < Number(min) ||
!e.target.value.length
) {
setVal(min.toString())
} else if (Number.isNaN(Number(e.target.value))) return
onValueChanged?.(Number(e.target.value))
// TODO: How to support negative number input? // TODO: How to support negative number input?
// Passthru since it validates again onBlur
if (/^\d*\.?\d*$/.test(e.target.value)) { if (/^\d*\.?\d*$/.test(e.target.value)) {
setVal(e.target.value) setVal(e.target.value)
} }
// Should not accept invalid value or NaN
// E.g. anything changes that trigger onValueChanged
// Which is incorrect
if (
Number(e.target.value) > Number(max) ||
Number(e.target.value) < Number(min) ||
Number.isNaN(Number(e.target.value))
) {
return
}
onValueChanged?.(Number(e.target.value))
}} }}
/> />
} }

View File

@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature' const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled' const QUICK_ASK_ENABLED = 'quickAskEnabled'
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
export const janDataFolderPathAtom = atom('') export const janDataFolderPathAtom = atom('')
@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
export const hostAtom = atom('http://localhost:1337/') export const hostAtom = atom('http://localhost:1337/')
// This feature is to allow user to cache model settings on thread creation
export const preserveModelSettingsAtom = atomWithStorage(
PRESERVE_MODEL_SETTINGS,
false
)

View File

@ -1,4 +1,4 @@
import { ImportingModel, Model, InferenceEngine } from '@janhq/core' import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom } from 'jotai' import { atom } from 'jotai'
import { localEngines } from '@/utils/modelEngine' import { localEngines } from '@/utils/modelEngine'
@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom(
} }
) )
export const downloadedModelsAtom = atom<Model[]>([]) export const downloadedModelsAtom = atom<ModelFile[]>([])
export const updateDownloadedModelAtom = atom(
null,
(get, set, updatedModel: Model) => {
const models: Model[] = get(downloadedModelsAtom).map((c) =>
c.id === updatedModel.id ? updatedModel : c
)
set(downloadedModelsAtom, models)
}
)
export const removeDownloadedModelAtom = atom( export const removeDownloadedModelAtom = atom(
null, null,
@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom(
} }
) )
export const configuredModelsAtom = atom<Model[]>([]) export const configuredModelsAtom = atom<ModelFile[]>([])
export const defaultModelAtom = atom<Model | undefined>(undefined) export const defaultModelAtom = atom<Model | undefined>(undefined)
@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom(
} }
) )
export const selectedModelAtom = atom<Model | undefined>(undefined) export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines) export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)

View File

@ -152,3 +152,6 @@ export const modalActionThreadAtom = atom<{
showModal: undefined, showModal: undefined,
thread: undefined, thread: undefined,
}) })
export const isDownloadALocalModelAtom = atom<boolean>(false)
export const isAnyRemoteModelConfiguredAtom = atom<boolean>(false)

View File

@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef } from 'react' import { useCallback, useEffect, useRef } from 'react'
import { EngineManager, Model } from '@janhq/core' import { EngineManager, Model, ModelFile } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
export const activeModelAtom = atom<Model | undefined>(undefined) export const activeModelAtom = atom<ModelFile | undefined>(undefined)
export const loadModelErrorAtom = atom<string | undefined>(undefined) export const loadModelErrorAtom = atom<string | undefined>(undefined)
type ModelState = { type ModelState = {
@ -37,7 +37,7 @@ export function useActiveModel() {
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom) const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
const downloadedModelsRef = useRef<Model[]>([]) const downloadedModelsRef = useRef<ModelFile[]>([])
useEffect(() => { useEffect(() => {
downloadedModelsRef.current = downloadedModels downloadedModelsRef.current = downloadedModels

View File

@ -7,8 +7,8 @@ import {
Thread, Thread,
ThreadAssistantInfo, ThreadAssistantInfo,
ThreadState, ThreadState,
Model,
AssistantTool, AssistantTool,
ModelFile,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtomValue, useSetAtom } from 'jotai'
@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
experimentalFeatureEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
threadsAtom, threadsAtom,
@ -67,7 +64,6 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue( const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom copyOverInstructionEnabledAtom
) )
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
@ -80,7 +76,7 @@ export const useCreateNewThread = () => {
const requestCreateNewThread = async ( const requestCreateNewThread = async (
assistant: Assistant, assistant: Assistant,
model?: Model | undefined model?: ModelFile | undefined
) => { ) => {
// Stop generating if any // Stop generating if any
setIsGeneratingResponse(false) setIsGeneratingResponse(false)
@ -109,19 +105,13 @@ export const useCreateNewThread = () => {
enabled: true, enabled: true,
settings: assistant.tools && assistant.tools[0].settings, settings: assistant.tools && assistant.tools[0].settings,
} }
const defaultContextLength = preserveModelSettings
? defaultModel?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? defaultModel?.metadata?.default_max_tokens
: 2048
const overriddenSettings = const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
? { ctx_len: defaultContextLength ?? 2048 } ? { ctx_len: 4096 }
: {} : {}
const overriddenParameters = defaultModel?.parameters.max_tokens const overriddenParameters = defaultModel?.parameters.max_tokens
? { max_tokens: defaultMaxTokens ?? 2048 } ? { max_tokens: 4096 }
: {} : {}
const createdAt = Date.now() const createdAt = Date.now()

View File

@ -1,6 +1,6 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core' import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core'
import { useSetAtom } from 'jotai' import { useSetAtom } from 'jotai'
@ -13,8 +13,8 @@ export default function useDeleteModel() {
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
const deleteModel = useCallback( const deleteModel = useCallback(
async (model: Model) => { async (model: ModelFile) => {
await localDeleteModel(model.id) await localDeleteModel(model)
removeDownloadedModel(model.id) removeDownloadedModel(model.id)
toaster({ toaster({
title: 'Model Deletion Successful', title: 'Model Deletion Successful',
@ -28,5 +28,7 @@ export default function useDeleteModel() {
return { deleteModel } return { deleteModel }
} }
const localDeleteModel = async (id: string) => const localDeleteModel = async (model: ModelFile) =>
extensionManager.get<ModelExtension>(ExtensionTypeEnum.Model)?.deleteModel(id) extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.deleteModel(model)

View File

@ -5,6 +5,7 @@ import {
Model, Model,
ModelEvent, ModelEvent,
ModelExtension, ModelExtension,
ModelFile,
events, events,
} from '@janhq/core' } from '@janhq/core'
@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise<Model | undefined> =>
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.getDefaultModel() ?.getDefaultModel()
const getLocalConfiguredModels = async (): Promise<Model[]> => const getLocalConfiguredModels = async (): Promise<ModelFile[]> =>
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.getConfiguredModels() ?? [] ?.getConfiguredModels() ?? []
const getLocalDownloadedModels = async (): Promise<Model[]> => const getLocalDownloadedModels = async (): Promise<ModelFile[]> =>
extensionManager extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model) .get<ModelExtension>(ExtensionTypeEnum.Model)
?.getDownloadedModels() ?? [] ?.getDownloadedModels() ?? []

View File

@ -1,6 +1,6 @@
import { useCallback, useEffect, useState } from 'react' import { useCallback, useEffect, useState } from 'react'
import { Model, InferenceEngine } from '@janhq/core' import { Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom, useAtomValue } from 'jotai' import { atom, useAtomValue } from 'jotai'
@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id'
*/ */
export default function useRecommendedModel() { export default function useRecommendedModel() {
const activeModel = useAtomValue(activeModelAtom) const activeModel = useAtomValue(activeModelAtom)
const [sortedModels, setSortedModels] = useState<Model[]>([]) const [sortedModels, setSortedModels] = useState<ModelFile[]>([])
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>() const [recommendedModel, setRecommendedModel] = useState<
ModelFile | undefined
>()
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom)
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => { const getAndSortDownloadedModels = useCallback(async (): Promise<
ModelFile[]
> => {
const models = downloadedModels.sort((a, b) => const models = downloadedModels.sort((a, b) =>
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
? 1 ? 1

View File

@ -23,7 +23,10 @@ import {
import { Stack } from '@/utils/Stack' import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64' import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
@ -189,8 +192,8 @@ export default function useSendChatMessage() {
if (engineParamsUpdate) setReloadModel(true) if (engineParamsUpdate) setReloadModel(true)
const runtimeParams = toRuntimeParams(activeModelParams) const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = toSettingParams(activeModelParams) const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim() const prompt = message.trim()

View File

@ -63,14 +63,9 @@ export function useStarterScreen() {
(x) => x.apiKey.length > 1 (x) => x.apiKey.length > 1
) )
let isShowStarterScreen const isShowStarterScreen =
isShowStarterScreen =
!isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length !isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length
// Remove this part when we rework on starter screen
isShowStarterScreen = false
return { return {
extensionHasSettings, extensionHasSettings,
isShowStarterScreen, isShowStarterScreen,

View File

@ -0,0 +1,314 @@
import { renderHook, act } from '@testing-library/react'
// Mock dependencies
jest.mock('ulidx')
jest.mock('@/extension')
import useUpdateModelParameters from './useUpdateModelParameters'
import { extensionManager } from '@/extension'
// Mock data
let model: any = {
id: 'model-1',
engine: 'nitro',
}
let extension: any = {
saveThread: jest.fn(),
}
const mockThread: any = {
id: 'thread-1',
assistants: [
{
model: {
parameters: {},
settings: {},
},
},
],
object: 'thread',
title: 'New Thread',
created: 0,
updated: 0,
}
describe('useUpdateModelParameters', () => {
beforeAll(() => {
jest.clearAllMocks()
jest.mock('./useRecommendedModel', () => ({
useRecommendedModel: () => ({
recommendedModel: model,
setRecommendedModel: jest.fn(),
downloadedModels: [],
}),
}))
})
it('should update model parameters and save thread when params are valid', async () => {
const mockValidParameters: any = {
params: {
// Inference
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
token_limit: 1000,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 0.3,
presence_penalty: 0.2,
// Load model
ctx_len: 1024,
ngl: 12,
embedding: true,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(mockThread, mockValidParameters)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
token_limit: 1000,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 0.3,
presence_penalty: 0.2,
},
settings: {
ctx_len: 1024,
ngl: 12,
embedding: true,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should not update invalid model parameters', async () => {
const mockInvalidParameters: any = {
params: {
// Inference
stop: [1, '<eos>'],
temperature: '0.5',
token_limit: '1000',
top_k: '0.7',
top_p: '0.1',
stream: 'true',
max_tokens: '1000',
frequency_penalty: '0.3',
presence_penalty: '0.2',
// Load model
ctx_len: '1024',
ngl: '12',
embedding: 'true',
n_parallel: '2',
cpu_threads: '4',
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockInvalidParameters
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
max_tokens: 1000,
token_limit: 1000,
},
settings: {
cpu_threads: 4,
ctx_len: 1024,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
n_parallel: 2,
ngl: 12,
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should update valid model parameters only', async () => {
const mockInvalidParameters: any = {
params: {
// Inference
stop: ['<eos>'],
temperature: -0.5,
token_limit: 100.2,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 1.2,
presence_penalty: 0.2,
// Load model
ctx_len: 1024,
ngl: 0,
embedding: 'true',
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockInvalidParameters
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>'],
top_k: 0.7,
top_p: 0.1,
stream: true,
token_limit: 100,
max_tokens: 1000,
presence_penalty: 0.2,
},
settings: {
ctx_len: 1024,
ngl: 0,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should handle missing modelId and engine gracefully', async () => {
const mockParametersWithoutModelIdAndEngine: any = {
params: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockParametersWithoutModelIdAndEngine
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
settings: {},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
})

View File

@ -4,24 +4,19 @@ import {
ConversationalExtension, ConversationalExtension,
ExtensionTypeEnum, ExtensionTypeEnum,
InferenceEngine, InferenceEngine,
Model,
ModelExtension,
Thread, Thread,
ThreadAssistantInfo, ThreadAssistantInfo,
} from '@janhq/core' } from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import {
extractInferenceParams,
import useRecommendedModel from './useRecommendedModel' extractModelLoadParams,
} from '@/utils/modelParam'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
selectedModelAtom,
updateDownloadedModelAtom,
} from '@/helpers/atoms/Model.atom'
import { import {
ModelParams, ModelParams,
getActiveThreadModelParamsAtom, getActiveThreadModelParamsAtom,
@ -31,28 +26,31 @@ import {
export type UpdateModelParameter = { export type UpdateModelParameter = {
params?: ModelParams params?: ModelParams
modelId?: string modelId?: string
modelPath?: string
engine?: InferenceEngine engine?: InferenceEngine
} }
export default function useUpdateModelParameters() { export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const [selectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
const { recommendedModel, setRecommendedModel } = useRecommendedModel()
const updateModelParameter = useCallback( const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => { async (thread: Thread, settings: UpdateModelParameter) => {
const toUpdateSettings = processStopWords(settings.params ?? {}) const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId const updatedModelParams = settings.modelId
? toUpdateSettings ? toUpdateSettings
: { ...activeModelParams, ...toUpdateSettings } : {
...selectedModel?.parameters,
...selectedModel?.settings,
...activeModelParams,
...toUpdateSettings,
}
// update the state // update the state
setThreadModelParams(thread.id, updatedModelParams) setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = toRuntimeParams(updatedModelParams) const runtimeParams = extractInferenceParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams) const settingParams = extractModelLoadParams(updatedModelParams)
const assistants = thread.assistants.map( const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => { (assistant: ThreadAssistantInfo) => {
@ -75,50 +73,8 @@ export default function useUpdateModelParameters() {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread) ?.saveThread(updatedThread)
// Persists default settings to model file
// Do not overwrite ctx_len and max_tokens
if (preserveModelFeatureEnabled) {
const defaultContextLength = settingParams.ctx_len
const defaultMaxTokens = runtimeParams.max_tokens
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
const { ctx_len, ...toSaveSettings } = settingParams
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
const { max_tokens, ...toSaveParams } = runtimeParams
const updatedModel = {
id: settings.modelId ?? selectedModel?.id,
parameters: {
...toSaveSettings,
}, },
settings: { [activeModelParams, selectedModel, setThreadModelParams]
...toSaveParams,
},
metadata: {
default_ctx_len: defaultContextLength,
default_max_tokens: defaultMaxTokens,
},
} as Partial<Model>
const model = await extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.updateModelInfo(updatedModel)
if (model) updateDownloadedModel(model)
if (selectedModel?.id === model?.id) setSelectedModel(model)
if (recommendedModel?.id === model?.id) setRecommendedModel(model)
}
},
[
activeModelParams,
selectedModel,
setThreadModelParams,
preserveModelFeatureEnabled,
updateDownloadedModel,
setSelectedModel,
recommendedModel,
setRecommendedModel,
]
) )
const processStopWords = (params: ModelParams): ModelParams => { const processStopWords = (params: ModelParams): ModelParams => {

View File

@ -1,6 +1,6 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { Model } from '@janhq/core' import { ModelFile } from '@janhq/core'
import { Button, Badge, Tooltip } from '@janhq/joi' import { Button, Badge, Tooltip } from '@janhq/joi'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
@ -38,7 +38,7 @@ import {
} from '@/helpers/atoms/SystemBar.atom' } from '@/helpers/atoms/SystemBar.atom'
type Props = { type Props = {
model: Model model: ModelFile
onClick: () => void onClick: () => void
open: string open: string
} }

View File

@ -1,6 +1,6 @@
import { useState } from 'react' import { useState } from 'react'
import { Model } from '@janhq/core' import { ModelFile } from '@janhq/core'
import { Badge } from '@janhq/joi' import { Badge } from '@janhq/joi'
import { twMerge } from 'tailwind-merge' import { twMerge } from 'tailwind-merge'
@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader'
import { toGibibytes } from '@/utils/converter' import { toGibibytes } from '@/utils/converter'
type Props = { type Props = {
model: Model model: ModelFile
} }
const ModelItem: React.FC<Props> = ({ model }) => { const ModelItem: React.FC<Props> = ({ model }) => {

View File

@ -1,6 +1,6 @@
import { useMemo } from 'react' import { useMemo } from 'react'
import { Model } from '@janhq/core' import { ModelFile } from '@janhq/core'
import { useAtomValue } from 'jotai' import { useAtomValue } from 'jotai'
@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
type Props = { type Props = {
models: Model[] models: ModelFile[]
} }
const ModelList = ({ models }: Props) => { const ModelList = ({ models }: Props) => {
const downloadedModels = useAtomValue(downloadedModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom)
const sortedModels: Model[] = useMemo(() => { const sortedModels: ModelFile[] = useMemo(() => {
const featuredModels: Model[] = [] const featuredModels: ModelFile[] = []
const remoteModels: Model[] = [] const remoteModels: ModelFile[] = []
const localModels: Model[] = [] const localModels: ModelFile[] = []
const remainingModels: Model[] = [] const remainingModels: ModelFile[] = []
models.forEach((m) => { models.forEach((m) => {
if (m.metadata?.tags?.includes('Featured')) { if (m.metadata?.tags?.includes('Featured')) {
featuredModels.push(m) featuredModels.push(m)

View File

@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { getConfigurationsData } from '@/utils/componentSettings' import { getConfigurationsData } from '@/utils/componentSettings'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
const selectedModel = useAtomValue(selectedModelAtom) const selectedModel = useAtomValue(selectedModelAtom)
const [currentModelSettingParams, setCurrentModelSettingParams] = useState( const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
toSettingParams(selectedModel?.settings) extractModelLoadParams(selectedModel?.settings)
) )
useEffect(() => { useEffect(() => {
if (selectedModel) { if (selectedModel) {
setCurrentModelSettingParams(toSettingParams(selectedModel?.settings)) setCurrentModelSettingParams(
extractModelLoadParams(selectedModel?.settings)
)
} }
}, [selectedModel]) }, [selectedModel])
const modelRuntimeParams = toRuntimeParams(selectedModel?.settings) const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
const componentDataRuntimeSetting = getConfigurationsData( const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams, modelRuntimeParams,

View File

@ -100,6 +100,7 @@ const DataFolder = () => {
<div className="flex items-center gap-x-3"> <div className="flex items-center gap-x-3">
<div className="relative"> <div className="relative">
<Input <Input
data-testid="jan-data-folder-input"
value={janDataFolderPath} value={janDataFolderPath}
className="w-full pr-8 sm:w-[240px]" className="w-full pr-8 sm:w-[240px]"
disabled disabled

View File

@ -21,7 +21,11 @@ const FactoryReset = () => {
recommended only if the application is in a corrupted state. recommended only if the application is in a corrupted state.
</p> </p>
</div> </div>
<Button theme="destructive" onClick={() => setModalValidation(true)}> <Button
data-testid="reset-button"
theme="destructive"
onClick={() => setModalValidation(true)}
>
Reset Reset
</Button> </Button>
<ModalValidation /> <ModalValidation />

View File

@ -0,0 +1,154 @@
import React from 'react'
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import '@testing-library/jest-dom'
import Advanced from '.'
class ResizeObserverMock {
observe() {}
unobserve() {}
disconnect() {}
}
global.ResizeObserver = ResizeObserverMock
// @ts-ignore
global.window.core = {
api: {
getAppConfigurations: () => jest.fn(),
updateAppConfiguration: () => jest.fn(),
relaunch: () => jest.fn(),
},
}
const setSettingsMock = jest.fn()
// Mock useSettings hook
jest.mock('@/hooks/useSettings', () => ({
__esModule: true,
useSettings: () => ({
readSettings: () => ({
run_mode: 'gpu',
experimental: false,
proxy: false,
gpus: [{ name: 'gpu-1' }, { name: 'gpu-2' }],
gpus_in_use: ['0'],
quick_ask: false,
}),
setSettings: setSettingsMock,
}),
}))
import * as toast from '@/containers/Toast'
jest.mock('@/containers/Toast')
jest.mock('@janhq/core', () => ({
__esModule: true,
...jest.requireActual('@janhq/core'),
fs: {
rm: jest.fn(),
},
}))
// Simulate a full advanced settings screen
// @ts-ignore
global.isMac = false
// @ts-ignore
global.isWindows = true
describe('Advanced', () => {
it('renders the component', async () => {
render(<Advanced />)
await waitFor(() => {
expect(screen.getByText('Experimental Mode')).toBeInTheDocument()
expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
})
})
it('updates Experimental enabled', async () => {
render(<Advanced />)
let experimentalToggle
await waitFor(() => {
experimentalToggle = screen.getByTestId(/experimental-switch/i)
fireEvent.click(experimentalToggle!)
})
expect(experimentalToggle).toBeChecked()
})
it('updates Experimental disabled', async () => {
render(<Advanced />)
let experimentalToggle
await waitFor(() => {
experimentalToggle = screen.getByTestId(/experimental-switch/i)
fireEvent.click(experimentalToggle!)
})
expect(experimentalToggle).not.toBeChecked()
})
it('clears logs', async () => {
const jestMock = jest.fn()
jest.spyOn(toast, 'toaster').mockImplementation(jestMock)
render(<Advanced />)
let clearLogsButton
await waitFor(() => {
clearLogsButton = screen.getByTestId(/clear-logs/i)
fireEvent.click(clearLogsButton)
})
expect(clearLogsButton).toBeInTheDocument()
expect(jestMock).toHaveBeenCalled()
})
it('toggles proxy enabled', async () => {
render(<Advanced />)
let proxyToggle
await waitFor(() => {
expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
proxyToggle = screen.getByTestId(/proxy-switch/i)
fireEvent.click(proxyToggle)
})
expect(proxyToggle).toBeChecked()
})
it('updates proxy settings', async () => {
render(<Advanced />)
let proxyInput
await waitFor(() => {
const proxyToggle = screen.getByTestId(/proxy-switch/i)
fireEvent.click(proxyToggle)
proxyInput = screen.getByTestId(/proxy-input/i)
fireEvent.change(proxyInput, { target: { value: 'http://proxy.com' } })
})
expect(proxyInput).toHaveValue('http://proxy.com')
})
it('toggles ignore SSL certificates', async () => {
render(<Advanced />)
let ignoreSslToggle
await waitFor(() => {
expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
ignoreSslToggle = screen.getByTestId(/ignore-ssl-switch/i)
fireEvent.click(ignoreSslToggle)
})
expect(ignoreSslToggle).toBeChecked()
})
it('renders DataFolder component', async () => {
render(<Advanced />)
await waitFor(() => {
expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
expect(screen.getByTestId(/jan-data-folder-input/i)).toBeInTheDocument()
})
})
it('renders FactoryReset component', async () => {
render(<Advanced />)
await waitFor(() => {
expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
expect(screen.getByTestId(/reset-button/i)).toBeInTheDocument()
})
})
})

View File

@ -43,19 +43,10 @@ type GPU = {
name: string name: string
} }
const test = [ /**
{ * Advanced Settings Screen
id: 'test a', * @returns
vram: 2, */
name: 'nvidia A',
},
{
id: 'test',
vram: 2,
name: 'nvidia B',
},
]
const Advanced = () => { const Advanced = () => {
const [experimentalEnabled, setExperimentalEnabled] = useAtom( const [experimentalEnabled, setExperimentalEnabled] = useAtom(
experimentalFeatureEnabledAtom experimentalFeatureEnabledAtom
@ -69,7 +60,7 @@ const Advanced = () => {
const [partialProxy, setPartialProxy] = useState<string>(proxy) const [partialProxy, setPartialProxy] = useState<string>(proxy)
const [gpuEnabled, setGpuEnabled] = useState<boolean>(false) const [gpuEnabled, setGpuEnabled] = useState<boolean>(false)
const [gpuList, setGpuList] = useState<GPU[]>(test) const [gpuList, setGpuList] = useState<GPU[]>([])
const [gpusInUse, setGpusInUse] = useState<string[]>([]) const [gpusInUse, setGpusInUse] = useState<string[]>([])
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>( const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
null null
@ -87,6 +78,9 @@ const Advanced = () => {
return y['name'] return y['name']
}) })
/**
* Handle proxy change
*/
const onProxyChange = useCallback( const onProxyChange = useCallback(
(event: ChangeEvent<HTMLInputElement>) => { (event: ChangeEvent<HTMLInputElement>) => {
const value = event.target.value || '' const value = event.target.value || ''
@ -100,6 +94,12 @@ const Advanced = () => {
[setPartialProxy, setProxy] [setPartialProxy, setProxy]
) )
/**
* Update Quick Ask Enabled
* @param e
* @param relaunch
* @returns void
*/
const updateQuickAskEnabled = async ( const updateQuickAskEnabled = async (
e: boolean, e: boolean,
relaunch: boolean = true relaunch: boolean = true
@ -111,6 +111,12 @@ const Advanced = () => {
if (relaunch) window.core?.api?.relaunch() if (relaunch) window.core?.api?.relaunch()
} }
/**
* Update Vulkan Enabled
* @param e
* @param relaunch
* @returns void
*/
const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => { const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => {
toaster({ toaster({
title: 'Reload', title: 'Reload',
@ -123,11 +129,19 @@ const Advanced = () => {
if (relaunch) window.location.reload() if (relaunch) window.location.reload()
} }
/**
* Update Experimental Enabled
* @param e
* @returns
*/
const updateExperimentalEnabled = async ( const updateExperimentalEnabled = async (
e: ChangeEvent<HTMLInputElement> e: ChangeEvent<HTMLInputElement>
) => { ) => {
setExperimentalEnabled(e.target.checked) setExperimentalEnabled(e.target.checked)
if (e) return
// If it checked, we don't need to do anything else
// Otherwise have to reset other settings
if (e.target.checked) return
// It affects other settings, so we need to reset them // It affects other settings, so we need to reset them
const isRelaunch = quickAskEnabled || vulkanEnabled const isRelaunch = quickAskEnabled || vulkanEnabled
@ -136,6 +150,9 @@ const Advanced = () => {
if (isRelaunch) window.core?.api?.relaunch() if (isRelaunch) window.core?.api?.relaunch()
} }
/**
* useEffect to set GPU enabled if possible
*/
useEffect(() => { useEffect(() => {
const setUseGpuIfPossible = async () => { const setUseGpuIfPossible = async () => {
const settings = await readSettings() const settings = await readSettings()
@ -149,6 +166,10 @@ const Advanced = () => {
setUseGpuIfPossible() setUseGpuIfPossible()
}, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled]) }, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled])
/**
* Clear logs
* @returns
*/
const clearLogs = async () => { const clearLogs = async () => {
try { try {
await fs.rm(`file://logs`) await fs.rm(`file://logs`)
@ -163,6 +184,11 @@ const Advanced = () => {
}) })
} }
/**
* Handle GPU Change
* @param gpuId
* @returns
*/
const handleGPUChange = (gpuId: string) => { const handleGPUChange = (gpuId: string) => {
let updatedGpusInUse = [...gpusInUse] let updatedGpusInUse = [...gpusInUse]
if (updatedGpusInUse.includes(gpuId)) { if (updatedGpusInUse.includes(gpuId)) {
@ -188,6 +214,9 @@ const Advanced = () => {
const gpuSelectionPlaceHolder = const gpuSelectionPlaceHolder =
gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU" gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU"
/**
* Handle click outside
*/
useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
return ( return (
@ -204,6 +233,7 @@ const Advanced = () => {
</p> </p>
</div> </div>
<Switch <Switch
data-testid="experimental-switch"
checked={experimentalEnabled} checked={experimentalEnabled}
onChange={updateExperimentalEnabled} onChange={updateExperimentalEnabled}
/> />
@ -401,11 +431,13 @@ const Advanced = () => {
<div className="flex w-full flex-shrink-0 flex-col items-end gap-2 pr-1 sm:w-1/2"> <div className="flex w-full flex-shrink-0 flex-col items-end gap-2 pr-1 sm:w-1/2">
<Switch <Switch
data-testid="proxy-switch"
checked={proxyEnabled} checked={proxyEnabled}
onChange={() => setProxyEnabled(!proxyEnabled)} onChange={() => setProxyEnabled(!proxyEnabled)}
/> />
<div className="w-full"> <div className="w-full">
<Input <Input
data-testid="proxy-input"
placeholder={'http://<user>:<password>@<domain or IP>:<port>'} placeholder={'http://<user>:<password>@<domain or IP>:<port>'}
value={partialProxy} value={partialProxy}
onChange={onProxyChange} onChange={onProxyChange}
@ -428,6 +460,7 @@ const Advanced = () => {
</p> </p>
</div> </div>
<Switch <Switch
data-testid="ignore-ssl-switch"
checked={ignoreSSL} checked={ignoreSSL}
onChange={(e) => setIgnoreSSL(e.target.checked)} onChange={(e) => setIgnoreSSL(e.target.checked)}
/> />
@ -448,6 +481,7 @@ const Advanced = () => {
</p> </p>
</div> </div>
<Switch <Switch
data-testid="quick-ask-switch"
checked={quickAskEnabled} checked={quickAskEnabled}
onChange={() => { onChange={() => {
toaster({ toaster({
@ -471,7 +505,11 @@ const Advanced = () => {
Clear all logs from Jan app. Clear all logs from Jan app.
</p> </p>
</div> </div>
<Button theme="destructive" onClick={clearLogs}> <Button
data-testid="clear-logs"
theme="destructive"
onClick={clearLogs}
>
Clear Clear
</Button> </Button>
</div> </div>

View File

@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC<Props> = ({
const { requestCreateNewThread } = useCreateNewThread() const { requestCreateNewThread } = useCreateNewThread()
const setMainViewState = useSetAtom(mainViewStateAtom) const setMainViewState = useSetAtom(mainViewStateAtom)
const assistants = useAtomValue(assistantsAtom) const assistants = useAtomValue(assistantsAtom)
const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null const downloadedModel = downloadedModels.find((md) => md.id === fileName)
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
const defaultModel = useAtomValue(defaultModelAtom) const defaultModel = useAtomValue(defaultModelAtom)
@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC<Props> = ({
alert('No assistant available') alert('No assistant available')
return return
} }
await requestCreateNewThread(assistants[0], model) await requestCreateNewThread(assistants[0], downloadedModel)
setMainViewState(MainViewState.Thread) setMainViewState(MainViewState.Thread)
setHfImportingStage('NONE') setHfImportingStage('NONE')
}, [ }, [
assistants, assistants,
model, downloadedModel,
requestCreateNewThread, requestCreateNewThread,
setMainViewState, setMainViewState,
setHfImportingStage, setHfImportingStage,
@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC<Props> = ({
</Badge> </Badge>
</div> </div>
{isDownloaded ? ( {downloadedModel ? (
<Button <Button
variant="soft" variant="soft"
className="min-w-[98px]" className="min-w-[98px]"

View File

@ -1,6 +1,6 @@
import { memo, useState } from 'react' import { memo, useState } from 'react'
import { InferenceEngine, Model } from '@janhq/core' import { InferenceEngine, ModelFile } from '@janhq/core'
import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi' import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi'
import { useAtom } from 'jotai' import { useAtom } from 'jotai'
import { import {
@ -21,7 +21,7 @@ import { localEngines } from '@/utils/modelEngine'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
type Props = { type Props = {
model: Model model: ModelFile
groupTitle?: string groupTitle?: string
} }

View File

@ -58,10 +58,22 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
const configuredModels = useAtomValue(configuredModelsAtom) const configuredModels = useAtomValue(configuredModelsAtom)
const setMainViewState = useSetAtom(mainViewStateAtom) const setMainViewState = useSetAtom(mainViewStateAtom)
const featuredModel = configuredModels.filter( const recommendModel = ['gemma-2-2b-it', 'llama3.1-8b-instruct']
(x) => x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
const featuredModel = configuredModels.filter((x) => {
const manualRecommendModel = configuredModels.filter((x) =>
recommendModel.includes(x.id)
) )
if (manualRecommendModel.length === 2) {
return x.id === recommendModel[0] || x.id === recommendModel[1]
} else {
return (
x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
)
}
})
const remoteModel = configuredModels.filter( const remoteModel = configuredModels.filter(
(x) => !localEngines.includes(x.engine) (x) => !localEngines.includes(x.engine)
) )
@ -105,7 +117,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
width={48} width={48}
height={48} height={48}
/> />
<h1 className="text-base font-semibold">Select a model to start</h1> <h1 className="text-base font-medium">Select a model to start</h1>
<div className="mt-6 w-[320px] md:w-[400px]"> <div className="mt-6 w-[320px] md:w-[400px]">
<Fragment> <Fragment>
<div className="relative" ref={refDropdown}> <div className="relative" ref={refDropdown}>
@ -120,7 +132,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/> />
<div <div
className={twMerge( className={twMerge(
'absolute left-0 top-10 max-h-[240px] w-full overflow-x-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))]', 'absolute left-0 top-10 z-20 max-h-[240px] w-full overflow-x-auto rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))]',
!isOpen ? 'invisible' : 'visible' !isOpen ? 'invisible' : 'visible'
)} )}
> >
@ -205,18 +217,20 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return ( return (
<div <div
key={featModel.id} key={featModel.id}
className="my-2 flex items-center justify-between gap-2 border-b border-[hsla(var(--app-border))] py-4 last:border-none" className="my-2 flex items-center justify-between gap-2 border-b border-[hsla(var(--app-border))] pb-4 pt-1 last:border-none"
> >
<div className="w-full text-left"> <div className="w-full text-left">
<h6>{featModel.name}</h6> <h6 className="font-medium">{featModel.name}</h6>
<p className="mt-4 text-[hsla(var(--text-secondary))]"> <p className="mt-2 font-medium text-[hsla(var(--text-secondary))]">
{featModel.metadata.author} {featModel.metadata.author}
</p> </p>
</div> </div>
{isDownloading ? ( {isDownloading ? (
<div className="flex w-full items-center gap-2"> <div className="flex w-full items-center gap-2">
{Object.values(downloadStates).map((item, i) => ( {Object.values(downloadStates)
.filter((x) => x.modelId === featModel.id)
.map((item, i) => (
<div <div
className="flex w-full items-center gap-2" className="flex w-full items-center gap-2"
key={i} key={i}
@ -248,7 +262,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
> >
Download Download
</Button> </Button>
<span className="font-medium text-[hsla(var(--text-secondary))]"> <span className="text-[hsla(var(--text-secondary))]">
{toGibibytes(featModel.metadata.size)} {toGibibytes(featModel.metadata.size)}
</span> </span>
</div> </div>
@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
) )
})} })}
<div className="mb-4 mt-8 flex items-center justify-between"> <div className="mb-2 mt-8 flex items-center justify-between">
<h2 className="text-[hsla(var(--text-secondary))]"> <h2 className="text-[hsla(var(--text-secondary))]">
Cloud Models Cloud Models
</h2> </h2>
@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return ( return (
<div <div
key={rowIndex} key={rowIndex}
className="my-2 flex items-center justify-center gap-4 md:gap-10" className="my-2 flex items-center gap-4 md:gap-10"
> >
{row.map((remoteEngine) => { {row.map((remoteEngine) => {
const engineLogo = getLogoEngine( const engineLogo = getLogoEngine(
@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/> />
)} )}
<p> <p className="font-medium">
{getTitleByEngine( {getTitleByEngine(
remoteEngine as InferenceEngine remoteEngine as InferenceEngine
)} )}

View File

@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings' import { getConfigurationsData } from '@/utils/componentSettings'
import { localEngines } from '@/utils/modelEngine' import { localEngines } from '@/utils/modelEngine'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting' import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools' import Tools from './Tools'
@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
const settings = useMemo(() => { const settings = useMemo(() => {
// runtime setting // runtime setting
const modelRuntimeParams = toRuntimeParams(activeModelParams) const modelRuntimeParams = extractInferenceParams(
{
...selectedModel?.parameters,
...activeModelParams,
},
selectedModel?.parameters
)
const componentDataRuntimeSetting = getConfigurationsData( const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams, modelRuntimeParams,
selectedModel selectedModel
).filter((x) => x.key !== 'prompt_template') ).filter((x) => x.key !== 'prompt_template')
// engine setting // engine setting
const modelEngineParams = toSettingParams(activeModelParams) const modelEngineParams = extractModelLoadParams(
{
...selectedModel?.settings,
...activeModelParams,
},
selectedModel?.settings
)
const componentDataEngineSetting = getConfigurationsData( const componentDataEngineSetting = getConfigurationsData(
modelEngineParams, modelEngineParams,
selectedModel selectedModel
@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
}, [activeModelParams, selectedModel]) }, [activeModelParams, selectedModel])
const promptTemplateSettings = useMemo(() => { const promptTemplateSettings = useMemo(() => {
const modelEngineParams = toSettingParams(activeModelParams) const modelEngineParams = extractModelLoadParams({
...selectedModel?.settings,
...activeModelParams,
})
const componentDataEngineSetting = getConfigurationsData( const componentDataEngineSetting = getConfigurationsData(
modelEngineParams, modelEngineParams,
selectedModel selectedModel

View File

@ -0,0 +1,35 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import ThreadScreen from './index'
import { useStarterScreen } from '../../hooks/useStarterScreen'
import '@testing-library/jest-dom'
global.ResizeObserver = class {
observe() {}
unobserve() {}
disconnect() {}
}
// Mock the useStarterScreen hook
jest.mock('@/hooks/useStarterScreen')
describe('ThreadScreen', () => {
it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
;(useStarterScreen as jest.Mock).mockReturnValue({
isShowStarterScreen: true,
extensionHasSettings: false,
})
const { getByText } = render(<ThreadScreen />)
expect(getByText('Select a model to start')).toBeInTheDocument()
})
it('renders Thread panels when isShowStarterScreen is false', () => {
;(useStarterScreen as jest.Mock).mockReturnValue({
isShowStarterScreen: false,
extensionHasSettings: false,
})
const { getByText } = render(<ThreadScreen />)
expect(getByText('Welcome!')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,183 @@
// web/utils/modelParam.test.ts
import { normalizeValue, validationRules } from './modelParam'
describe('validationRules', () => {
it('should validate temperature correctly', () => {
expect(validationRules.temperature(0.5)).toBe(true)
expect(validationRules.temperature(2)).toBe(true)
expect(validationRules.temperature(0)).toBe(true)
expect(validationRules.temperature(-0.1)).toBe(false)
expect(validationRules.temperature(2.3)).toBe(false)
expect(validationRules.temperature('0.5')).toBe(false)
})
it('should validate token_limit correctly', () => {
expect(validationRules.token_limit(100)).toBe(true)
expect(validationRules.token_limit(1)).toBe(true)
expect(validationRules.token_limit(0)).toBe(true)
expect(validationRules.token_limit(-1)).toBe(false)
expect(validationRules.token_limit('100')).toBe(false)
})
it('should validate top_k correctly', () => {
expect(validationRules.top_k(0.5)).toBe(true)
expect(validationRules.top_k(1)).toBe(true)
expect(validationRules.top_k(0)).toBe(true)
expect(validationRules.top_k(-0.1)).toBe(false)
expect(validationRules.top_k(1.1)).toBe(false)
expect(validationRules.top_k('0.5')).toBe(false)
})
it('should validate top_p correctly', () => {
expect(validationRules.top_p(0.5)).toBe(true)
expect(validationRules.top_p(1)).toBe(true)
expect(validationRules.top_p(0)).toBe(true)
expect(validationRules.top_p(-0.1)).toBe(false)
expect(validationRules.top_p(1.1)).toBe(false)
expect(validationRules.top_p('0.5')).toBe(false)
})
it('should validate stream correctly', () => {
expect(validationRules.stream(true)).toBe(true)
expect(validationRules.stream(false)).toBe(true)
expect(validationRules.stream('true')).toBe(false)
expect(validationRules.stream(1)).toBe(false)
})
it('should validate max_tokens correctly', () => {
expect(validationRules.max_tokens(100)).toBe(true)
expect(validationRules.max_tokens(1)).toBe(true)
expect(validationRules.max_tokens(0)).toBe(true)
expect(validationRules.max_tokens(-1)).toBe(false)
expect(validationRules.max_tokens('100')).toBe(false)
})
it('should validate stop correctly', () => {
expect(validationRules.stop(['word1', 'word2'])).toBe(true)
expect(validationRules.stop([])).toBe(true)
expect(validationRules.stop(['word1', 2])).toBe(false)
expect(validationRules.stop('word1')).toBe(false)
})
it('should validate frequency_penalty correctly', () => {
expect(validationRules.frequency_penalty(0.5)).toBe(true)
expect(validationRules.frequency_penalty(1)).toBe(true)
expect(validationRules.frequency_penalty(0)).toBe(true)
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
expect(validationRules.frequency_penalty(1.1)).toBe(false)
expect(validationRules.frequency_penalty('0.5')).toBe(false)
})
it('should validate presence_penalty correctly', () => {
expect(validationRules.presence_penalty(0.5)).toBe(true)
expect(validationRules.presence_penalty(1)).toBe(true)
expect(validationRules.presence_penalty(0)).toBe(true)
expect(validationRules.presence_penalty(-0.1)).toBe(false)
expect(validationRules.presence_penalty(1.1)).toBe(false)
expect(validationRules.presence_penalty('0.5')).toBe(false)
})
it('should validate ctx_len correctly', () => {
expect(validationRules.ctx_len(1024)).toBe(true)
expect(validationRules.ctx_len(1)).toBe(true)
expect(validationRules.ctx_len(0)).toBe(true)
expect(validationRules.ctx_len(-1)).toBe(false)
expect(validationRules.ctx_len('1024')).toBe(false)
})
it('should validate ngl correctly', () => {
expect(validationRules.ngl(12)).toBe(true)
expect(validationRules.ngl(1)).toBe(true)
expect(validationRules.ngl(0)).toBe(true)
expect(validationRules.ngl(-1)).toBe(false)
expect(validationRules.ngl('12')).toBe(false)
})
it('should validate embedding correctly', () => {
expect(validationRules.embedding(true)).toBe(true)
expect(validationRules.embedding(false)).toBe(true)
expect(validationRules.embedding('true')).toBe(false)
expect(validationRules.embedding(1)).toBe(false)
})
it('should validate n_parallel correctly', () => {
expect(validationRules.n_parallel(2)).toBe(true)
expect(validationRules.n_parallel(1)).toBe(true)
expect(validationRules.n_parallel(0)).toBe(true)
expect(validationRules.n_parallel(-1)).toBe(false)
expect(validationRules.n_parallel('2')).toBe(false)
})
it('should validate cpu_threads correctly', () => {
expect(validationRules.cpu_threads(4)).toBe(true)
expect(validationRules.cpu_threads(1)).toBe(true)
expect(validationRules.cpu_threads(0)).toBe(true)
expect(validationRules.cpu_threads(-1)).toBe(false)
expect(validationRules.cpu_threads('4')).toBe(false)
})
it('should validate prompt_template correctly', () => {
expect(validationRules.prompt_template('template')).toBe(true)
expect(validationRules.prompt_template('')).toBe(true)
expect(validationRules.prompt_template(123)).toBe(false)
})
it('should validate llama_model_path correctly', () => {
expect(validationRules.llama_model_path('path')).toBe(true)
expect(validationRules.llama_model_path('')).toBe(true)
expect(validationRules.llama_model_path(123)).toBe(false)
})
it('should validate mmproj correctly', () => {
expect(validationRules.mmproj('mmproj')).toBe(true)
expect(validationRules.mmproj('')).toBe(true)
expect(validationRules.mmproj(123)).toBe(false)
})
it('should validate vision_model correctly', () => {
expect(validationRules.vision_model(true)).toBe(true)
expect(validationRules.vision_model(false)).toBe(true)
expect(validationRules.vision_model('true')).toBe(false)
expect(validationRules.vision_model(1)).toBe(false)
})
it('should validate text_model correctly', () => {
expect(validationRules.text_model(true)).toBe(true)
expect(validationRules.text_model(false)).toBe(true)
expect(validationRules.text_model('true')).toBe(false)
expect(validationRules.text_model(1)).toBe(false)
})
})
describe('normalizeValue', () => {
it('should normalize ctx_len correctly', () => {
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
expect(normalizeValue('ctx_len', '2')).toBe(2)
expect(normalizeValue('ctx_len', 100)).toBe(100)
})
it('should normalize token_limit correctly', () => {
expect(normalizeValue('token_limit', 100.5)).toBe(100)
expect(normalizeValue('token_limit', '1')).toBe(1)
expect(normalizeValue('token_limit', 0)).toBe(0)
})
it('should normalize max_tokens correctly', () => {
expect(normalizeValue('max_tokens', 100.5)).toBe(100)
expect(normalizeValue('max_tokens', '1')).toBe(1)
expect(normalizeValue('max_tokens', 0)).toBe(0)
})
it('should normalize ngl correctly', () => {
expect(normalizeValue('ngl', 12.5)).toBe(12)
expect(normalizeValue('ngl', '2')).toBe(2)
expect(normalizeValue('ngl', 0)).toBe(0)
})
it('should normalize n_parallel correctly', () => {
expect(normalizeValue('n_parallel', 2.5)).toBe(2)
expect(normalizeValue('n_parallel', '2')).toBe(2)
expect(normalizeValue('n_parallel', 0)).toBe(0)
})
it('should normalize cpu_threads correctly', () => {
expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
expect(normalizeValue('cpu_threads', '4')).toBe(4)
expect(normalizeValue('cpu_threads', 0)).toBe(0)
})
})

View File

@ -1,9 +1,69 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/helpers/atoms/Thread.atom' import { ModelParams } from '@/helpers/atoms/Thread.atom'
export const toRuntimeParams = ( /**
modelParams?: ModelParams * Validation rules for model parameters
*/
export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) =>
Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
ngl: (value: any) => Number.isInteger(value) && value >= 0,
embedding: (value: any) => typeof value === 'boolean',
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
prompt_template: (value: any) => typeof value === 'string',
llama_model_path: (value: any) => typeof value === 'string',
mmproj: (value: any) => typeof value === 'string',
vision_model: (value: any) => typeof value === 'boolean',
text_model: (value: any) => typeof value === 'boolean',
}
/**
* There are some parameters that need to be normalized before being sent to the server
* E.g. ctx_len should be an integer, but it can be a float from the input field
* @param key
* @param value
* @returns
*/
export const normalizeValue = (key: string, value: any) => {
if (
key === 'token_limit' ||
key === 'max_tokens' ||
key === 'ctx_len' ||
key === 'ngl' ||
key === 'n_parallel' ||
key === 'cpu_threads'
) {
// Convert to integer
return Math.floor(Number(value))
}
return value
}
/**
* Extract inference parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractInferenceParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelRuntimeParams => { ): ModelRuntimeParams => {
if (!modelParams) return {} if (!modelParams) return {}
const defaultModelParams: ModelRuntimeParams = { const defaultModelParams: ModelRuntimeParams = {
@ -22,15 +82,35 @@ export const toRuntimeParams = (
for (const [key, value] of Object.entries(modelParams)) { for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) { if (key in defaultModelParams) {
Object.assign(runtimeParams, { ...runtimeParams, [key]: value }) const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: normalizeValue(key, value),
})
}
} }
} }
return runtimeParams return runtimeParams
} }
export const toSettingParams = ( /**
modelParams?: ModelParams * Extract model load parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractModelLoadParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelSettingParams => { ): ModelSettingParams => {
if (!modelParams) return {} if (!modelParams) return {}
const defaultSettingParams: ModelSettingParams = { const defaultSettingParams: ModelSettingParams = {
@ -49,7 +129,21 @@ export const toSettingParams = (
for (const [key, value] of Object.entries(modelParams)) { for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) { if (key in defaultSettingParams) {
Object.assign(settingParams, { ...settingParams, [key]: value }) const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(modelParams, {
...modelParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(settingParams, {
...settingParams,
[key]: normalizeValue(key, value),
})
}
} }
} }