diff --git a/.github/workflows/auto-trigger-jan-docs.yaml b/.github/workflows/auto-trigger-jan-docs.yaml new file mode 100644 index 000000000..a3001a9e0 --- /dev/null +++ b/.github/workflows/auto-trigger-jan-docs.yaml @@ -0,0 +1,25 @@ +name: Trigger Docs Workflow + +on: + release: + types: + - published + workflow_dispatch: + push: + branches: + - ci/auto-trigger-jan-docs-for-new-release + +jobs: + trigger_docs_workflow: + runs-on: ubuntu-latest + + steps: + - name: Trigger external workflow using GitHub API + env: + GITHUB_TOKEN: ${{ secrets.PAT_SERVICE_ACCOUNT }} + run: | + curl -X POST \ + -H "Accept: application/vnd.github.v3+json" \ + -H "Authorization: token $GITHUB_TOKEN" \ + https://api.github.com/repos/janhq/docs/actions/workflows/jan-docs.yml/dispatches \ + -d '{"ref":"main"}' diff --git a/.gitignore b/.gitignore index 646e6842a..f28d152d9 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ core/test_results.html coverage .yarn .yarnrc +test_results.html +*.tsbuildinfo diff --git a/core/jest.config.js b/core/jest.config.js index 6c805f1c9..9b1dd2ade 100644 --- a/core/jest.config.js +++ b/core/jest.config.js @@ -1,8 +1,17 @@ module.exports = { preset: 'ts-jest', testEnvironment: 'node', + collectCoverageFrom: ['src/**/*.{ts,tsx}'], moduleNameMapper: { '@/(.*)': '/src/$1', }, runner: './testRunner.js', + transform: { + "^.+\\.tsx?$": [ + "ts-jest", + { + diagnostics: false, + }, + ], + }, } diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts index 84250888e..f38cc0b40 100644 --- a/core/src/browser/core.test.ts +++ b/core/src/browser/core.test.ts @@ -1,98 +1,109 @@ -import { openExternalUrl } from './core'; -import { joinPath } from './core'; -import { openFileExplorer } from './core'; -import { getJanDataFolderPath } from './core'; -import { abortDownload } from './core'; -import { getFileSize } from './core'; -import { executeOnMain } from './core'; +import { openExternalUrl } from './core' +import { joinPath } from './core' +import { openFileExplorer } from './core' +import { getJanDataFolderPath } from './core' +import { abortDownload } from './core' +import { getFileSize } from './core' +import { executeOnMain } from './core' -it('should open external url', async () => { - const url = 'http://example.com'; - globalThis.core = { - api: { - openExternalUrl: jest.fn().mockResolvedValue('opened') +describe('test core apis', () => { + it('should open external url', async () => { + const url = 'http://example.com' + globalThis.core = { + api: { + openExternalUrl: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openExternalUrl(url); - expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url); - expect(result).toBe('opened'); -}); + const result = await openExternalUrl(url) + expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url) + expect(result).toBe('opened') + }) - -it('should join paths', async () => { - const paths = ['/path/one', '/path/two']; - globalThis.core = { - api: { - joinPath: jest.fn().mockResolvedValue('/path/one/path/two') + it('should join paths', async () => { + const paths = ['/path/one', '/path/two'] + globalThis.core = { + api: { + joinPath: jest.fn().mockResolvedValue('/path/one/path/two'), + }, } - }; - const result = await joinPath(paths); - expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths); - expect(result).toBe('/path/one/path/two'); -}); + const result = await joinPath(paths) + expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths) + expect(result).toBe('/path/one/path/two') + }) - -it('should open file explorer', async () => { - const path = '/path/to/open'; - globalThis.core = { - api: { - openFileExplorer: jest.fn().mockResolvedValue('opened') + it('should open file explorer', async () => { + const path = '/path/to/open' + globalThis.core = { + api: { + openFileExplorer: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openFileExplorer(path); - expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path); - expect(result).toBe('opened'); -}); + const result = await openFileExplorer(path) + expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path) + expect(result).toBe('opened') + }) - -it('should get jan data folder path', async () => { - globalThis.core = { - api: { - getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data') + it('should get jan data folder path', async () => { + globalThis.core = { + api: { + getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'), + }, } - }; - const result = await getJanDataFolderPath(); - expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled(); - expect(result).toBe('/path/to/jan/data'); -}); + const result = await getJanDataFolderPath() + expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled() + expect(result).toBe('/path/to/jan/data') + }) - -it('should abort download', async () => { - const fileName = 'testFile'; - globalThis.core = { - api: { - abortDownload: jest.fn().mockResolvedValue('aborted') + it('should abort download', async () => { + const fileName = 'testFile' + globalThis.core = { + api: { + abortDownload: jest.fn().mockResolvedValue('aborted'), + }, } - }; - const result = await abortDownload(fileName); - expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName); - expect(result).toBe('aborted'); -}); + const result = await abortDownload(fileName) + expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName) + expect(result).toBe('aborted') + }) - -it('should get file size', async () => { - const url = 'http://example.com/file'; - globalThis.core = { - api: { - getFileSize: jest.fn().mockResolvedValue(1024) + it('should get file size', async () => { + const url = 'http://example.com/file' + globalThis.core = { + api: { + getFileSize: jest.fn().mockResolvedValue(1024), + }, } - }; - const result = await getFileSize(url); - expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url); - expect(result).toBe(1024); -}); + const result = await getFileSize(url) + expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url) + expect(result).toBe(1024) + }) - -it('should execute function on main process', async () => { - const extension = 'testExtension'; - const method = 'testMethod'; - const args = ['arg1', 'arg2']; - globalThis.core = { - api: { - invokeExtensionFunc: jest.fn().mockResolvedValue('result') + it('should execute function on main process', async () => { + const extension = 'testExtension' + const method = 'testMethod' + const args = ['arg1', 'arg2'] + globalThis.core = { + api: { + invokeExtensionFunc: jest.fn().mockResolvedValue('result'), + }, } - }; - const result = await executeOnMain(extension, method, ...args); - expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args); - expect(result).toBe('result'); -}); + const result = await executeOnMain(extension, method, ...args) + expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args) + expect(result).toBe('result') + }) +}) + +describe('dirName - just a pass thru api', () => { + it('should retrieve the directory name from a file path', async () => { + const mockDirName = jest.fn() + globalThis.core = { + api: { + dirName: mockDirName.mockResolvedValue('/path/to'), + }, + } + // Normal file path with extension + const path = '/path/to/file.txt' + await globalThis.core.api.dirName(path) + expect(mockDirName).toHaveBeenCalledWith(path) + }) +}) diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts index fdbceb06b..b19e0b339 100644 --- a/core/src/browser/core.ts +++ b/core/src/browser/core.ts @@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) => const joinPath: (paths: string[]) => Promise = (paths) => globalThis.core.api?.joinPath(paths) +/** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + * @returns {Promise} A promise that resolves the dirname. + */ +const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path) + /** * Retrieve the basename from an url. * @param path - The path to retrieve. @@ -161,5 +168,6 @@ export { systemInformation, showToast, getFileSize, + dirName, FileStat, } diff --git a/core/src/browser/extension.test.ts b/core/src/browser/extension.test.ts index 6c1cd8579..2db14a24e 100644 --- a/core/src/browser/extension.test.ts +++ b/core/src/browser/extension.test.ts @@ -1,4 +1,9 @@ import { BaseExtension } from './extension' +import { SettingComponentProps } from '../types' +import { getJanDataFolderPath, joinPath } from './core' +import { fs } from './fs' +jest.mock('./core') +jest.mock('./fs') class TestBaseExtension extends BaseExtension { onLoad(): void {} @@ -44,3 +49,103 @@ describe('BaseExtension', () => { // Add your assertions here }) }) + +describe('BaseExtension', () => { + class TestBaseExtension extends BaseExtension { + onLoad(): void {} + onUnload(): void {} + } + + let baseExtension: TestBaseExtension + + beforeEach(() => { + baseExtension = new TestBaseExtension('https://example.com', 'TestExtension') + }) + + afterEach(() => { + jest.resetAllMocks() + }) + + it('should have the correct properties', () => { + expect(baseExtension.name).toBe('TestExtension') + expect(baseExtension.productName).toBeUndefined() + expect(baseExtension.url).toBe('https://example.com') + expect(baseExtension.active).toBeUndefined() + expect(baseExtension.description).toBeUndefined() + expect(baseExtension.version).toBeUndefined() + }) + + it('should return undefined for type()', () => { + expect(baseExtension.type()).toBeUndefined() + }) + + it('should have abstract methods onLoad() and onUnload()', () => { + expect(baseExtension.onLoad).toBeDefined() + expect(baseExtension.onUnload).toBeDefined() + }) + + it('should have installationState() return "NotRequired"', async () => { + const installationState = await baseExtension.installationState() + expect(installationState).toBe('NotRequired') + }) + + it('should install the extension', async () => { + await baseExtension.install() + // Add your assertions here + }) + + it('should register settings', async () => { + const settings: SettingComponentProps[] = [ + { key: 'setting1', controllerProps: { value: 'value1' } } as any, + { key: 'setting2', controllerProps: { value: 'value2' } } as any, + ] + + ;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data') + ;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension') + ;(fs.existsSync as jest.Mock).mockResolvedValue(false) + ;(fs.mkdir as jest.Mock).mockResolvedValue(undefined) + ;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined) + + await baseExtension.registerSettings(settings) + + expect(fs.mkdir).toHaveBeenCalledWith('/data/settings/TestExtension') + expect(fs.writeFileSync).toHaveBeenCalledWith( + '/data/settings/TestExtension', + JSON.stringify(settings, null, 2) + ) + }) + + it('should get setting with default value', async () => { + const settings: SettingComponentProps[] = [ + { key: 'setting1', controllerProps: { value: 'value1' } } as any, + ] + + jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings) + + const value = await baseExtension.getSetting('setting1', 'defaultValue') + expect(value).toBe('value1') + + const defaultValue = await baseExtension.getSetting('setting2', 'defaultValue') + expect(defaultValue).toBe('defaultValue') + }) + + it('should update settings', async () => { + const settings: SettingComponentProps[] = [ + { key: 'setting1', controllerProps: { value: 'value1' } } as any, + ] + + jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings) + ;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data') + ;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension/settings.json') + ;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined) + + await baseExtension.updateSettings([ + { key: 'setting1', controllerProps: { value: 'newValue' } } as any, + ]) + + expect(fs.writeFileSync).toHaveBeenCalledWith( + '/data/settings/TestExtension/settings.json', + JSON.stringify([{ key: 'setting1', controllerProps: { value: 'newValue' } }], null, 2) + ) + }) +}) diff --git a/core/src/browser/extensions/assistant.test.ts b/core/src/browser/extensions/assistant.test.ts new file mode 100644 index 000000000..ae81b0985 --- /dev/null +++ b/core/src/browser/extensions/assistant.test.ts @@ -0,0 +1,8 @@ + +import { AssistantExtension } from './assistant'; +import { ExtensionTypeEnum } from '../extension'; + +it('should return the correct type', () => { + const extension = new AssistantExtension(); + expect(extension.type()).toBe(ExtensionTypeEnum.Assistant); +}); diff --git a/core/src/browser/extensions/engines/AIEngine.test.ts b/core/src/browser/extensions/engines/AIEngine.test.ts new file mode 100644 index 000000000..59dad280f --- /dev/null +++ b/core/src/browser/extensions/engines/AIEngine.test.ts @@ -0,0 +1,59 @@ +import { AIEngine } from './AIEngine' +import { events } from '../../events' +import { ModelEvent, Model, ModelFile, InferenceEngine } from '../../../types' +import { EngineManager } from './EngineManager' +import { fs } from '../../fs' + +jest.mock('../../events') +jest.mock('./EngineManager') +jest.mock('../../fs') + +class TestAIEngine extends AIEngine { + onUnload(): void {} + provider = 'test-provider' + + inference(data: any) {} + + stopInference() {} +} + +describe('AIEngine', () => { + let engine: TestAIEngine + + beforeEach(() => { + engine = new TestAIEngine('', '') + jest.clearAllMocks() + }) + + it('should load model if provider matches', async () => { + const model: ModelFile = { id: 'model1', engine: 'test-provider' } as any + + await engine.loadModel(model) + + expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model) + }) + + it('should not load model if provider does not match', async () => { + const model: ModelFile = { id: 'model1', engine: 'other-provider' } as any + + await engine.loadModel(model) + + expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelReady, model) + }) + + it('should unload model if provider matches', async () => { + const model: Model = { id: 'model1', version: '1.0', engine: 'test-provider' } as any + + await engine.unloadModel(model) + + expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, model) + }) + + it('should not unload model if provider does not match', async () => { + const model: Model = { id: 'model1', version: '1.0', engine: 'other-provider' } as any + + await engine.unloadModel(model) + + expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, model) + }) +}) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 7cd9f513e..75354de88 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core' import { events } from '../../events' import { BaseExtension } from '../../extension' import { fs } from '../../fs' -import { MessageRequest, Model, ModelEvent } from '../../../types' +import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types' import { EngineManager } from './EngineManager' /** @@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension { override onLoad() { this.registerEngine() - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } @@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads the model. */ - async loadModel(model: Model): Promise { + async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return Promise.resolve() events.emit(ModelEvent.OnModelReady, model) return Promise.resolve() diff --git a/core/src/browser/extensions/engines/EngineManager.test.ts b/core/src/browser/extensions/engines/EngineManager.test.ts new file mode 100644 index 000000000..c1f1fcb71 --- /dev/null +++ b/core/src/browser/extensions/engines/EngineManager.test.ts @@ -0,0 +1,43 @@ +/** + * @jest-environment jsdom + */ +import { EngineManager } from './EngineManager' +import { AIEngine } from './AIEngine' + +// @ts-ignore +class MockAIEngine implements AIEngine { + provider: string + constructor(provider: string) { + this.provider = provider + } +} + +describe('EngineManager', () => { + let engineManager: EngineManager + + beforeEach(() => { + engineManager = new EngineManager() + }) + + test('should register an engine', () => { + const engine = new MockAIEngine('testProvider') + // @ts-ignore + engineManager.register(engine) + expect(engineManager.engines.get('testProvider')).toBe(engine) + }) + + test('should retrieve a registered engine by provider', () => { + const engine = new MockAIEngine('testProvider') + // @ts-ignore + engineManager.register(engine) + // @ts-ignore + const retrievedEngine = engineManager.get('testProvider') + expect(retrievedEngine).toBe(engine) + }) + + test('should return undefined for an unregistered provider', () => { + // @ts-ignore + const retrievedEngine = engineManager.get('nonExistentProvider') + expect(retrievedEngine).toBeUndefined() + }) +}) diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts new file mode 100644 index 000000000..4ae81496f --- /dev/null +++ b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts @@ -0,0 +1,100 @@ +/** + * @jest-environment jsdom + */ +import { LocalOAIEngine } from './LocalOAIEngine' +import { events } from '../../events' +import { ModelEvent, ModelFile, Model } from '../../../types' +import { executeOnMain, systemInformation, dirName } from '../../core' + +jest.mock('../../core', () => ({ + executeOnMain: jest.fn(), + systemInformation: jest.fn(), + dirName: jest.fn(), +})) + +jest.mock('../../events', () => ({ + events: { + on: jest.fn(), + emit: jest.fn(), + }, +})) + +class TestLocalOAIEngine extends LocalOAIEngine { + inferenceUrl = '' + nodeModule = 'testNodeModule' + provider = 'testProvider' +} + +describe('LocalOAIEngine', () => { + let engine: TestLocalOAIEngine + + beforeEach(() => { + engine = new TestLocalOAIEngine('', '') + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it('should subscribe to events on load', () => { + engine.onLoad() + expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelInit, expect.any(Function)) + expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelStop, expect.any(Function)) + }) + + it('should load model correctly', async () => { + const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any + const modelFolder = 'path/to' + const systemInfo = { os: 'testOS' } + const res = { error: null } + + ;(dirName as jest.Mock).mockResolvedValue(modelFolder) + ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo) + ;(executeOnMain as jest.Mock).mockResolvedValue(res) + + await engine.loadModel(model) + + expect(dirName).toHaveBeenCalledWith(model.file_path) + expect(systemInformation).toHaveBeenCalled() + expect(executeOnMain).toHaveBeenCalledWith( + engine.nodeModule, + engine.loadModelFunctionName, + { modelFolder, model }, + systemInfo + ) + expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model) + }) + + it('should handle load model error', async () => { + const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any + const modelFolder = 'path/to' + const systemInfo = { os: 'testOS' } + const res = { error: 'load error' } + + ;(dirName as jest.Mock).mockResolvedValue(modelFolder) + ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo) + ;(executeOnMain as jest.Mock).mockResolvedValue(res) + + await expect(engine.loadModel(model)).rejects.toEqual('load error') + + expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelFail, { error: res.error }) + }) + + it('should unload model correctly', async () => { + const model: Model = { engine: 'testProvider' } as any + + await engine.unloadModel(model) + + expect(executeOnMain).toHaveBeenCalledWith(engine.nodeModule, engine.unloadModelFunctionName) + expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, {}) + }) + + it('should not unload model if engine does not match', async () => { + const model: Model = { engine: 'otherProvider' } as any + + await engine.unloadModel(model) + + expect(executeOnMain).not.toHaveBeenCalled() + expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, {}) + }) +}) diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index fb9e4962c..123b9a593 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,6 +1,6 @@ -import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' +import { executeOnMain, systemInformation, dirName } from '../../core' import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' +import { Model, ModelEvent, ModelFile } from '../../../types' import { OAIEngine } from './OAIEngine' /** @@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine { unloadModelFunctionName: string = 'unloadModel' /** - * On extension load, subscribe to events. + * This class represents a base for local inference providers in the OpenAI architecture. + * It extends the OAIEngine class and provides the implementation of loading and unloading models locally. + * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated. + * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped. */ override onLoad() { super.onLoad() // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** * Load the model. */ - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return - const modelFolderName = 'models' - const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) + const modelFolder = await dirName(model.file_path) const systemInfo = await systemInformation() const res = await executeOnMain( this.nodeModule, diff --git a/core/src/browser/extensions/engines/OAIEngine.test.ts b/core/src/browser/extensions/engines/OAIEngine.test.ts new file mode 100644 index 000000000..81348786c --- /dev/null +++ b/core/src/browser/extensions/engines/OAIEngine.test.ts @@ -0,0 +1,119 @@ +/** + * @jest-environment jsdom + */ +import { OAIEngine } from './OAIEngine' +import { events } from '../../events' +import { + MessageEvent, + InferenceEvent, + MessageRequest, + MessageRequestType, + MessageStatus, + ChatCompletionRole, + ContentType, +} from '../../../types' +import { requestInference } from './helpers/sse' +import { ulid } from 'ulidx' + +jest.mock('./helpers/sse') +jest.mock('ulidx') +jest.mock('../../events') + +class TestOAIEngine extends OAIEngine { + inferenceUrl = 'http://test-inference-url' + provider = 'test-provider' + + async headers() { + return { Authorization: 'Bearer test-token' } + } +} + +describe('OAIEngine', () => { + let engine: TestOAIEngine + + beforeEach(() => { + engine = new TestOAIEngine('', '') + jest.clearAllMocks() + }) + + it('should subscribe to events on load', () => { + engine.onLoad() + expect(events.on).toHaveBeenCalledWith(MessageEvent.OnMessageSent, expect.any(Function)) + expect(events.on).toHaveBeenCalledWith(InferenceEvent.OnInferenceStopped, expect.any(Function)) + }) + + it('should handle inference request', async () => { + const data: MessageRequest = { + model: { engine: 'test-provider', id: 'test-model' } as any, + threadId: 'test-thread', + type: MessageRequestType.Thread, + assistantId: 'test-assistant', + messages: [{ role: ChatCompletionRole.User, content: 'Hello' }], + } + + ;(ulid as jest.Mock).mockReturnValue('test-id') + ;(requestInference as jest.Mock).mockReturnValue({ + subscribe: ({ next, complete }: any) => { + next('test response') + complete() + }, + }) + + await engine.inference(data) + + expect(requestInference).toHaveBeenCalledWith( + 'http://test-inference-url', + expect.objectContaining({ model: 'test-model' }), + expect.any(Object), + expect.any(AbortController), + { Authorization: 'Bearer test-token' }, + undefined + ) + + expect(events.emit).toHaveBeenCalledWith( + MessageEvent.OnMessageResponse, + expect.objectContaining({ id: 'test-id' }) + ) + expect(events.emit).toHaveBeenCalledWith( + MessageEvent.OnMessageUpdate, + expect.objectContaining({ + content: [{ type: ContentType.Text, text: { value: 'test response', annotations: [] } }], + status: MessageStatus.Ready, + }) + ) + }) + + it('should handle inference error', async () => { + const data: MessageRequest = { + model: { engine: 'test-provider', id: 'test-model' } as any, + threadId: 'test-thread', + type: MessageRequestType.Thread, + assistantId: 'test-assistant', + messages: [{ role: ChatCompletionRole.User, content: 'Hello' }], + } + + ;(ulid as jest.Mock).mockReturnValue('test-id') + ;(requestInference as jest.Mock).mockReturnValue({ + subscribe: ({ error }: any) => { + error({ message: 'test error', code: 500 }) + }, + }) + + await engine.inference(data) + + expect(events.emit).toHaveBeenCalledWith( + MessageEvent.OnMessageUpdate, + expect.objectContaining({ + content: [{ type: ContentType.Text, text: { value: 'test error', annotations: [] } }], + status: MessageStatus.Error, + error_code: 500, + }) + ) + }) + + it('should stop inference', () => { + engine.stopInference() + expect(engine.isCancelled).toBe(true) + expect(engine.controller.signal.aborted).toBe(true) + }) +}) diff --git a/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts b/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts new file mode 100644 index 000000000..871499f45 --- /dev/null +++ b/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts @@ -0,0 +1,43 @@ +/** + * @jest-environment jsdom + */ +import { RemoteOAIEngine } from './' + +class TestRemoteOAIEngine extends RemoteOAIEngine { + inferenceUrl: string = '' + provider: string = 'TestRemoteOAIEngine' +} + +describe('RemoteOAIEngine', () => { + let engine: TestRemoteOAIEngine + + beforeEach(() => { + engine = new TestRemoteOAIEngine('', '') + }) + + test('should call onLoad and super.onLoad', () => { + const onLoadSpy = jest.spyOn(engine, 'onLoad') + const superOnLoadSpy = jest.spyOn(Object.getPrototypeOf(RemoteOAIEngine.prototype), 'onLoad') + engine.onLoad() + + expect(onLoadSpy).toHaveBeenCalled() + expect(superOnLoadSpy).toHaveBeenCalled() + }) + + test('should return headers with apiKey', async () => { + engine.apiKey = 'test-api-key' + const headers = await engine.headers() + + expect(headers).toEqual({ + 'Authorization': 'Bearer test-api-key', + 'api-key': 'test-api-key', + }) + }) + + test('should return empty headers when apiKey is not set', async () => { + engine.apiKey = undefined + const headers = await engine.headers() + + expect(headers).toEqual({}) + }) +}) diff --git a/core/src/browser/extensions/engines/helpers/sse.test.ts b/core/src/browser/extensions/engines/helpers/sse.test.ts index cff5b93b3..0b78aa9b5 100644 --- a/core/src/browser/extensions/engines/helpers/sse.test.ts +++ b/core/src/browser/extensions/engines/helpers/sse.test.ts @@ -1,6 +1,7 @@ import { lastValueFrom, Observable } from 'rxjs' import { requestInference } from './sse' +import { ReadableStream } from 'stream/web'; describe('requestInference', () => { it('should send a request to the inference server and return an Observable', () => { // Mock the fetch function @@ -58,3 +59,66 @@ describe('requestInference', () => { expect(lastValueFrom(result)).rejects.toEqual({ message: 'Wrong API Key', code: 'invalid_api_key' }) }) }) + + it('should handle a successful response with a transformResponse function', () => { + // Mock the fetch function + const mockFetch: any = jest.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve({ choices: [{ message: { content: 'Generated response' } }] }), + headers: new Headers(), + redirected: false, + status: 200, + statusText: 'OK', + }) + ) + jest.spyOn(global, 'fetch').mockImplementation(mockFetch) + + // Define the test inputs + const inferenceUrl = 'https://inference-server.com' + const requestBody = { message: 'Hello' } + const model = { id: 'model-id', parameters: { stream: false } } + const transformResponse = (data: any) => data.choices[0].message.content.toUpperCase() + + // Call the function + const result = requestInference(inferenceUrl, requestBody, model, undefined, undefined, transformResponse) + + // Assert the expected behavior + expect(result).toBeInstanceOf(Observable) + expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE') + }) + + + it('should handle a successful response with streaming enabled', () => { + // Mock the fetch function + const mockFetch: any = jest.fn(() => + Promise.resolve({ + ok: true, + body: new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode('data: {"choices": [{"delta": {"content": "Streamed"}}]}')); + controller.enqueue(new TextEncoder().encode('data: [DONE]')); + controller.close(); + } + }), + headers: new Headers(), + redirected: false, + status: 200, + statusText: 'OK', + }) + ); + jest.spyOn(global, 'fetch').mockImplementation(mockFetch); + + // Define the test inputs + const inferenceUrl = 'https://inference-server.com'; + const requestBody = { message: 'Hello' }; + const model = { id: 'model-id', parameters: { stream: true } }; + + // Call the function + const result = requestInference(inferenceUrl, requestBody, model); + + // Assert the expected behavior + expect(result).toBeInstanceOf(Observable); + expect(lastValueFrom(result)).resolves.toEqual('Streamed'); + }); + diff --git a/core/src/browser/extensions/engines/index.test.ts b/core/src/browser/extensions/engines/index.test.ts new file mode 100644 index 000000000..4c0ef11d8 --- /dev/null +++ b/core/src/browser/extensions/engines/index.test.ts @@ -0,0 +1,6 @@ + +import { expect } from '@jest/globals'; + +it('should re-export all exports from ./AIEngine', () => { + expect(require('./index')).toHaveProperty('AIEngine'); +}); diff --git a/core/src/browser/extensions/index.test.ts b/core/src/browser/extensions/index.test.ts new file mode 100644 index 000000000..26cbda8c5 --- /dev/null +++ b/core/src/browser/extensions/index.test.ts @@ -0,0 +1,32 @@ +import { ConversationalExtension } from './index'; +import { InferenceExtension } from './index'; +import { MonitoringExtension } from './index'; +import { AssistantExtension } from './index'; +import { ModelExtension } from './index'; +import * as Engines from './index'; + +describe('index.ts exports', () => { + test('should export ConversationalExtension', () => { + expect(ConversationalExtension).toBeDefined(); + }); + + test('should export InferenceExtension', () => { + expect(InferenceExtension).toBeDefined(); + }); + + test('should export MonitoringExtension', () => { + expect(MonitoringExtension).toBeDefined(); + }); + + test('should export AssistantExtension', () => { + expect(AssistantExtension).toBeDefined(); + }); + + test('should export ModelExtension', () => { + expect(ModelExtension).toBeDefined(); + }); + + test('should export Engines', () => { + expect(Engines).toBeDefined(); + }); +}); \ No newline at end of file diff --git a/core/src/browser/extensions/inference.test.ts b/core/src/browser/extensions/inference.test.ts new file mode 100644 index 000000000..45ec9d172 --- /dev/null +++ b/core/src/browser/extensions/inference.test.ts @@ -0,0 +1,45 @@ +import { MessageRequest, ThreadMessage } from '../../types' +import { BaseExtension, ExtensionTypeEnum } from '../extension' +import { InferenceExtension } from './' + +// Mock the MessageRequest and ThreadMessage types +type MockMessageRequest = { + text: string +} + +type MockThreadMessage = { + text: string + userId: string +} + +// Mock the BaseExtension class +class MockBaseExtension extends BaseExtension { + type(): ExtensionTypeEnum | undefined { + return ExtensionTypeEnum.Base + } +} + +// Create a mock implementation of InferenceExtension +class MockInferenceExtension extends InferenceExtension { + async inference(data: MessageRequest): Promise { + return { text: 'Mock response', userId: '123' } as unknown as ThreadMessage + } +} + +describe('InferenceExtension', () => { + let inferenceExtension: InferenceExtension + + beforeEach(() => { + inferenceExtension = new MockInferenceExtension() + }) + + it('should have the correct type', () => { + expect(inferenceExtension.type()).toBe(ExtensionTypeEnum.Inference) + }) + + it('should implement the inference method', async () => { + const messageRequest: MessageRequest = { text: 'Hello' } as unknown as MessageRequest + const result = await inferenceExtension.inference(messageRequest) + expect(result).toEqual({ text: 'Mock response', userId: '123' } as unknown as ThreadMessage) + }) +}) diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index 5b3089403..040542927 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -4,6 +4,7 @@ import { HuggingFaceRepoData, ImportingModel, Model, + ModelFile, ModelInterface, OptionType, } from '../../types' @@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter network?: { proxy: string; ignoreSSL?: boolean } ): Promise abstract cancelModelDownload(modelId: string): Promise - abstract deleteModel(modelId: string): Promise - abstract saveModel(model: Model): Promise - abstract getDownloadedModels(): Promise - abstract getConfiguredModels(): Promise + abstract deleteModel(model: ModelFile): Promise + abstract getDownloadedModels(): Promise + abstract getConfiguredModels(): Promise abstract importModels(models: ImportingModel[], optionType: OptionType): Promise - abstract updateModelInfo(modelInfo: Partial): Promise + abstract updateModelInfo(modelInfo: Partial): Promise abstract fetchHuggingFaceRepoData(repoId: string): Promise abstract getDefaultModel(): Promise } diff --git a/core/src/browser/extensions/monitoring.test.ts b/core/src/browser/extensions/monitoring.test.ts new file mode 100644 index 000000000..9bba89a8c --- /dev/null +++ b/core/src/browser/extensions/monitoring.test.ts @@ -0,0 +1,42 @@ + +import { ExtensionTypeEnum } from '../extension'; +import { MonitoringExtension } from './monitoring'; + +it('should have the correct type', () => { + class TestMonitoringExtension extends MonitoringExtension { + getGpuSetting(): Promise { + throw new Error('Method not implemented.'); + } + getResourcesInfo(): Promise { + throw new Error('Method not implemented.'); + } + getCurrentLoad(): Promise { + throw new Error('Method not implemented.'); + } + getOsInfo(): Promise { + throw new Error('Method not implemented.'); + } + } + const monitoringExtension = new TestMonitoringExtension(); + expect(monitoringExtension.type()).toBe(ExtensionTypeEnum.SystemMonitoring); +}); + + +it('should create an instance of MonitoringExtension', () => { + class TestMonitoringExtension extends MonitoringExtension { + getGpuSetting(): Promise { + throw new Error('Method not implemented.'); + } + getResourcesInfo(): Promise { + throw new Error('Method not implemented.'); + } + getCurrentLoad(): Promise { + throw new Error('Method not implemented.'); + } + getOsInfo(): Promise { + throw new Error('Method not implemented.'); + } + } + const monitoringExtension = new TestMonitoringExtension(); + expect(monitoringExtension).toBeInstanceOf(MonitoringExtension); +}); diff --git a/core/src/browser/fs.test.ts b/core/src/browser/fs.test.ts new file mode 100644 index 000000000..21da54874 --- /dev/null +++ b/core/src/browser/fs.test.ts @@ -0,0 +1,97 @@ +import { fs } from './fs' + +describe('fs module', () => { + beforeEach(() => { + globalThis.core = { + api: { + writeFileSync: jest.fn(), + writeBlob: jest.fn(), + readFileSync: jest.fn(), + existsSync: jest.fn(), + readdirSync: jest.fn(), + mkdir: jest.fn(), + rm: jest.fn(), + unlinkSync: jest.fn(), + appendFileSync: jest.fn(), + copyFile: jest.fn(), + getGgufFiles: jest.fn(), + fileStat: jest.fn(), + }, + } + }) + + it('should call writeFileSync with correct arguments', () => { + const args = ['path/to/file', 'data'] + fs.writeFileSync(...args) + expect(globalThis.core.api.writeFileSync).toHaveBeenCalledWith(...args) + }) + + it('should call writeBlob with correct arguments', async () => { + const path = 'path/to/file' + const data = 'blob data' + await fs.writeBlob(path, data) + expect(globalThis.core.api.writeBlob).toHaveBeenCalledWith(path, data) + }) + + it('should call readFileSync with correct arguments', () => { + const args = ['path/to/file'] + fs.readFileSync(...args) + expect(globalThis.core.api.readFileSync).toHaveBeenCalledWith(...args) + }) + + it('should call existsSync with correct arguments', () => { + const args = ['path/to/file'] + fs.existsSync(...args) + expect(globalThis.core.api.existsSync).toHaveBeenCalledWith(...args) + }) + + it('should call readdirSync with correct arguments', () => { + const args = ['path/to/directory'] + fs.readdirSync(...args) + expect(globalThis.core.api.readdirSync).toHaveBeenCalledWith(...args) + }) + + it('should call mkdir with correct arguments', () => { + const args = ['path/to/directory'] + fs.mkdir(...args) + expect(globalThis.core.api.mkdir).toHaveBeenCalledWith(...args) + }) + + it('should call rm with correct arguments', () => { + const args = ['path/to/directory'] + fs.rm(...args) + expect(globalThis.core.api.rm).toHaveBeenCalledWith(...args, { recursive: true, force: true }) + }) + + it('should call unlinkSync with correct arguments', () => { + const args = ['path/to/file'] + fs.unlinkSync(...args) + expect(globalThis.core.api.unlinkSync).toHaveBeenCalledWith(...args) + }) + + it('should call appendFileSync with correct arguments', () => { + const args = ['path/to/file', 'data'] + fs.appendFileSync(...args) + expect(globalThis.core.api.appendFileSync).toHaveBeenCalledWith(...args) + }) + + it('should call copyFile with correct arguments', async () => { + const src = 'path/to/src' + const dest = 'path/to/dest' + await fs.copyFile(src, dest) + expect(globalThis.core.api.copyFile).toHaveBeenCalledWith(src, dest) + }) + + it('should call getGgufFiles with correct arguments', async () => { + const paths = ['path/to/file1', 'path/to/file2'] + await fs.getGgufFiles(paths) + expect(globalThis.core.api.getGgufFiles).toHaveBeenCalledWith(paths) + }) + + it('should call fileStat with correct arguments', async () => { + const path = 'path/to/file' + const outsideJanDataFolder = true + await fs.fileStat(path, outsideJanDataFolder) + expect(globalThis.core.api.fileStat).toHaveBeenCalledWith(path, outsideJanDataFolder) + }) +}) diff --git a/core/src/browser/tools/index.test.ts b/core/src/browser/tools/index.test.ts new file mode 100644 index 000000000..8a24d3bb6 --- /dev/null +++ b/core/src/browser/tools/index.test.ts @@ -0,0 +1,5 @@ + + +it('should not throw any errors when imported', () => { + expect(() => require('./index')).not.toThrow(); +}) diff --git a/core/src/browser/tools/tool.test.ts b/core/src/browser/tools/tool.test.ts new file mode 100644 index 000000000..dcb478478 --- /dev/null +++ b/core/src/browser/tools/tool.test.ts @@ -0,0 +1,63 @@ +import { ToolManager } from '../../browser/tools/manager' +import { InferenceTool } from '../../browser/tools/tool' +import { AssistantTool, MessageRequest } from '../../types' + +class MockInferenceTool implements InferenceTool { + name = 'mockTool' + process(request: MessageRequest, tool: AssistantTool): Promise { + return Promise.resolve(request) + } +} + +it('should register a tool', () => { + const manager = new ToolManager() + const tool = new MockInferenceTool() + manager.register(tool) + expect(manager.get(tool.name)).toBe(tool) +}) + +it('should retrieve a tool by its name', () => { + const manager = new ToolManager() + const tool = new MockInferenceTool() + manager.register(tool) + const retrievedTool = manager.get(tool.name) + expect(retrievedTool).toBe(tool) +}) + +it('should return undefined for a non-existent tool', () => { + const manager = new ToolManager() + const retrievedTool = manager.get('nonExistentTool') + expect(retrievedTool).toBeUndefined() +}) + +it('should process the message request with enabled tools', async () => { + const manager = new ToolManager() + const tool = new MockInferenceTool() + manager.register(tool) + + const request: MessageRequest = { message: 'test' } as any + const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any + + const result = await manager.process(request, tools) + expect(result).toBe(request) +}) + +it('should skip processing for disabled tools', async () => { + const manager = new ToolManager() + const tool = new MockInferenceTool() + manager.register(tool) + + const request: MessageRequest = { message: 'test' } as any + const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any + + const result = await manager.process(request, tools) + expect(result).toBe(request) +}) + +it('should throw an error when process is called without implementation', () => { + class TestTool extends InferenceTool { + name = 'testTool' + } + const tool = new TestTool() + expect(() => tool.process({} as MessageRequest)).toThrowError() +}) diff --git a/core/src/index.test.ts b/core/src/index.test.ts new file mode 100644 index 000000000..a1bd7c6b9 --- /dev/null +++ b/core/src/index.test.ts @@ -0,0 +1,7 @@ + + +it('should declare global object core when importing the module and then deleting it', () => { + import('./index'); + delete globalThis.core; + expect(typeof globalThis.core).toBe('undefined'); +}); diff --git a/core/src/node/api/index.test.ts b/core/src/node/api/index.test.ts new file mode 100644 index 000000000..c35d6e792 --- /dev/null +++ b/core/src/node/api/index.test.ts @@ -0,0 +1,7 @@ + +import * as restfulV1 from './restful/v1'; + +it('should re-export from restful/v1', () => { + const restfulV1Exports = require('./restful/v1'); + expect(restfulV1Exports).toBeDefined(); +}) diff --git a/core/src/node/api/processors/Processor.test.ts b/core/src/node/api/processors/Processor.test.ts new file mode 100644 index 000000000..fd913c481 --- /dev/null +++ b/core/src/node/api/processors/Processor.test.ts @@ -0,0 +1,6 @@ + +import { Processor } from './Processor'; + +it('should be defined', () => { + expect(Processor).toBeDefined(); +}); diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts index 3ada5df1e..5c4daef29 100644 --- a/core/src/node/api/processors/app.test.ts +++ b/core/src/node/api/processors/app.test.ts @@ -1,40 +1,57 @@ -import { App } from './app'; +jest.mock('../../helper', () => ({ + ...jest.requireActual('../../helper'), + getJanDataFolderPath: () => './app', +})) +import { dirname } from 'path' +import { App } from './app' it('should call stopServer', () => { - const app = new App(); - const stopServerMock = jest.fn().mockResolvedValue('Server stopped'); + const app = new App() + const stopServerMock = jest.fn().mockResolvedValue('Server stopped') jest.mock('@janhq/server', () => ({ - stopServer: stopServerMock - })); - const result = app.stopServer(); - expect(stopServerMock).toHaveBeenCalled(); -}); + stopServer: stopServerMock, + })) + app.stopServer() + expect(stopServerMock).toHaveBeenCalled() +}) it('should correctly retrieve basename', () => { - const app = new App(); - const result = app.baseName('/path/to/file.txt'); - expect(result).toBe('file.txt'); -}); + const app = new App() + const result = app.baseName('/path/to/file.txt') + expect(result).toBe('file.txt') +}) it('should correctly identify subdirectories', () => { - const app = new App(); - const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'; - const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'; - const result = app.isSubdirectory(basePath, subPath); - expect(result).toBe(true); -}); + const app = new App() + const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to' + const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir' + const result = app.isSubdirectory(basePath, subPath) + expect(result).toBe(true) +}) it('should correctly join multiple paths', () => { - const app = new App(); - const result = app.joinPath(['path', 'to', 'file']); - const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'; - expect(result).toBe(expectedPath); -}); + const app = new App() + const result = app.joinPath(['path', 'to', 'file']) + const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file' + expect(result).toBe(expectedPath) +}) it('should call correct function with provided arguments using process method', () => { - const app = new App(); - const mockFunc = jest.fn(); - app.joinPath = mockFunc; - app.process('joinPath', ['path1', 'path2']); - expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']); -}); + const app = new App() + const mockFunc = jest.fn() + app.joinPath = mockFunc + app.process('joinPath', ['path1', 'path2']) + expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']) +}) + +it('should retrieve the directory name from a file path (Unix/Windows)', async () => { + const app = new App() + const path = 'C:/Users/John Doe/Desktop/file.txt' + expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop') +}) + +it('should retrieve the directory name when using file protocol', async () => { + const app = new App() + const path = 'file:/models/file.txt' + expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models') +}) diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts index 15460ba56..a0808c5ac 100644 --- a/core/src/node/api/processors/app.ts +++ b/core/src/node/api/processors/app.ts @@ -1,4 +1,4 @@ -import { basename, isAbsolute, join, relative } from 'path' +import { basename, dirname, isAbsolute, join, relative } from 'path' import { Processor } from './Processor' import { @@ -6,6 +6,8 @@ import { appResourcePath, getAppConfigurations as appConfiguration, updateAppConfiguration, + normalizeFilePath, + getJanDataFolderPath, } from '../../helper' export class App implements Processor { @@ -28,6 +30,18 @@ export class App implements Processor { return join(...args) } + /** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + */ + dirName(path: string) { + const arg = + path.startsWith(`file:/`) || path.startsWith(`file:\\`) + ? join(getJanDataFolderPath(), normalizeFilePath(path)) + : path + return dirname(arg) + } + /** * Checks if the given path is a subdirectory of the given directory. * diff --git a/core/src/node/api/processors/download.test.ts b/core/src/node/api/processors/download.test.ts index 1dc0eefb8..370f1746f 100644 --- a/core/src/node/api/processors/download.test.ts +++ b/core/src/node/api/processors/download.test.ts @@ -1,59 +1,131 @@ -import { Downloader } from './download'; -import { DownloadEvent } from '../../../types/api'; -import { DownloadManager } from '../../helper/download'; +import { Downloader } from './download' +import { DownloadEvent } from '../../../types/api' +import { DownloadManager } from '../../helper/download' -it('should handle getFileSize errors correctly', async () => { - const observer = jest.fn(); - const url = 'http://example.com/file'; +jest.mock('../../helper', () => ({ + getJanDataFolderPath: jest.fn().mockReturnValue('path/to/folder'), +})) - const downloader = new Downloader(observer); - const requestMock = jest.fn((options, callback) => { - callback(new Error('Test error'), null); - }); - jest.mock('request', () => requestMock); +jest.mock('../../helper/path', () => ({ + validatePath: jest.fn().mockReturnValue('path/to/folder'), + normalizeFilePath: () => process.platform === 'win32' ? 'C:\\Users\path\\to\\file.gguf' : '/Users/path/to/file.gguf', +})) - await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error'); -}); +jest.mock( + 'request', + jest.fn().mockReturnValue(() => ({ + on: jest.fn(), + })) +) +jest.mock('fs', () => ({ + createWriteStream: jest.fn(), +})) -it('should pause download correctly', () => { - const observer = jest.fn(); - const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'; +jest.mock('request-progress', () => { + return jest.fn().mockImplementation(() => { + return { + on: jest.fn().mockImplementation((event, callback) => { + if (event === 'error') { + callback(new Error('Download failed')) + } + return { + on: jest.fn().mockImplementation((event, callback) => { + if (event === 'error') { + callback(new Error('Download failed')) + } + return { + on: jest.fn().mockImplementation((event, callback) => { + if (event === 'error') { + callback(new Error('Download failed')) + } + return { pipe: jest.fn() } + }), + } + }), + } + }), + } + }) +}) - const downloader = new Downloader(observer); - const pauseMock = jest.fn(); - DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock }; +describe('Downloader', () => { + beforeEach(() => { + jest.resetAllMocks() + }) + it('should handle getFileSize errors correctly', async () => { + const observer = jest.fn() + const url = 'http://example.com/file' - downloader.pauseDownload(observer, fileName); + const downloader = new Downloader(observer) + const requestMock = jest.fn((options, callback) => { + callback(new Error('Test error'), null) + }) + jest.mock('request', () => requestMock) - expect(pauseMock).toHaveBeenCalled(); -}); + await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error') + }) -it('should resume download correctly', () => { - const observer = jest.fn(); - const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'; + it('should pause download correctly', () => { + const observer = jest.fn() + const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file' - const downloader = new Downloader(observer); - const resumeMock = jest.fn(); - DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock }; + const downloader = new Downloader(observer) + const pauseMock = jest.fn() + DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock } - downloader.resumeDownload(observer, fileName); + downloader.pauseDownload(observer, fileName) - expect(resumeMock).toHaveBeenCalled(); -}); + expect(pauseMock).toHaveBeenCalled() + }) -it('should handle aborting a download correctly', () => { - const observer = jest.fn(); - const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'; + it('should resume download correctly', () => { + const observer = jest.fn() + const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file' - const downloader = new Downloader(observer); - const abortMock = jest.fn(); - DownloadManager.instance.networkRequests[fileName] = { abort: abortMock }; + const downloader = new Downloader(observer) + const resumeMock = jest.fn() + DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock } - downloader.abortDownload(observer, fileName); + downloader.resumeDownload(observer, fileName) - expect(abortMock).toHaveBeenCalled(); - expect(observer).toHaveBeenCalledWith(DownloadEvent.onFileDownloadError, expect.objectContaining({ - error: 'aborted' - })); -}); + expect(resumeMock).toHaveBeenCalled() + }) + + it('should handle aborting a download correctly', () => { + const observer = jest.fn() + const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file' + + const downloader = new Downloader(observer) + const abortMock = jest.fn() + DownloadManager.instance.networkRequests[fileName] = { abort: abortMock } + + downloader.abortDownload(observer, fileName) + + expect(abortMock).toHaveBeenCalled() + expect(observer).toHaveBeenCalledWith( + DownloadEvent.onFileDownloadError, + expect.objectContaining({ + error: 'aborted', + }) + ) + }) + + it('should handle download fail correctly', () => { + const observer = jest.fn() + const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file.gguf' + + const downloader = new Downloader(observer) + + downloader.downloadFile(observer, { + localPath: fileName, + url: 'http://127.0.0.1', + }) + expect(observer).toHaveBeenCalledWith( + DownloadEvent.onFileDownloadError, + expect.objectContaining({ + error: expect.anything(), + }) + ) + }) +}) diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts index 07486bdf8..5db18a53a 100644 --- a/core/src/node/api/processors/download.ts +++ b/core/src/node/api/processors/download.ts @@ -34,7 +34,7 @@ export class Downloader implements Processor { } const array = normalizedPath.split(sep) const fileName = array.pop() ?? '' - const modelId = array.pop() ?? '' + const modelId = downloadRequest.modelId ?? array.pop() ?? '' const destination = resolve(getJanDataFolderPath(), normalizedPath) validatePath(destination) @@ -100,7 +100,11 @@ export class Downloader implements Processor { }) .on('end', () => { const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId] - if (currentDownloadState && DownloadManager.instance.networkRequests[normalizedPath]) { + if ( + currentDownloadState && + DownloadManager.instance.networkRequests[normalizedPath] && + DownloadManager.instance.downloadProgressMap[modelId]?.downloadState !== 'error' + ) { // Finished downloading, rename temp file to actual file renameSync(downloadingTempFile, destination) const downloadState: DownloadState = { diff --git a/core/src/node/api/processors/extension.test.ts b/core/src/node/api/processors/extension.test.ts index 917883499..2067c5c42 100644 --- a/core/src/node/api/processors/extension.test.ts +++ b/core/src/node/api/processors/extension.test.ts @@ -7,3 +7,34 @@ it('should call function associated with key in process method', () => { extension.process('testKey', 'arg1', 'arg2'); expect(mockFunc).toHaveBeenCalledWith('arg1', 'arg2'); }); + + +it('should_handle_empty_extension_list_for_install', async () => { + jest.mock('../../extension/store', () => ({ + installExtensions: jest.fn(() => Promise.resolve([])), + })); + const extension = new Extension(); + const result = await extension.installExtension([]); + expect(result).toEqual([]); +}); + + +it('should_handle_empty_extension_list_for_update', async () => { + jest.mock('../../extension/store', () => ({ + getExtension: jest.fn(() => ({ update: jest.fn(() => Promise.resolve(true)) })), + })); + const extension = new Extension(); + const result = await extension.updateExtension([]); + expect(result).toEqual([]); +}); + + +it('should_handle_empty_extension_list', async () => { + jest.mock('../../extension/store', () => ({ + getExtension: jest.fn(() => ({ uninstall: jest.fn(() => Promise.resolve(true)) })), + removeExtension: jest.fn(), + })); + const extension = new Extension(); + const result = await extension.uninstallExtension([]); + expect(result).toBe(true); +}); diff --git a/core/src/node/api/processors/processor.test.ts b/core/src/node/api/processors/processor.test.ts deleted file mode 100644 index e69de29bb..000000000 diff --git a/core/src/node/api/restful/helper/builder.test.ts b/core/src/node/api/restful/helper/builder.test.ts new file mode 100644 index 000000000..eb21e9401 --- /dev/null +++ b/core/src/node/api/restful/helper/builder.test.ts @@ -0,0 +1,305 @@ +import { + existsSync, + readdirSync, + readFileSync, + writeFileSync, + mkdirSync, + appendFileSync, + rmdirSync, +} from 'fs' +import { join } from 'path' +import { + getBuilder, + retrieveBuilder, + deleteBuilder, + getMessages, + retrieveMessage, + createThread, + updateThread, + createMessage, + downloadModel, + chatCompletions, +} from './builder' +import { RouteConfiguration } from './configuration' + +jest.mock('fs') +jest.mock('path') +jest.mock('../../../helper', () => ({ + getEngineConfiguration: jest.fn(), + getJanDataFolderPath: jest.fn().mockReturnValue('/mock/path'), +})) +jest.mock('request') +jest.mock('request-progress') +jest.mock('node-fetch') + +describe('builder helper functions', () => { + const mockConfiguration: RouteConfiguration = { + dirName: 'mockDir', + metadataFileName: 'metadata.json', + delete: { + object: 'mockObject', + }, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('getBuilder', () => { + it('should return an empty array if directory does not exist', async () => { + ;(existsSync as jest.Mock).mockReturnValue(false) + const result = await getBuilder(mockConfiguration) + expect(result).toEqual([]) + }) + + it('should return model data if directory exists', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await getBuilder(mockConfiguration) + expect(result).toEqual([{ id: 'model1' }]) + }) + }) + + describe('retrieveBuilder', () => { + it('should return undefined if no data matches the id', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await retrieveBuilder(mockConfiguration, 'nonexistentId') + expect(result).toBeUndefined() + }) + + it('should return the matching data', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await retrieveBuilder(mockConfiguration, 'model1') + expect(result).toEqual({ id: 'model1' }) + }) + }) + + describe('deleteBuilder', () => { + it('should return a message if trying to delete Jan assistant', async () => { + const result = await deleteBuilder({ ...mockConfiguration, dirName: 'assistants' }, 'jan') + expect(result).toEqual({ message: 'Cannot delete Jan assistant' }) + }) + + it('should return a message if data is not found', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await deleteBuilder(mockConfiguration, 'nonexistentId') + expect(result).toEqual({ message: 'Not found' }) + }) + + it('should delete the directory and return success message', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await deleteBuilder(mockConfiguration, 'model1') + expect(rmdirSync).toHaveBeenCalledWith(join('/mock/path', 'mockDir', 'model1'), { + recursive: true, + }) + expect(result).toEqual({ id: 'model1', object: 'mockObject', deleted: true }) + }) + }) + + describe('getMessages', () => { + it('should return an empty array if message file does not exist', async () => { + ;(existsSync as jest.Mock).mockReturnValue(false) + + const result = await getMessages('thread1') + expect(result).toEqual([]) + }) + + it('should return messages if message file exists', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl']) + ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n{"id":"msg2"}\n') + + const result = await getMessages('thread1') + expect(result).toEqual([{ id: 'msg1' }, { id: 'msg2' }]) + }) + }) + + describe('retrieveMessage', () => { + it('should return a message if no messages match the id', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl']) + ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n') + + const result = await retrieveMessage('thread1', 'nonexistentId') + expect(result).toEqual({ message: 'Not found' }) + }) + + it('should return the matching message', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl']) + ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n') + + const result = await retrieveMessage('thread1', 'msg1') + expect(result).toEqual({ id: 'msg1' }) + }) + }) + + describe('createThread', () => { + it('should return a message if thread has no assistants', async () => { + const result = await createThread({}) + expect(result).toEqual({ message: 'Thread must have at least one assistant' }) + }) + + it('should create a thread and return the updated thread', async () => { + ;(existsSync as jest.Mock).mockReturnValue(false) + + const thread = { assistants: [{ assistant_id: 'assistant1' }] } + const result = await createThread(thread) + expect(mkdirSync).toHaveBeenCalled() + expect(writeFileSync).toHaveBeenCalled() + expect(result.id).toBeDefined() + }) + }) + + describe('updateThread', () => { + it('should return a message if thread is not found', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await updateThread('nonexistentId', {}) + expect(result).toEqual({ message: 'Thread not found' }) + }) + + it('should update the thread and return the updated thread', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await updateThread('model1', { name: 'updatedName' }) + expect(writeFileSync).toHaveBeenCalled() + expect(result.name).toEqual('updatedName') + }) + }) + + describe('createMessage', () => { + it('should create a message and return the created message', async () => { + ;(existsSync as jest.Mock).mockReturnValue(false) + const message = { role: 'user', content: 'Hello' } + + const result = (await createMessage('thread1', message)) as any + expect(mkdirSync).toHaveBeenCalled() + expect(appendFileSync).toHaveBeenCalled() + expect(result.id).toBeDefined() + }) + }) + + describe('downloadModel', () => { + it('should return a message if model is not found', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' })) + + const result = await downloadModel('nonexistentId') + expect(result).toEqual({ message: 'Model not found' }) + }) + + it('should start downloading the model', async () => { + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue( + JSON.stringify({ id: 'model1', object: 'model', sources: ['http://example.com'] }) + ) + const result = await downloadModel('model1') + expect(result).toEqual({ message: 'Starting download model1' }) + }) + }) + + describe('chatCompletions', () => { + it('should return an error if model is not found', async () => { + const request = { body: { model: 'nonexistentModel' } } + const reply = { code: jest.fn().mockReturnThis(), send: jest.fn() } + + await chatCompletions(request, reply) + expect(reply.code).toHaveBeenCalledWith(404) + expect(reply.send).toHaveBeenCalledWith({ + error: { + message: 'The model nonexistentModel does not exist', + type: 'invalid_request_error', + param: null, + code: 'model_not_found', + }, + }) + }) + + it('should return the error on status not ok', async () => { + const request = { body: { model: 'model1' } } + const mockSend = jest.fn() + const reply = { + code: jest.fn().mockReturnThis(), + send: jest.fn(), + headers: jest.fn().mockReturnValue({ + send: mockSend, + }), + raw: { + writeHead: jest.fn(), + pipe: jest.fn(), + }, + } + + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue( + JSON.stringify({ id: 'model1', engine: 'openai' }) + ) + + // Mock fetch + const fetch = require('node-fetch') + fetch.mockResolvedValue({ + status: 400, + headers: new Map([ + ['content-type', 'application/json'], + ['x-request-id', '123456'], + ]), + body: { pipe: jest.fn() }, + text: jest.fn().mockResolvedValue({ error: 'Mock error response' }), + }) + await chatCompletions(request, reply) + expect(reply.code).toHaveBeenCalledWith(400) + expect(mockSend).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'Mock error response', + }) + ) + }) + + it('should return the chat completions', async () => { + const request = { body: { model: 'model1' } } + const reply = { + code: jest.fn().mockReturnThis(), + send: jest.fn(), + raw: { writeHead: jest.fn(), pipe: jest.fn() }, + } + + ;(existsSync as jest.Mock).mockReturnValue(true) + ;(readdirSync as jest.Mock).mockReturnValue(['file1']) + ;(readFileSync as jest.Mock).mockReturnValue( + JSON.stringify({ id: 'model1', engine: 'openai' }) + ) + + // Mock fetch + const fetch = require('node-fetch') + fetch.mockResolvedValue({ + status: 200, + body: { pipe: jest.fn() }, + json: jest.fn().mockResolvedValue({ completions: ['completion1'] }), + }) + await chatCompletions(request, reply) + expect(reply.raw.writeHead).toHaveBeenCalledWith(200, expect.any(Object)) + }) + }) +}) diff --git a/core/src/node/api/restful/helper/builder.ts b/core/src/node/api/restful/helper/builder.ts index 08da0ff33..db2000d69 100644 --- a/core/src/node/api/restful/helper/builder.ts +++ b/core/src/node/api/restful/helper/builder.ts @@ -280,13 +280,13 @@ export const downloadModel = async ( for (const source of model.sources) { const rq = request({ url: source, strictSSL, proxy }) progress(rq, {}) - .on('progress', function (state: any) { + ?.on('progress', function (state: any) { console.debug('progress', JSON.stringify(state, null, 2)) }) - .on('error', function (err: Error) { + ?.on('error', function (err: Error) { console.error('error', err) }) - .on('end', function () { + ?.on('end', function () { console.debug('end') }) .pipe(createWriteStream(modelBinaryPath)) @@ -353,8 +353,10 @@ export const chatCompletions = async (request: any, reply: any) => { body: JSON.stringify(request.body), }) if (response.status !== 200) { - console.error(response) - reply.code(400).send(response) + // Forward the error response to client via reply + const responseBody = await response.text() + const responseHeaders = Object.fromEntries(response.headers) + reply.code(response.status).headers(responseHeaders).send(responseBody) } else { reply.raw.writeHead(200, { 'Content-Type': request.body.stream === true ? 'text/event-stream' : 'application/json', diff --git a/core/src/node/api/restful/helper/consts.test.ts b/core/src/node/api/restful/helper/consts.test.ts new file mode 100644 index 000000000..34d42dcf0 --- /dev/null +++ b/core/src/node/api/restful/helper/consts.test.ts @@ -0,0 +1,6 @@ + +import { NITRO_DEFAULT_PORT } from './consts'; + +it('should test NITRO_DEFAULT_PORT', () => { + expect(NITRO_DEFAULT_PORT).toBe(3928); +}); diff --git a/core/src/node/api/restful/helper/startStopModel.test.ts b/core/src/node/api/restful/helper/startStopModel.test.ts new file mode 100644 index 000000000..a5475cc28 --- /dev/null +++ b/core/src/node/api/restful/helper/startStopModel.test.ts @@ -0,0 +1,16 @@ + + + import { startModel } from './startStopModel' + + describe('startModel', () => { + it('test_startModel_error', async () => { + const modelId = 'testModelId' + const settingParams = undefined + + const result = await startModel(modelId, settingParams) + + expect(result).toEqual({ + error: expect.any(Error), + }) + }) + }) diff --git a/core/src/node/extension/index.test.ts b/core/src/node/extension/index.test.ts new file mode 100644 index 000000000..ce9cb0d0a --- /dev/null +++ b/core/src/node/extension/index.test.ts @@ -0,0 +1,7 @@ + + + import { useExtensions } from './index' + + test('testUseExtensionsMissingPath', () => { + expect(() => useExtensions(undefined as any)).toThrowError('A path to the extensions folder is required to use extensions') + }) diff --git a/core/src/node/helper/config.test.ts b/core/src/node/helper/config.test.ts index 201a98141..d46750d5f 100644 --- a/core/src/node/helper/config.test.ts +++ b/core/src/node/helper/config.test.ts @@ -1,6 +1,8 @@ import { getEngineConfiguration } from './config'; import { getAppConfigurations, defaultAppConfig } from './config'; +import { getJanExtensionsPath } from './config'; +import { getJanDataFolderPath } from './config'; it('should return undefined for invalid engine ID', async () => { const config = await getEngineConfiguration('invalid_engine'); expect(config).toBeUndefined(); @@ -12,3 +14,15 @@ it('should return default config when CI is e2e', () => { const config = getAppConfigurations(); expect(config).toEqual(defaultAppConfig()); }); + + +it('should return extensions path when retrieved successfully', () => { + const extensionsPath = getJanExtensionsPath(); + expect(extensionsPath).not.toBeUndefined(); +}); + + +it('should return data folder path when retrieved successfully', () => { + const dataFolderPath = getJanDataFolderPath(); + expect(dataFolderPath).not.toBeUndefined(); +}); diff --git a/core/src/types/api/index.test.ts b/core/src/types/api/index.test.ts new file mode 100644 index 000000000..6f2f2dcdb --- /dev/null +++ b/core/src/types/api/index.test.ts @@ -0,0 +1,24 @@ + + +import { NativeRoute } from '../index'; + +test('testNativeRouteEnum', () => { + expect(NativeRoute.openExternalUrl).toBe('openExternalUrl'); + expect(NativeRoute.openAppDirectory).toBe('openAppDirectory'); + expect(NativeRoute.openFileExplore).toBe('openFileExplorer'); + expect(NativeRoute.selectDirectory).toBe('selectDirectory'); + expect(NativeRoute.selectFiles).toBe('selectFiles'); + expect(NativeRoute.relaunch).toBe('relaunch'); + expect(NativeRoute.setNativeThemeLight).toBe('setNativeThemeLight'); + expect(NativeRoute.setNativeThemeDark).toBe('setNativeThemeDark'); + expect(NativeRoute.setMinimizeApp).toBe('setMinimizeApp'); + expect(NativeRoute.setCloseApp).toBe('setCloseApp'); + expect(NativeRoute.setMaximizeApp).toBe('setMaximizeApp'); + expect(NativeRoute.showOpenMenu).toBe('showOpenMenu'); + expect(NativeRoute.hideQuickAskWindow).toBe('hideQuickAskWindow'); + expect(NativeRoute.sendQuickAskInput).toBe('sendQuickAskInput'); + expect(NativeRoute.hideMainWindow).toBe('hideMainWindow'); + expect(NativeRoute.showMainWindow).toBe('showMainWindow'); + expect(NativeRoute.quickAskSizeUpdated).toBe('quickAskSizeUpdated'); + expect(NativeRoute.ackDeepLink).toBe('ackDeepLink'); +}); diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index bca11c0a8..8f1ff70bf 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -37,6 +37,7 @@ export enum AppRoute { getAppConfigurations = 'getAppConfigurations', updateAppConfiguration = 'updateAppConfiguration', joinPath = 'joinPath', + dirName = 'dirName', isSubdirectory = 'isSubdirectory', baseName = 'baseName', startServer = 'startServer', diff --git a/core/src/types/config/appConfigEvent.test.ts b/core/src/types/config/appConfigEvent.test.ts new file mode 100644 index 000000000..6000156c7 --- /dev/null +++ b/core/src/types/config/appConfigEvent.test.ts @@ -0,0 +1,9 @@ + + + import { AppConfigurationEventName } from './appConfigEvent'; + + describe('AppConfigurationEventName', () => { + it('should have the correct value for OnConfigurationUpdate', () => { + expect(AppConfigurationEventName.OnConfigurationUpdate).toBe('OnConfigurationUpdate'); + }); + }); diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 1b36a5777..9f3e32b3e 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -40,6 +40,14 @@ export type DownloadRequest = { */ extensionId?: string + /** + * The model ID of the model that initiated the download. + */ + modelId?: string + + /** + * The download type. + */ downloadType?: DownloadType | string } @@ -52,3 +60,18 @@ type DownloadSize = { total: number transferred: number } + +/** + * The file metadata + */ +export type FileMetadata = { + /** + * The origin file path. + */ + file_path: string + + /** + * The file name. + */ + file_name: string +} diff --git a/core/src/types/huggingface/huggingfaceEntity.test.ts b/core/src/types/huggingface/huggingfaceEntity.test.ts new file mode 100644 index 000000000..d57b484be --- /dev/null +++ b/core/src/types/huggingface/huggingfaceEntity.test.ts @@ -0,0 +1,28 @@ + + + import { AllQuantizations } from './huggingfaceEntity'; + + test('testAllQuantizationsArray', () => { + expect(AllQuantizations).toEqual([ + 'Q3_K_S', + 'Q3_K_M', + 'Q3_K_L', + 'Q4_K_S', + 'Q4_K_M', + 'Q5_K_S', + 'Q5_K_M', + 'Q4_0', + 'Q4_1', + 'Q5_0', + 'Q5_1', + 'IQ2_XXS', + 'IQ2_XS', + 'Q2_K', + 'Q2_K_S', + 'Q6_K', + 'Q8_0', + 'F16', + 'F32', + 'COPY', + ]); + }); diff --git a/core/src/types/huggingface/index.test.ts b/core/src/types/huggingface/index.test.ts new file mode 100644 index 000000000..9cb80a08f --- /dev/null +++ b/core/src/types/huggingface/index.test.ts @@ -0,0 +1,8 @@ + + + import * as huggingfaceEntity from './huggingfaceEntity'; + import * as index from './index'; + + test('test_exports_from_huggingfaceEntity', () => { + expect(index).toEqual(huggingfaceEntity); + }); diff --git a/core/src/types/index.test.ts b/core/src/types/index.test.ts new file mode 100644 index 000000000..9dc001c4d --- /dev/null +++ b/core/src/types/index.test.ts @@ -0,0 +1,28 @@ + +import * as assistant from './assistant'; +import * as model from './model'; +import * as thread from './thread'; +import * as message from './message'; +import * as inference from './inference'; +import * as monitoring from './monitoring'; +import * as file from './file'; +import * as config from './config'; +import * as huggingface from './huggingface'; +import * as miscellaneous from './miscellaneous'; +import * as api from './api'; +import * as setting from './setting'; + + test('test_module_exports', () => { + expect(assistant).toBeDefined(); + expect(model).toBeDefined(); + expect(thread).toBeDefined(); + expect(message).toBeDefined(); + expect(inference).toBeDefined(); + expect(monitoring).toBeDefined(); + expect(file).toBeDefined(); + expect(config).toBeDefined(); + expect(huggingface).toBeDefined(); + expect(miscellaneous).toBeDefined(); + expect(api).toBeDefined(); + expect(setting).toBeDefined(); + }); diff --git a/core/src/types/inference/inferenceEntity.test.ts b/core/src/types/inference/inferenceEntity.test.ts new file mode 100644 index 000000000..a2c06e32b --- /dev/null +++ b/core/src/types/inference/inferenceEntity.test.ts @@ -0,0 +1,13 @@ + + + import { ChatCompletionMessage, ChatCompletionRole } from './inferenceEntity'; + + test('test_chatCompletionMessage_withStringContent_andSystemRole', () => { + const message: ChatCompletionMessage = { + content: 'Hello, world!', + role: ChatCompletionRole.System, + }; + + expect(message.content).toBe('Hello, world!'); + expect(message.role).toBe(ChatCompletionRole.System); + }); diff --git a/core/src/types/inference/inferenceEvent.test.ts b/core/src/types/inference/inferenceEvent.test.ts new file mode 100644 index 000000000..1cb44fdbb --- /dev/null +++ b/core/src/types/inference/inferenceEvent.test.ts @@ -0,0 +1,7 @@ + + + import { InferenceEvent } from './inferenceEvent'; + + test('testInferenceEventEnumContainsOnInferenceStopped', () => { + expect(InferenceEvent.OnInferenceStopped).toBe('OnInferenceStopped'); + }); diff --git a/core/src/types/message/messageEntity.test.ts b/core/src/types/message/messageEntity.test.ts new file mode 100644 index 000000000..1d41d129a --- /dev/null +++ b/core/src/types/message/messageEntity.test.ts @@ -0,0 +1,9 @@ + +import { MessageStatus } from './messageEntity'; + +it('should have correct values', () => { + expect(MessageStatus.Ready).toBe('ready'); + expect(MessageStatus.Pending).toBe('pending'); + expect(MessageStatus.Error).toBe('error'); + expect(MessageStatus.Stopped).toBe('stopped'); +}) diff --git a/core/src/types/message/messageEvent.test.ts b/core/src/types/message/messageEvent.test.ts new file mode 100644 index 000000000..80a943bb1 --- /dev/null +++ b/core/src/types/message/messageEvent.test.ts @@ -0,0 +1,7 @@ + + + import { MessageEvent } from './messageEvent'; + + test('testOnMessageSentValue', () => { + expect(MessageEvent.OnMessageSent).toBe('OnMessageSent'); + }); diff --git a/core/src/types/message/messageRequestType.test.ts b/core/src/types/message/messageRequestType.test.ts new file mode 100644 index 000000000..41f53b2e0 --- /dev/null +++ b/core/src/types/message/messageRequestType.test.ts @@ -0,0 +1,7 @@ + + + import { MessageRequestType } from './messageRequestType'; + + test('testMessageRequestTypeEnumContainsThread', () => { + expect(MessageRequestType.Thread).toBe('Thread'); + }); diff --git a/core/src/types/miscellaneous/systemResourceInfo.test.ts b/core/src/types/miscellaneous/systemResourceInfo.test.ts new file mode 100644 index 000000000..35a459f0e --- /dev/null +++ b/core/src/types/miscellaneous/systemResourceInfo.test.ts @@ -0,0 +1,6 @@ + +import { SupportedPlatforms } from './systemResourceInfo'; + +it('should contain the correct values', () => { + expect(SupportedPlatforms).toEqual(['win32', 'linux', 'darwin']); +}); diff --git a/core/src/types/model/modelEntity.test.ts b/core/src/types/model/modelEntity.test.ts new file mode 100644 index 000000000..306316ac4 --- /dev/null +++ b/core/src/types/model/modelEntity.test.ts @@ -0,0 +1,30 @@ + + + import { Model, ModelSettingParams, ModelRuntimeParams, InferenceEngine } from '../model' + + test('testValidModelCreation', () => { + const model: Model = { + object: 'model', + version: '1.0', + format: 'format1', + sources: [{ filename: 'model.bin', url: 'http://example.com/model.bin' }], + id: 'model1', + name: 'Test Model', + created: Date.now(), + description: 'A cool model from Huggingface', + settings: { ctx_len: 100, ngl: 50, embedding: true }, + parameters: { temperature: 0.5, token_limit: 100, top_k: 10 }, + metadata: { author: 'Author', tags: ['tag1', 'tag2'], size: 100 }, + engine: InferenceEngine.anthropic + }; + + expect(model).toBeDefined(); + expect(model.object).toBe('model'); + expect(model.version).toBe('1.0'); + expect(model.sources).toHaveLength(1); + expect(model.sources[0].filename).toBe('model.bin'); + expect(model.settings).toBeDefined(); + expect(model.parameters).toBeDefined(); + expect(model.metadata).toBeDefined(); + expect(model.engine).toBe(InferenceEngine.anthropic); + }); diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index f154f7f04..933c698c3 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -1,3 +1,5 @@ +import { FileMetadata } from '../file' + /** * Represents the information about a model. * @stored @@ -151,3 +153,8 @@ export type ModelRuntimeParams = { export type ModelInitFailed = Model & { error: Error } + +/** + * ModelFile is the model.json entity and it's file metadata + */ +export type ModelFile = Model & FileMetadata diff --git a/core/src/types/model/modelEvent.test.ts b/core/src/types/model/modelEvent.test.ts new file mode 100644 index 000000000..f9fa8cc6a --- /dev/null +++ b/core/src/types/model/modelEvent.test.ts @@ -0,0 +1,7 @@ + + + import { ModelEvent } from './modelEvent'; + + test('testOnModelInit', () => { + expect(ModelEvent.OnModelInit).toBe('OnModelInit'); + }); diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 639c7c8d3..08d456b7e 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -1,5 +1,5 @@ import { GpuSetting } from '../miscellaneous' -import { Model } from './modelEntity' +import { Model, ModelFile } from './modelEntity' /** * Model extension for managing models. @@ -12,7 +12,7 @@ export interface ModelInterface { * @returns A Promise that resolves when the model has been downloaded. */ downloadModel( - model: Model, + model: ModelFile, gpuSettings?: GpuSetting, network?: { ignoreSSL?: boolean; proxy?: string } ): Promise @@ -29,24 +29,17 @@ export interface ModelInterface { * @param modelId - The ID of the model to delete. * @returns A Promise that resolves when the model has been deleted. */ - deleteModel(modelId: string): Promise - - /** - * Saves a model. - * @param model - The model to save. - * @returns A Promise that resolves when the model has been saved. - */ - saveModel(model: Model): Promise + deleteModel(model: ModelFile): Promise /** * Gets a list of downloaded models. * @returns A Promise that resolves with an array of downloaded models. */ - getDownloadedModels(): Promise + getDownloadedModels(): Promise /** * Gets a list of configured models. * @returns A Promise that resolves with an array of configured models. */ - getConfiguredModels(): Promise + getConfiguredModels(): Promise } diff --git a/core/src/types/monitoring/index.test.ts b/core/src/types/monitoring/index.test.ts new file mode 100644 index 000000000..010fcb97a --- /dev/null +++ b/core/src/types/monitoring/index.test.ts @@ -0,0 +1,16 @@ + +import * as monitoringInterface from './monitoringInterface'; +import * as resourceInfo from './resourceInfo'; + + import * as index from './index'; + import * as monitoringInterface from './monitoringInterface'; + import * as resourceInfo from './resourceInfo'; + + it('should re-export all symbols from monitoringInterface and resourceInfo', () => { + for (const key in monitoringInterface) { + expect(index[key]).toBe(monitoringInterface[key]); + } + for (const key in resourceInfo) { + expect(index[key]).toBe(resourceInfo[key]); + } + }); diff --git a/core/src/types/setting/index.test.ts b/core/src/types/setting/index.test.ts new file mode 100644 index 000000000..699adfe4f --- /dev/null +++ b/core/src/types/setting/index.test.ts @@ -0,0 +1,5 @@ + + +it('should not throw any errors', () => { + expect(() => require('./index')).not.toThrow(); +}); diff --git a/core/src/types/setting/settingComponent.test.ts b/core/src/types/setting/settingComponent.test.ts new file mode 100644 index 000000000..c56550e19 --- /dev/null +++ b/core/src/types/setting/settingComponent.test.ts @@ -0,0 +1,19 @@ + +import { createSettingComponent } from './settingComponent'; + + it('should throw an error when creating a setting component with invalid controller type', () => { + const props: SettingComponentProps = { + key: 'invalidControllerKey', + title: 'Invalid Controller Title', + description: 'Invalid Controller Description', + controllerType: 'invalid' as any, + controllerProps: { + placeholder: 'Enter text', + value: 'Initial Value', + type: 'text', + textAlign: 'left', + inputActions: ['unobscure'], + }, + }; + expect(() => createSettingComponent(props)).toThrowError(); + }); diff --git a/core/src/types/thread/threadEvent.test.ts b/core/src/types/thread/threadEvent.test.ts new file mode 100644 index 000000000..f892f1050 --- /dev/null +++ b/core/src/types/thread/threadEvent.test.ts @@ -0,0 +1,6 @@ + +import { ThreadEvent } from './threadEvent'; + +it('should have the correct values', () => { + expect(ThreadEvent.OnThreadStarted).toBe('OnThreadStarted'); +}); diff --git a/electron/jest.config.js b/electron/jest.config.js new file mode 100644 index 000000000..ec5968ccd --- /dev/null +++ b/electron/jest.config.js @@ -0,0 +1,18 @@ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + collectCoverageFrom: ['src/**/*.{ts,tsx}'], + modulePathIgnorePatterns: ['/tests'], + moduleNameMapper: { + '@/(.*)': '/src/$1', + }, + runner: './testRunner.js', + transform: { + '^.+\\.tsx?$': [ + 'ts-jest', + { + diagnostics: false, + }, + ], + }, +} diff --git a/electron/testRunner.js b/electron/testRunner.js new file mode 100644 index 000000000..b0d108160 --- /dev/null +++ b/electron/testRunner.js @@ -0,0 +1,10 @@ +const jestRunner = require('jest-runner'); + +class EmptyTestFileRunner extends jestRunner.default { + async runTests(tests, watcher, onStart, onResult, onFailure, options) { + const nonEmptyTests = tests.filter(test => test.context.hasteFS.getSize(test.path) > 0); + return super.runTests(nonEmptyTests, watcher, onStart, onResult, onFailure, options); + } +} + +module.exports = EmptyTestFileRunner; \ No newline at end of file diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts index c13e91119..5d7328053 100644 --- a/electron/tests/e2e/thread.e2e.spec.ts +++ b/electron/tests/e2e/thread.e2e.spec.ts @@ -1,32 +1,29 @@ import { expect } from '@playwright/test' import { page, test, TIMEOUT } from '../config/fixtures' -test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => { +test('Select GPT model from Hub and Chat with Invalid API Key', async ({ + hubPage, +}) => { await hubPage.navigateByMenu() await hubPage.verifyContainerVisible() // Select the first GPT model await page .locator('[data-testid^="use-model-btn"][data-testid*="gpt"]') - .first().click() - - // Attempt to create thread and chat in Thread page - await page - .getByTestId('btn-create-thread') + .first() .click() - await page - .getByTestId('txt-input-chat') - .fill('dummy value') + await page.getByTestId('txt-input-chat').fill('dummy value') - await page - .getByTestId('btn-send-chat') - .click() + await page.getByTestId('btn-send-chat').click() - await page.waitForFunction(() => { - const loaders = document.querySelectorAll('[data-testid$="loader"]'); - return !loaders.length; - }, { timeout: TIMEOUT }); + await page.waitForFunction( + () => { + const loaders = document.querySelectorAll('[data-testid$="loader"]') + return !loaders.length + }, + { timeout: TIMEOUT } + ) const APIKeyError = page.getByTestId('invalid-API-key-error') await expect(APIKeyError).toBeVisible({ diff --git a/extensions/conversational-extension/jest.config.js b/extensions/conversational-extension/jest.config.js new file mode 100644 index 000000000..8bb37208d --- /dev/null +++ b/extensions/conversational-extension/jest.config.js @@ -0,0 +1,5 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', +} diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json index d062ce9c3..036fcfab2 100644 --- a/extensions/conversational-extension/package.json +++ b/extensions/conversational-extension/package.json @@ -7,6 +7,7 @@ "author": "Jan ", "license": "MIT", "scripts": { + "test": "jest", "build": "tsc -b . && webpack --config webpack.config.js", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" }, diff --git a/extensions/conversational-extension/src/Conversational.test.ts b/extensions/conversational-extension/src/Conversational.test.ts new file mode 100644 index 000000000..3d1d6fc60 --- /dev/null +++ b/extensions/conversational-extension/src/Conversational.test.ts @@ -0,0 +1,408 @@ +/** + * @jest-environment jsdom + */ +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + fs: { + existsSync: jest.fn(), + mkdir: jest.fn(), + writeFileSync: jest.fn(), + readdirSync: jest.fn(), + readFileSync: jest.fn(), + appendFileSync: jest.fn(), + rm: jest.fn(), + writeBlob: jest.fn(), + joinPath: jest.fn(), + fileStat: jest.fn(), + }, + joinPath: jest.fn(), + ConversationalExtension: jest.fn(), +})) + +import { fs } from '@janhq/core' + +import JSONConversationalExtension from '.' + +describe('JSONConversationalExtension Tests', () => { + let extension: JSONConversationalExtension + + beforeEach(() => { + // @ts-ignore + extension = new JSONConversationalExtension() + }) + + it('should create thread folder on load if it does not exist', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + + await extension.onLoad() + + expect(mkdirSpy).toHaveBeenCalledWith('file://threads') + }) + + it('should log message on unload', () => { + const consoleSpy = jest.spyOn(console, 'debug').mockImplementation() + + extension.onUnload() + + expect(consoleSpy).toHaveBeenCalledWith( + 'JSONConversationalExtension unloaded' + ) + }) + + it('should return sorted threads', async () => { + jest + .spyOn(extension, 'getValidThreadDirs') + .mockResolvedValue(['dir1', 'dir2']) + jest + .spyOn(extension, 'readThread') + .mockResolvedValueOnce({ updated: '2023-01-01' }) + .mockResolvedValueOnce({ updated: '2023-01-02' }) + + const threads = await extension.getThreads() + + expect(threads).toEqual([ + { updated: '2023-01-02' }, + { updated: '2023-01-01' }, + ]) + }) + + it('should ignore broken threads', async () => { + jest + .spyOn(extension, 'getValidThreadDirs') + .mockResolvedValue(['dir1', 'dir2']) + jest + .spyOn(extension, 'readThread') + .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' })) + .mockResolvedValueOnce('this_is_an_invalid_json_content') + + const threads = await extension.getThreads() + + expect(threads).toEqual([{ updated: '2023-01-01' }]) + }) + + it('should save thread', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const writeFileSyncSpy = jest + .spyOn(fs, 'writeFileSync') + .mockResolvedValue({}) + + const thread = { id: '1', updated: '2023-01-01' } as any + await extension.saveThread(thread) + + expect(mkdirSpy).toHaveBeenCalled() + expect(writeFileSyncSpy).toHaveBeenCalled() + }) + + it('should delete thread', async () => { + const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({}) + + await extension.deleteThread('1') + + expect(rmSpy).toHaveBeenCalled() + }) + + it('should add new message', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const appendFileSyncSpy = jest + .spyOn(fs, 'appendFileSync') + .mockResolvedValue({}) + + const message = { + thread_id: '1', + content: [{ type: 'text', text: { annotations: [] } }], + } as any + await extension.addNewMessage(message) + + expect(mkdirSpy).toHaveBeenCalled() + expect(appendFileSyncSpy).toHaveBeenCalled() + }) + + it('should store image', async () => { + const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + await extension.storeImage( + 'data:image/png;base64,abcd', + 'path/to/image.png' + ) + + expect(writeBlobSpy).toHaveBeenCalled() + }) + + it('should store file', async () => { + const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + await extension.storeFile( + 'data:application/pdf;base64,abcd', + 'path/to/file.pdf' + ) + + expect(writeBlobSpy).toHaveBeenCalled() + }) + + it('should write messages', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const writeFileSyncSpy = jest + .spyOn(fs, 'writeFileSync') + .mockResolvedValue({}) + + const messages = [{ id: '1', thread_id: '1', content: [] }] as any + await extension.writeMessages('1', messages) + + expect(mkdirSpy).toHaveBeenCalled() + expect(writeFileSyncSpy).toHaveBeenCalled() + }) + + it('should get all messages on string response', async () => { + jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) + jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n') + + const messages = await extension.getAllMessages('1') + + expect(messages).toEqual([{ id: '1' }, { id: '2' }]) + }) + + it('should get all messages on object response', async () => { + jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) + jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 }) + + const messages = await extension.getAllMessages('1') + + expect(messages).toEqual([{ id: 1 }]) + }) + + it('get all messages return empty on error', async () => { + jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl']) + + const messages = await extension.getAllMessages('1') + + expect(messages).toEqual([]) + }) + + it('return empty messages on no messages file', async () => { + jest.spyOn(fs, 'readdirSync').mockResolvedValue([]) + + const messages = await extension.getAllMessages('1') + + expect(messages).toEqual([]) + }) + + it('should ignore error message', async () => { + jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) + jest + .spyOn(fs, 'readFileSync') + .mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n') + + const messages = await extension.getAllMessages('1') + + expect(messages).toEqual([{ id: '1' }, { id: '2' }]) + }) + + it('should create thread folder on load if it does not exist', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + + await extension.onLoad() + + expect(mkdirSpy).toHaveBeenCalledWith('file://threads') + }) + + it('should log message on unload', () => { + const consoleSpy = jest.spyOn(console, 'debug').mockImplementation() + + extension.onUnload() + + expect(consoleSpy).toHaveBeenCalledWith( + 'JSONConversationalExtension unloaded' + ) + }) + + it('should return sorted threads', async () => { + jest + .spyOn(extension, 'getValidThreadDirs') + .mockResolvedValue(['dir1', 'dir2']) + jest + .spyOn(extension, 'readThread') + .mockResolvedValueOnce({ updated: '2023-01-01' }) + .mockResolvedValueOnce({ updated: '2023-01-02' }) + + const threads = await extension.getThreads() + + expect(threads).toEqual([ + { updated: '2023-01-02' }, + { updated: '2023-01-01' }, + ]) + }) + + it('should ignore broken threads', async () => { + jest + .spyOn(extension, 'getValidThreadDirs') + .mockResolvedValue(['dir1', 'dir2']) + jest + .spyOn(extension, 'readThread') + .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' })) + .mockResolvedValueOnce('this_is_an_invalid_json_content') + + const threads = await extension.getThreads() + + expect(threads).toEqual([{ updated: '2023-01-01' }]) + }) + + it('should save thread', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const writeFileSyncSpy = jest + .spyOn(fs, 'writeFileSync') + .mockResolvedValue({}) + + const thread = { id: '1', updated: '2023-01-01' } as any + await extension.saveThread(thread) + + expect(mkdirSpy).toHaveBeenCalled() + expect(writeFileSyncSpy).toHaveBeenCalled() + }) + + it('should delete thread', async () => { + const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({}) + + await extension.deleteThread('1') + + expect(rmSpy).toHaveBeenCalled() + }) + + it('should add new message', async () => { + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(false) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const appendFileSyncSpy = jest + .spyOn(fs, 'appendFileSync') + .mockResolvedValue({}) + + const message = { + thread_id: '1', + content: [{ type: 'text', text: { annotations: [] } }], + } as any + await extension.addNewMessage(message) + + expect(mkdirSpy).toHaveBeenCalled() + expect(appendFileSyncSpy).toHaveBeenCalled() + }) + + it('should add new image message', async () => { + jest + .spyOn(fs, 'existsSync') + // @ts-ignore + .mockResolvedValueOnce(false) + // @ts-ignore + .mockResolvedValueOnce(false) + // @ts-ignore + .mockResolvedValueOnce(true) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const appendFileSyncSpy = jest + .spyOn(fs, 'appendFileSync') + .mockResolvedValue({}) + jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + const message = { + thread_id: '1', + content: [ + { type: 'image', text: { annotations: ['data:image;base64,hehe'] } }, + ], + } as any + await extension.addNewMessage(message) + + expect(mkdirSpy).toHaveBeenCalled() + expect(appendFileSyncSpy).toHaveBeenCalled() + }) + + it('should add new pdf message', async () => { + jest + .spyOn(fs, 'existsSync') + // @ts-ignore + .mockResolvedValueOnce(false) + // @ts-ignore + .mockResolvedValueOnce(false) + // @ts-ignore + .mockResolvedValueOnce(true) + const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) + const appendFileSyncSpy = jest + .spyOn(fs, 'appendFileSync') + .mockResolvedValue({}) + jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + const message = { + thread_id: '1', + content: [ + { type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } }, + ], + } as any + await extension.addNewMessage(message) + + expect(mkdirSpy).toHaveBeenCalled() + expect(appendFileSyncSpy).toHaveBeenCalled() + }) + + it('should store image', async () => { + const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + await extension.storeImage( + 'data:image/png;base64,abcd', + 'path/to/image.png' + ) + + expect(writeBlobSpy).toHaveBeenCalled() + }) + + it('should store file', async () => { + const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) + + await extension.storeFile( + 'data:application/pdf;base64,abcd', + 'path/to/file.pdf' + ) + + expect(writeBlobSpy).toHaveBeenCalled() + }) +}) + +describe('test readThread', () => { + let extension: JSONConversationalExtension + + beforeEach(() => { + // @ts-ignore + extension = new JSONConversationalExtension() + }) + + it('should read thread', async () => { + jest + .spyOn(fs, 'readFileSync') + .mockResolvedValue(JSON.stringify({ id: '1' })) + const thread = await extension.readThread('1') + expect(thread).toEqual(`{"id":"1"}`) + }) + + it('getValidThreadDirs should return valid thread directories', async () => { + jest + .spyOn(fs, 'readdirSync') + .mockResolvedValueOnce(['1', '2', '3']) + .mockResolvedValueOnce(['thread.json']) + .mockResolvedValueOnce(['thread.json']) + .mockResolvedValueOnce([]) + // @ts-ignore + jest.spyOn(fs, 'existsSync').mockResolvedValue(true) + jest.spyOn(fs, 'fileStat').mockResolvedValue({ + isDirectory: true, + } as any) + const validThreadDirs = await extension.getValidThreadDirs() + expect(validThreadDirs).toEqual(['1', '2']) + }) +}) diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index 1bca75347..b34f09181 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -5,6 +5,7 @@ import { Thread, ThreadMessage, } from '@janhq/core' +import { safelyParseJSON } from './jsonUtil' /** * JSONConversationalExtension is a ConversationalExtension implementation that provides @@ -45,10 +46,11 @@ export default class JSONConversationalExtension extends ConversationalExtension if (result.status === 'fulfilled') { return typeof result.value === 'object' ? result.value - : JSON.parse(result.value) + : safelyParseJSON(result.value) } + return undefined }) - .filter((convo) => convo != null) + .filter((convo) => !!convo) convos.sort( (a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime() ) @@ -195,7 +197,7 @@ export default class JSONConversationalExtension extends ConversationalExtension * @param threadDirName the thread dir we are reading from. * @returns data of the thread */ - private async readThread(threadDirName: string): Promise { + async readThread(threadDirName: string): Promise { return fs.readFileSync( await joinPath([ JSONConversationalExtension._threadFolder, @@ -210,7 +212,7 @@ export default class JSONConversationalExtension extends ConversationalExtension * Returns a Promise that resolves to an array of thread directories. * @private */ - private async getValidThreadDirs(): Promise { + async getValidThreadDirs(): Promise { const fileInsideThread: string[] = await fs.readdirSync( JSONConversationalExtension._threadFolder ) @@ -266,7 +268,8 @@ export default class JSONConversationalExtension extends ConversationalExtension const messages: ThreadMessage[] = [] result.forEach((line: string) => { - messages.push(JSON.parse(line)) + const message = safelyParseJSON(line) + if (message) messages.push(safelyParseJSON(line)) }) return messages } catch (err) { diff --git a/extensions/conversational-extension/src/jsonUtil.ts b/extensions/conversational-extension/src/jsonUtil.ts new file mode 100644 index 000000000..7f83cadce --- /dev/null +++ b/extensions/conversational-extension/src/jsonUtil.ts @@ -0,0 +1,14 @@ +// Note about performance +// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block. +// v8 4.5 and above can optimise try/catch +export function safelyParseJSON(json) { + // This function cannot be optimised, it's best to + // keep it small! + var parsed + try { + parsed = JSON.parse(json) + } catch (e) { + return undefined + } + return parsed // Could be undefined! +} diff --git a/extensions/conversational-extension/tsconfig.json b/extensions/conversational-extension/tsconfig.json index 2477d58ce..8427123e7 100644 --- a/extensions/conversational-extension/tsconfig.json +++ b/extensions/conversational-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["src/**/*.test.ts"] } diff --git a/extensions/inference-anthropic-extension/jest.config.js b/extensions/inference-anthropic-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/inference-anthropic-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/inference-anthropic-extension/package.json b/extensions/inference-anthropic-extension/package.json index a9d30a8e5..19c0df5e8 100644 --- a/extensions/inference-anthropic-extension/package.json +++ b/extensions/inference-anthropic-extension/package.json @@ -9,6 +9,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest test", "build": "tsc -b . && webpack --config webpack.config.js", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install", "sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-anthropic-extension && yarn && yarn build:publish" diff --git a/extensions/inference-anthropic-extension/src/anthropic.test.ts b/extensions/inference-anthropic-extension/src/anthropic.test.ts new file mode 100644 index 000000000..703ead0fb --- /dev/null +++ b/extensions/inference-anthropic-extension/src/anthropic.test.ts @@ -0,0 +1,77 @@ +// Import necessary modules +import JanInferenceAnthropicExtension, { Settings } from '.' +import { PayloadType, ChatCompletionRole } from '@janhq/core' + +// Mocks +jest.mock('@janhq/core', () => ({ + RemoteOAIEngine: jest.fn().mockImplementation(() => ({ + registerSettings: jest.fn(), + registerModels: jest.fn(), + getSetting: jest.fn(), + onChange: jest.fn(), + onSettingUpdate: jest.fn(), + onLoad: jest.fn(), + headers: jest.fn(), + })), + PayloadType: jest.fn(), + ChatCompletionRole: { + User: 'user' as const, + Assistant: 'assistant' as const, + System: 'system' as const, + }, +})) + +// Helper functions +const createMockPayload = (): PayloadType => ({ + messages: [ + { role: ChatCompletionRole.System, content: 'Meow' }, + { role: ChatCompletionRole.User, content: 'Hello' }, + { role: ChatCompletionRole.Assistant, content: 'Hi there' }, + ], + model: 'claude-v1', + stream: false, +}) + +describe('JanInferenceAnthropicExtension', () => { + let extension: JanInferenceAnthropicExtension + + beforeEach(() => { + extension = new JanInferenceAnthropicExtension('', '') + extension.apiKey = 'mock-api-key' + extension.inferenceUrl = 'mock-endpoint' + jest.clearAllMocks() + }) + + it('should initialize with correct settings', async () => { + await extension.onLoad() + expect(extension.apiKey).toBe('mock-api-key') + expect(extension.inferenceUrl).toBe('mock-endpoint') + }) + + it('should transform payload correctly', () => { + const payload = createMockPayload() + const transformedPayload = extension.transformPayload(payload) + + expect(transformedPayload).toEqual({ + max_tokens: 4096, + model: 'claude-v1', + stream: false, + system: 'Meow', + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there' }, + ], + }) + }) + + it('should transform response correctly', () => { + const nonStreamResponse = { content: [{ text: 'Test response' }] } + const streamResponse = + 'data: {"type":"content_block_delta","delta":{"text":"Hello"}}' + + expect(extension.transformResponse(nonStreamResponse)).toBe('Test response') + expect(extension.transformResponse(streamResponse)).toBe('Hello') + expect(extension.transformResponse('')).toBe('') + expect(extension.transformResponse('event: something')).toBe('') + }) +}) diff --git a/extensions/inference-anthropic-extension/src/index.ts b/extensions/inference-anthropic-extension/src/index.ts index f28a584f2..94da26d94 100644 --- a/extensions/inference-anthropic-extension/src/index.ts +++ b/extensions/inference-anthropic-extension/src/index.ts @@ -13,7 +13,7 @@ import { ChatCompletionRole } from '@janhq/core' declare const SETTINGS: Array declare const MODELS: Array -enum Settings { +export enum Settings { apiKey = 'anthropic-api-key', chatCompletionsEndPoint = 'chat-completions-endpoint', } @@ -23,6 +23,7 @@ type AnthropicPayloadType = { model?: string max_tokens?: number messages?: Array<{ role: string; content: string }> + system?: string } /** @@ -113,6 +114,10 @@ export default class JanInferenceAnthropicExtension extends RemoteOAIEngine { role: 'assistant', content: item.content as string, }) + } else if (item.role === ChatCompletionRole.System) { + // When using Claude, you can dramatically improve its performance by using the system parameter to give it a role. + // This technique, known as role prompting, is the most powerful way to use system prompts with Claude. + convertedData.system = item.content as string } }) diff --git a/extensions/inference-anthropic-extension/tsconfig.json b/extensions/inference-anthropic-extension/tsconfig.json index 2477d58ce..6db951c9e 100644 --- a/extensions/inference-anthropic-extension/tsconfig.json +++ b/extensions/inference-anthropic-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json index ac3ed180a..f484b4511 100644 --- a/extensions/inference-nitro-extension/package.json +++ b/extensions/inference-nitro-extension/package.json @@ -1,7 +1,7 @@ { "name": "@janhq/inference-cortex-extension", "productName": "Cortex Inference Engine", - "version": "1.0.17", + "version": "1.0.18", "description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.", "main": "dist/index.js", "node": "dist/node/index.cjs.js", diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json index 36fceaad2..4d825cfeb 100644 --- a/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json +++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json @@ -8,7 +8,7 @@ "id": "deepseek-coder-1.3b", "object": "model", "name": "Deepseek Coder 1.3B Instruct Q8", - "version": "1.3", + "version": "1.4", "description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.", "format": "gguf", "settings": { @@ -22,13 +22,13 @@ "top_p": 0.95, "stream": true, "max_tokens": 16384, - "stop": [], + "stop": ["<|EOT|>"], "frequency_penalty": 0, "presence_penalty": 0 }, "metadata": { "author": "Deepseek, The Bloke", - "tags": ["Tiny", "Foundational Model"], + "tags": ["Tiny"], "size": 1430000000 }, "engine": "nitro" diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json index 103c4cbcb..e87d6a643 100644 --- a/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json +++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json @@ -2,13 +2,13 @@ "sources": [ { "filename": "deepseek-coder-33b-instruct.Q4_K_M.gguf", - "url": "https://huggingface.co/TheBloke/deepseek-coder-33B-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf" + "url": "https://huggingface.co/mradermacher/deepseek-coder-33b-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf" } ], "id": "deepseek-coder-34b", "object": "model", "name": "Deepseek Coder 33B Instruct Q4", - "version": "1.3", + "version": "1.4", "description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.", "format": "gguf", "settings": { @@ -22,13 +22,13 @@ "top_p": 0.95, "stream": true, "max_tokens": 16384, - "stop": [], + "stop": ["<|EOT|>"], "frequency_penalty": 0, "presence_penalty": 0 }, "metadata": { - "author": "Deepseek, The Bloke", - "tags": ["34B", "Foundational Model"], + "author": "Deepseek", + "tags": ["33B"], "size": 19940000000 }, "engine": "nitro" diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index d79e076d4..6e825e8fd 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -22,6 +22,7 @@ import { downloadFile, DownloadState, DownloadEvent, + ModelFile, } from '@janhq/core' declare const CUDA_DOWNLOAD_URL: string @@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine { this.nitroProcessInfo = health } - override loadModel(model: Model): Promise { + override loadModel(model: ModelFile): Promise { if (model.engine !== this.provider) return Promise.resolve() this.getNitroProcessHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index edc2d013d..98ca4572f 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry' import { log, getSystemResourceInfo, - Model, InferenceEngine, ModelSettingParams, PromptTemplate, SystemInformation, getJanDataFolderPath, + ModelFile, } from '@janhq/core/node' import { executableNitroFile } from './execute' import terminate from 'terminate' @@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch) */ interface ModelInitOptions { modelFolder: string - model: Model + model: ModelFile } // The PORT to use for the Nitro subprocess const PORT = 3928 @@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise { if (!settings?.ngl) { settings.ngl = 100 } - log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) + log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`) return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { method: 'POST', headers: { @@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise { }) .then((res) => { log( - `[CORTEX]::Debug: Load model success with response ${JSON.stringify( + `[CORTEX]:: Load model success with response ${JSON.stringify( res )}` ) @@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise { async function validateModelStatus(modelId: string): Promise { // Send a GET request to the validation URL. // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. - log(`[CORTEX]::Debug: Validating model ${modelId}`) + log(`[CORTEX]:: Validating model ${modelId}`) return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { method: 'POST', body: JSON.stringify({ @@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise { retryDelay: 300, }).then(async (res: Response) => { log( - `[CORTEX]::Debug: Validate model state with response ${JSON.stringify( + `[CORTEX]:: Validate model state with response ${JSON.stringify( res.status )}` ) @@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise { // Otherwise, return an object with an error message. if (body.model_loaded) { log( - `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( + `[CORTEX]:: Validate model state success with response ${JSON.stringify( body )}` ) @@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise { } const errorBody = await res.text() log( - `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( + `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( res.statusText )}` ) @@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise { async function killSubprocess(): Promise { const controller = new AbortController() setTimeout(() => controller.abort(), 5000) - log(`[CORTEX]::Debug: Request to kill cortex`) + log(`[CORTEX]:: Request to kill cortex`) const killRequest = () => { return fetch(NITRO_HTTP_KILL_URL, { @@ -321,17 +321,17 @@ async function killSubprocess(): Promise { .then(() => tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) ) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch((err) => { log( - `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` ) throw 'PORT_NOT_AVAILABLE' }) } if (subprocess?.pid && process.platform !== 'darwin') { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -341,7 +341,7 @@ async function killSubprocess(): Promise { } else { tcpPortUsed .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .then(() => resolve()) .catch(() => { log( @@ -362,7 +362,7 @@ async function killSubprocess(): Promise { * @returns A promise that resolves when the Nitro subprocess is started. */ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { - log(`[CORTEX]::Debug: Spawning cortex subprocess...`) + log(`[CORTEX]:: Spawning cortex subprocess...`) return new Promise(async (resolve, reject) => { let executableOptions = executableNitroFile( @@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { const args: string[] = ['1', LOCAL_HOST, PORT.toString()] // Execute the binary log( - `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` ) log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`) @@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { // Handle subprocess output subprocess.stdout.on('data', (data: any) => { - log(`[CORTEX]::Debug: ${data}`) + log(`[CORTEX]:: ${data}`) }) subprocess.stderr.on('data', (data: any) => { @@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { }) subprocess.on('close', (code: any) => { - log(`[CORTEX]::Debug: cortex exited with code: ${code}`) + log(`[CORTEX]:: cortex exited with code: ${code}`) subprocess = undefined reject(`child process exited with code ${code}`) }) @@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { tcpPortUsed .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) .then(() => { - log(`[CORTEX]::Debug: cortex is ready`) + log(`[CORTEX]:: cortex is ready`) resolve() }) }) diff --git a/extensions/inference-openai-extension/jest.config.js b/extensions/inference-openai-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/inference-openai-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json index 6852a1892..72517d540 100644 --- a/extensions/inference-openai-extension/resources/models.json +++ b/extensions/inference-openai-extension/resources/models.json @@ -119,5 +119,65 @@ ] }, "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-preview", + "object": "model", + "name": "OpenAI o1-preview", + "version": "1.0", + "description": "OpenAI o1-preview is a new model with complex reasoning", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-mini", + "object": "model", + "name": "OpenAI o1-mini", + "version": "1.0", + "description": "OpenAI o1-mini is a lightweight reasoning model", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" } ] diff --git a/extensions/inference-openai-extension/src/OpenAIExtension.test.ts b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts new file mode 100644 index 000000000..4d46bc007 --- /dev/null +++ b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts @@ -0,0 +1,54 @@ +/** + * @jest-environment jsdom + */ +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + RemoteOAIEngine: jest.fn().mockImplementation(() => ({ + onLoad: jest.fn(), + registerSettings: jest.fn(), + registerModels: jest.fn(), + getSetting: jest.fn(), + onSettingUpdate: jest.fn(), + })), +})) +import JanInferenceOpenAIExtension, { Settings } from '.' + +describe('JanInferenceOpenAIExtension', () => { + let extension: JanInferenceOpenAIExtension + + beforeEach(() => { + // @ts-ignore + extension = new JanInferenceOpenAIExtension() + }) + + it('should initialize with settings and models', async () => { + await extension.onLoad() + // Assuming there are some default SETTINGS and MODELS being registered + expect(extension.apiKey).toBe(undefined) + expect(extension.inferenceUrl).toBe('') + }) + + it('should transform the payload for preview models', () => { + const payload: any = { + max_tokens: 100, + model: 'o1-mini', + // Add other required properties... + } + + const transformedPayload = extension.transformPayload(payload) + expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens) + expect(transformedPayload).not.toHaveProperty('max_tokens') + expect(transformedPayload).toHaveProperty('max_completion_tokens') + }) + + it('should not transform the payload for non-preview models', () => { + const payload: any = { + max_tokens: 100, + model: 'non-preview-model', + // Add other required properties... + } + + const transformedPayload = extension.transformPayload(payload) + expect(transformedPayload).toEqual(payload) + }) +}) diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index 60446ccce..44c243adf 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -6,16 +6,17 @@ * @module inference-openai-extension/src/index */ -import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' +import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core' declare const SETTINGS: Array declare const MODELS: Array -enum Settings { +export enum Settings { apiKey = 'openai-api-key', chatCompletionsEndPoint = 'chat-completions-endpoint', } - +type OpenAIPayloadType = PayloadType & + ModelRuntimeParams & { max_completion_tokens: number } /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -24,6 +25,7 @@ enum Settings { export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { inferenceUrl: string = '' provider: string = 'openai' + previewModels = ['o1-mini', 'o1-preview'] override async onLoad(): Promise { super.onLoad() @@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { } } } + + /** + * Tranform the payload before sending it to the inference endpoint. + * The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter. + * Others do not. + * @param payload + * @returns + */ + transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => { + // Transform the payload for preview models + if (this.previewModels.includes(payload.model)) { + const { max_tokens, ...params } = payload + return { + ...params, + max_completion_tokens: max_tokens, + } + } + // Pass through for non-preview models + return payload + } } diff --git a/extensions/inference-openai-extension/tsconfig.json b/extensions/inference-openai-extension/tsconfig.json index 2477d58ce..6db951c9e 100644 --- a/extensions/inference-openai-extension/tsconfig.json +++ b/extensions/inference-openai-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/model-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json index 4a2c61b71..9a406dcf4 100644 --- a/extensions/model-extension/package.json +++ b/extensions/model-extension/package.json @@ -8,6 +8,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" }, diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index c3f3acc77..d36d8ffac 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -27,7 +27,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Compile TypeScript files // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // commonjs(), @@ -62,7 +62,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) commonjs(), // Allow node_modules resolution, so you can use 'external' to control diff --git a/extensions/model-extension/src/helpers/path.test.ts b/extensions/model-extension/src/helpers/path.test.ts new file mode 100644 index 000000000..64ca65d8a --- /dev/null +++ b/extensions/model-extension/src/helpers/path.test.ts @@ -0,0 +1,87 @@ +import { extractFileName } from './path'; + +describe('extractFileName Function', () => { + it('should correctly extract the file name with the provided file extension', () => { + const url = 'http://example.com/some/path/to/file.ext'; + const fileExtension = '.ext'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('file.ext'); + }); + + it('should correctly append the file extension if it does not already exist in the file name', () => { + const url = 'http://example.com/some/path/to/file'; + const fileExtension = '.txt'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('file.txt'); + }); + + it('should handle cases where the URL does not have a file extension correctly', () => { + const url = 'http://example.com/some/path/to/file'; + const fileExtension = '.jpg'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('file.jpg'); + }); + + it('should correctly handle URLs without a trailing slash', () => { + const url = 'http://example.com/some/path/tofile'; + const fileExtension = '.txt'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('tofile.txt'); + }); + + it('should correctly handle URLs with multiple file extensions', () => { + const url = 'http://example.com/some/path/tofile.tar.gz'; + const fileExtension = '.gz'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('tofile.tar.gz'); + }); + + it('should correctly handle URLs with special characters', () => { + const url = 'http://example.com/some/path/tófílë.extë'; + const fileExtension = '.extë'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('tófílë.extë'); + }); + + it('should correctly handle URLs that are just a file with no path', () => { + const url = 'http://example.com/file.txt'; + const fileExtension = '.txt'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('file.txt'); + }); + + it('should correctly handle URLs that have special query parameters', () => { + const url = 'http://example.com/some/path/tofile.ext?query=1'; + const fileExtension = '.ext'; + const fileName = extractFileName(url.split('?')[0], fileExtension); + expect(fileName).toBe('tofile.ext'); + }); + + it('should correctly handle URLs that have uppercase characters', () => { + const url = 'http://EXAMPLE.COM/PATH/TO/FILE.EXT'; + const fileExtension = '.ext'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('FILE.EXT'); + }); + + it('should correctly handle invalid URLs', () => { + const url = 'invalid-url'; + const fileExtension = '.txt'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('invalid-url.txt'); + }); + + it('should correctly handle empty URLs', () => { + const url = ''; + const fileExtension = '.txt'; + const fileName = extractFileName(url, fileExtension); + expect(fileName).toBe('.txt'); + }); + + it('should correctly handle undefined URLs', () => { + const url = undefined; + const fileExtension = '.txt'; + const fileName = extractFileName(url as any, fileExtension); + expect(fileName).toBe('.txt'); + }); +}); diff --git a/extensions/model-extension/src/helpers/path.ts b/extensions/model-extension/src/helpers/path.ts index cbb151aa6..6091005b8 100644 --- a/extensions/model-extension/src/helpers/path.ts +++ b/extensions/model-extension/src/helpers/path.ts @@ -3,6 +3,8 @@ */ export function extractFileName(url: string, fileExtension: string): string { + if(!url) return fileExtension + const extractedFileName = url.split('/').pop() const fileName = extractedFileName.toLowerCase().endsWith(fileExtension) ? extractedFileName diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts new file mode 100644 index 000000000..5b126d4cc --- /dev/null +++ b/extensions/model-extension/src/index.test.ts @@ -0,0 +1,788 @@ +/** + * @jest-environment jsdom + */ +const readDirSyncMock = jest.fn() +const existMock = jest.fn() +const readFileSyncMock = jest.fn() +const downloadMock = jest.fn() +const mkdirMock = jest.fn() +const writeFileSyncMock = jest.fn() +const copyFileMock = jest.fn() +const dirNameMock = jest.fn() +const executeMock = jest.fn() + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + events: { + emit: jest.fn(), + }, + fs: { + existsSync: existMock, + readdirSync: readDirSyncMock, + readFileSync: readFileSyncMock, + writeFileSync: writeFileSyncMock, + mkdir: mkdirMock, + copyFile: copyFileMock, + fileStat: () => ({ + isDirectory: false, + }), + }, + dirName: dirNameMock, + joinPath: (paths) => paths.join('/'), + ModelExtension: jest.fn(), + downloadFile: downloadMock, + executeOnMain: executeMock, +})) + +jest.mock('@huggingface/gguf') + +global.fetch = jest.fn(() => + Promise.resolve({ + json: () => Promise.resolve({ test: 100 }), + arrayBuffer: jest.fn(), + }) +) as jest.Mock + +import JanModelExtension from '.' +import { fs, dirName } from '@janhq/core' +import { gguf } from '@huggingface/gguf' + +describe('JanModelExtension', () => { + let sut: JanModelExtension + + beforeAll(() => { + // @ts-ignore + sut = new JanModelExtension() + }) + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('getConfiguredModels', () => { + describe("when there's no models are pre-populated", () => { + it('should return empty array', async () => { + // Mock configured models data + const configuredModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getConfiguredModels() + expect(result).toEqual([]) + }) + }) + + describe("when there's are pre-populated models - all flattened", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe("when there's are pre-populated models - there are nested folders", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else if (path.includes('model2/model2-1')) + return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('getDownloadedModels', () => { + describe('no models downloaded', () => { + it('should return empty array', async () => { + // Mock downloaded models data + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getDownloadedModels() + expect(result).toEqual([]) + }) + }) + describe('only one model is downloaded', () => { + describe('flatten folder', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded', () => { + describe('nested folders', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded with uppercased GGUF files', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.GGUF'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe('all models are downloaded - GGUF & Tensort RT', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json', 'test.engine'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('deleteModel', () => { + describe('model is a GGUF model', () => { + it('should delete the GGUF file', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockImplementation((path) => { + return ['model.json', 'test.gguf'] + }) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.gguf' + ) + }) + + it('no gguf file presented', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockReturnValue(['model.json']) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledTimes(0) + }) + + it('delete an imported model', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.gguf']) + + // MARK: This is a tricky logic implement? + // I will just add test for now but will align on the legacy implementation + fs.readFileSync = jest.fn().mockReturnValue( + JSON.stringify({ + metadata: { + author: 'user', + }, + }) + ) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.rm).toHaveBeenCalledWith('file://models/model1') + }) + + it('delete tensorrt-models', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.engine']) + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.engine' + ) + }) + }) + }) + + describe('downloadModel', () => { + const model: any = { + id: 'model-id', + name: 'Test Model', + sources: [ + { url: 'http://example.com/model.gguf', filename: 'model.gguf' }, + ], + engine: 'test-engine', + } + + const network = { + ignoreSSL: true, + proxy: 'http://proxy.example.com', + } + + const gpuSettings: any = { + gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }], + } + + it('should reject with invalid gguf metadata', async () => { + existMock.mockImplementation(() => false) + + expect( + sut.downloadModel(model, gpuSettings, network) + ).rejects.toBeTruthy() + }) + + it('should download corresponding ID', async () => { + existMock.mockImplementation(() => true) + dirNameMock.mockImplementation(() => 'file://models/model1') + downloadMock.mockImplementation(() => { + return Promise.resolve({}) + }) + + expect( + await sut.downloadModel( + { ...model, file_path: 'file://models/model1/model.json' }, + gpuSettings, + network + ) + ).toBeUndefined() + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model.gguf', + modelId: 'model-id', + url: 'http://example.com/model.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + }) + + it('should handle invalid model file', async () => { + executeMock.mockResolvedValue({}) + + fs.readFileSync = jest.fn(() => { + return JSON.stringify({ metadata: { author: 'user' } }) + }) + + expect( + sut.downloadModel( + { ...model, file_path: 'file://models/model1/model.json' }, + gpuSettings, + network + ) + ).resolves.not.toThrow() + + expect(downloadMock).not.toHaveBeenCalled() + }) + it('should handle model file with no sources', async () => { + executeMock.mockResolvedValue({}) + const modelWithoutSources = { ...model, sources: [] } + + expect( + sut.downloadModel( + { + ...modelWithoutSources, + file_path: 'file://models/model1/model.json', + }, + gpuSettings, + network + ) + ).resolves.toBe(undefined) + + expect(downloadMock).not.toHaveBeenCalled() + }) + + it('should handle model file with multiple sources', async () => { + const modelWithMultipleSources = { + ...model, + sources: [ + { url: 'http://example.com/model1.gguf', filename: 'model1.gguf' }, + { url: 'http://example.com/model2.gguf', filename: 'model2.gguf' }, + ], + } + + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + downloadMock.mockImplementation(() => { + return Promise.resolve({}) + }) + + expect( + await sut.downloadModel( + { + ...modelWithMultipleSources, + file_path: 'file://models/model1/model.json', + }, + gpuSettings, + network + ) + ).toBeUndefined() + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model1.gguf', + modelId: 'model-id', + url: 'http://example.com/model1.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model2.gguf', + modelId: 'model-id', + url: 'http://example.com/model2.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + }) + + it('should handle model file with no file_path', async () => { + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + const modelWithoutFilepath = { ...model, file_path: undefined } + + await sut.downloadModel(modelWithoutFilepath, gpuSettings, network) + + expect(downloadMock).toHaveBeenCalledWith( + expect.objectContaining({ + localPath: 'file://models/model-id/model.gguf', + }), + expect.anything() + ) + }) + + it('should handle model file with invalid file_path', async () => { + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + const modelWithInvalidFilepath = { + ...model, + file_path: 'file://models/invalid-model.json', + } + + await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network) + + expect(downloadMock).toHaveBeenCalledWith( + expect.objectContaining({ + localPath: 'file://models/model1/model.gguf', + }), + expect.anything() + ) + }) + }) +}) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index e2f68a58c..20d23b747 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -22,6 +22,8 @@ import { getFileSize, AllQuantizations, ModelEvent, + ModelFile, + dirName, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension { ] private static readonly _tensorRtEngineFormat = '.engine' private static readonly _supportedGpuArch = ['ampere', 'ada'] - private static readonly _safetensorsRegexs = [ - /model\.safetensors$/, - /model-[0-9]+-of-[0-9]+\.safetensors$/, - ] - private static readonly _pytorchRegexs = [ - /pytorch_model\.bin$/, - /consolidated\.[0-9]+\.pth$/, - /pytorch_model-[0-9]+-of-[0-9]+\.bin$/, - /.*\.pt$/, - ] + interrupted = false /** @@ -83,15 +76,32 @@ export default class JanModelExtension extends ModelExtension { * @returns A Promise that resolves when the model is downloaded. */ async downloadModel( - model: Model, + model: ModelFile, gpuSettings?: GpuSetting, network?: { ignoreSSL?: boolean; proxy?: string } ): Promise { - // create corresponding directory + // Create corresponding directory const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id]) if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath) - const modelJsonPath = await joinPath([modelDirPath, 'model.json']) + const modelJsonPath = + model.file_path ?? (await joinPath([modelDirPath, 'model.json'])) + + // Download HF model - model.json not exist if (!(await fs.existsSync(modelJsonPath))) { + // It supports only one source for HF download + const metadata = await this.fetchModelMetadata(model.sources[0].url) + const updatedModel = await this.retrieveGGUFMetadata(metadata) + if (updatedModel) { + // Update model settings + model.settings = { + ...model.settings, + ...updatedModel.settings, + } + model.parameters = { + ...model.parameters, + ...updatedModel.parameters, + } + } await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2)) events.emit(ModelEvent.OnModelsUpdate, {}) } @@ -142,11 +152,15 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._supportedModelFormat ) if (source.filename) { - path = await joinPath([modelDirPath, source.filename]) + path = model.file_path + ? await joinPath([await dirName(model.file_path), source.filename]) + : await joinPath([modelDirPath, source.filename]) } + const downloadRequest: DownloadRequest = { url: source.url, localPath: path, + modelId: model.id, } downloadFile(downloadRequest, network) } @@ -156,10 +170,13 @@ export default class JanModelExtension extends ModelExtension { model.sources[0]?.url, JanModelExtension._supportedModelFormat ) - const path = await joinPath([modelDirPath, fileName]) + const path = model.file_path + ? await joinPath([await dirName(model.file_path), fileName]) + : await joinPath([modelDirPath, fileName]) const downloadRequest: DownloadRequest = { url: model.sources[0]?.url, localPath: path, + modelId: model.id, } downloadFile(downloadRequest, network) @@ -319,9 +336,9 @@ export default class JanModelExtension extends ModelExtension { * @param filePath - The path to the model file to delete. * @returns A Promise that resolves when the model is deleted. */ - async deleteModel(modelId: string): Promise { + async deleteModel(model: ModelFile): Promise { try { - const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) + const dirPath = await dirName(model.file_path) const jsonFilePath = await joinPath([ dirPath, JanModelExtension._modelMetadataFileName, @@ -330,6 +347,8 @@ export default class JanModelExtension extends ModelExtension { await this.readModelMetadata(jsonFilePath) ) as Model + // TODO: This is so tricky? + // Should depend on sources? const isUserImportModel = modelInfo.metadata?.author?.toLowerCase() === 'user' if (isUserImportModel) { @@ -350,30 +369,11 @@ export default class JanModelExtension extends ModelExtension { } } - /** - * Saves a model file. - * @param model - The model to save. - * @returns A Promise that resolves when the model is saved. - */ - async saveModel(model: Model): Promise { - const jsonFilePath = await joinPath([ - JanModelExtension._homeDir, - model.id, - JanModelExtension._modelMetadataFileName, - ]) - - try { - await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2)) - } catch (err) { - console.error(err) - } - } - /** * Gets all downloaded models. * @returns A Promise that resolves with an array of all models. */ - async getDownloadedModels(): Promise { + async getDownloadedModels(): Promise { return await this.getModelsMetadata( async (modelDir: string, model: Model) => { if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) @@ -425,8 +425,10 @@ export default class JanModelExtension extends ModelExtension { ): Promise { // try to find model.json recursively inside each folder if (!(await fs.existsSync(folderFullPath))) return undefined + const files: string[] = await fs.readdirSync(folderFullPath) if (files.length === 0) return undefined + if (files.includes(JanModelExtension._modelMetadataFileName)) { return joinPath([ folderFullPath, @@ -446,7 +448,7 @@ export default class JanModelExtension extends ModelExtension { private async getModelsMetadata( selector?: (path: string, model: Model) => Promise - ): Promise { + ): Promise { try { if (!(await fs.existsSync(JanModelExtension._homeDir))) { console.debug('Model folder not found') @@ -469,6 +471,7 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._homeDir, dirName, ]) + const jsonPath = await this.getModelJsonPath(folderFullPath) if (await fs.existsSync(jsonPath)) { @@ -486,6 +489,8 @@ export default class JanModelExtension extends ModelExtension { }, ] } + model.file_path = jsonPath + model.file_name = JanModelExtension._modelMetadataFileName if (selector && !(await selector?.(dirName, model))) { return @@ -506,7 +511,7 @@ export default class JanModelExtension extends ModelExtension { typeof result.value === 'object' ? result.value : JSON.parse(result.value) - return model as Model + return model as ModelFile } catch { console.debug(`Unable to parse model metadata: ${result.value}`) } @@ -574,7 +579,7 @@ export default class JanModelExtension extends ModelExtension { ]) ) - const eos_id = metadata?.['tokenizer.ggml.eos_token_id'] + const updatedModel = await this.retrieveGGUFMetadata(metadata) if (!defaultModel) { console.error('Unable to find default model') @@ -594,18 +599,11 @@ export default class JanModelExtension extends ModelExtension { ], parameters: { ...defaultModel.parameters, - stop: eos_id - ? [metadata['tokenizer.ggml.tokens'][eos_id] ?? ''] - : defaultModel.parameters.stop, + ...updatedModel.parameters, }, settings: { ...defaultModel.settings, - prompt_template: - metadata?.parsed_chat_template ?? - defaultModel.settings.prompt_template, - ctx_len: - metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len, - ngl: (metadata?.['llama.block_count'] ?? 32) + 1, + ...updatedModel.settings, llama_model_path: binaryFileName, }, created: Date.now(), @@ -637,7 +635,7 @@ export default class JanModelExtension extends ModelExtension { * Gets all available models. * @returns A Promise that resolves with an array of all models. */ - async getConfiguredModels(): Promise { + async getConfiguredModels(): Promise { return this.getModelsMetadata() } @@ -669,7 +667,7 @@ export default class JanModelExtension extends ModelExtension { modelBinaryPath: string, modelFolderName: string, modelFolderPath: string - ): Promise { + ): Promise { const fileStats = await fs.fileStat(modelBinaryPath, true) const binaryFileSize = fileStats.size @@ -685,9 +683,9 @@ export default class JanModelExtension extends ModelExtension { 'retrieveGGUFMetadata', modelBinaryPath ) - const eos_id = metadata?.['tokenizer.ggml.eos_token_id'] const binaryFileName = await baseName(modelBinaryPath) + const updatedModel = await this.retrieveGGUFMetadata(metadata) const model: Model = { ...defaultModel, @@ -701,19 +699,12 @@ export default class JanModelExtension extends ModelExtension { ], parameters: { ...defaultModel.parameters, - stop: eos_id - ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? ''] - : defaultModel.parameters.stop, + ...updatedModel.parameters, }, settings: { ...defaultModel.settings, - prompt_template: - metadata?.parsed_chat_template ?? - defaultModel.settings.prompt_template, - ctx_len: - metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len, - ngl: (metadata?.['llama.block_count'] ?? 32) + 1, + ...updatedModel.settings, llama_model_path: binaryFileName, }, created: Date.now(), @@ -732,25 +723,21 @@ export default class JanModelExtension extends ModelExtension { await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) - return model + return { + ...model, + file_path: modelFilePath, + file_name: JanModelExtension._modelMetadataFileName, + } } - async updateModelInfo(modelInfo: Partial): Promise { - const modelId = modelInfo.id + async updateModelInfo(modelInfo: Partial): Promise { if (modelInfo.id == null) throw new Error('Model ID is required') - const janDataFolderPath = await getJanDataFolderPath() - const jsonFilePath = await joinPath([ - janDataFolderPath, - 'models', - modelId, - JanModelExtension._modelMetadataFileName, - ]) const model = JSON.parse( - await this.readModelMetadata(jsonFilePath) - ) as Model + await this.readModelMetadata(modelInfo.file_path) + ) as ModelFile - const updatedModel: Model = { + const updatedModel: ModelFile = { ...model, ...modelInfo, parameters: { @@ -765,9 +752,15 @@ export default class JanModelExtension extends ModelExtension { ...model.metadata, ...modelInfo.metadata, }, + // Should not persist file_path & file_name + file_path: undefined, + file_name: undefined, } - await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) + await fs.writeFileSync( + modelInfo.file_path, + JSON.stringify(updatedModel, null, 2) + ) return updatedModel } @@ -877,4 +870,35 @@ export default class JanModelExtension extends ModelExtension { importedModels ) } + + /** + * Retrieve Model Settings from GGUF Metadata + * @param metadata + * @returns + */ + async retrieveGGUFMetadata(metadata: any): Promise> { + const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata) + const defaultModel = DEFAULT_MODEL as Model + const eos_id = metadata['tokenizer.ggml.eos_token_id'] + const architecture = metadata['general.architecture'] + + return { + settings: { + prompt_template: template ?? defaultModel.settings.prompt_template, + ctx_len: + metadata[`${architecture}.context_length`] ?? + metadata['llama.context_length'] ?? + 4096, + ngl: + (metadata[`${architecture}.block_count`] ?? + metadata['llama.block_count'] ?? + 32) + 1, + }, + parameters: { + stop: eos_id + ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? ''] + : defaultModel.parameters.stop, + }, + } + } } diff --git a/extensions/model-extension/src/node/index.ts b/extensions/model-extension/src/node/index.ts index 2b498f424..6323d7f97 100644 --- a/extensions/model-extension/src/node/index.ts +++ b/extensions/model-extension/src/node/index.ts @@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => { // Parse metadata and tensor info const { metadata } = ggufMetadata(buffer.buffer) - const template = new Template(metadata['tokenizer.chat_template']) - const eos_id = metadata['tokenizer.ggml.eos_token_id'] - const bos_id = metadata['tokenizer.ggml.bos_token_id'] - const eos_token = metadata['tokenizer.ggml.tokens'][eos_id] - const bos_token = metadata['tokenizer.ggml.tokens'][bos_id] // Parse jinja template - const renderedTemplate = template.render({ - add_generation_prompt: true, - eos_token, - bos_token, - messages: [ - { - role: 'system', - content: '{system_message}', - }, - { - role: 'user', - content: '{prompt}', - }, - ], - }) + const renderedTemplate = renderJinjaTemplate(metadata) return { ...metadata, parsed_chat_template: renderedTemplate, @@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => { console.log('[MODEL_EXT]', e) } } + +/** + * Convert metadata to jinja template + * @param metadata + */ +export const renderJinjaTemplate = (metadata: any): string => { + const template = new Template(metadata['tokenizer.chat_template']) + const eos_id = metadata['tokenizer.ggml.eos_token_id'] + const bos_id = metadata['tokenizer.ggml.bos_token_id'] + if (eos_id === undefined || bos_id === undefined) { + return '' + } + const eos_token = metadata['tokenizer.ggml.tokens'][eos_id] + const bos_token = metadata['tokenizer.ggml.tokens'][bos_id] + // Parse jinja template + return template.render({ + add_generation_prompt: true, + eos_token, + bos_token, + messages: [ + { + role: 'system', + content: '{system_message}', + }, + { + role: 'user', + content: '{prompt}', + }, + ], + }) +} diff --git a/extensions/model-extension/src/node/node.test.ts b/extensions/model-extension/src/node/node.test.ts new file mode 100644 index 000000000..afd2b8470 --- /dev/null +++ b/extensions/model-extension/src/node/node.test.ts @@ -0,0 +1,53 @@ +import { renderJinjaTemplate } from './index' +import { Template } from '@huggingface/jinja' + +jest.mock('@huggingface/jinja', () => ({ + Template: jest.fn((template: string) => ({ + render: jest.fn(() => `${template}_rendered`), + })), +})) + +describe('renderJinjaTemplate', () => { + beforeEach(() => { + jest.clearAllMocks() // Clear mocks between tests + }) + + it('should render the template with correct parameters', () => { + const metadata = { + 'tokenizer.chat_template': 'Hello, {{ messages }}!', + 'tokenizer.ggml.eos_token_id': 0, + 'tokenizer.ggml.bos_token_id': 1, + 'tokenizer.ggml.tokens': ['EOS', 'BOS'], + } + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!') + + expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered') + }) + + it('should handle missing token IDs gracefully', () => { + const metadata = { + 'tokenizer.chat_template': 'Hello, {{ messages }}!', + 'tokenizer.ggml.eos_token_id': 0, + 'tokenizer.ggml.tokens': ['EOS'], + } + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!') + + expect(renderedTemplate).toBe('') + }) + + it('should handle empty template gracefully', () => { + const metadata = {} + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith(undefined) + + expect(renderedTemplate).toBe("") + }) +}) diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json index addd8e127..0d3252934 100644 --- a/extensions/model-extension/tsconfig.json +++ b/extensions/model-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/tensorrt-llm-extension/jest.config.js b/extensions/tensorrt-llm-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/tensorrt-llm-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/tensorrt-llm-extension/package.json b/extensions/tensorrt-llm-extension/package.json index c5cb54809..7a7ef6ef0 100644 --- a/extensions/tensorrt-llm-extension/package.json +++ b/extensions/tensorrt-llm-extension/package.json @@ -22,6 +22,7 @@ "tensorrtVersion": "0.1.8", "provider": "nitro-tensorrt-llm", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts", "build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", @@ -49,7 +50,12 @@ "rollup-plugin-sourcemaps": "^0.6.3", "rollup-plugin-typescript2": "^0.36.0", "run-script-os": "^1.1.6", - "typescript": "^5.2.2" + "typescript": "^5.2.2", + "@types/jest": "^29.5.12", + "jest": "^29.7.0", + "jest-junit": "^16.0.0", + "jest-runner": "^29.7.0", + "ts-jest": "^29.2.5" }, "dependencies": { "@janhq/core": "file:../../core", diff --git a/extensions/tensorrt-llm-extension/rollup.config.ts b/extensions/tensorrt-llm-extension/rollup.config.ts index 1fad0e711..50b4350e7 100644 --- a/extensions/tensorrt-llm-extension/rollup.config.ts +++ b/extensions/tensorrt-llm-extension/rollup.config.ts @@ -23,10 +23,10 @@ export default [ DOWNLOAD_RUNNER_URL: process.platform === 'win32' ? JSON.stringify( - 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' + 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' ) : JSON.stringify( - 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' + 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' ), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), INFERENCE_URL: JSON.stringify( diff --git a/extensions/tensorrt-llm-extension/src/index.test.ts b/extensions/tensorrt-llm-extension/src/index.test.ts new file mode 100644 index 000000000..48d6e71d7 --- /dev/null +++ b/extensions/tensorrt-llm-extension/src/index.test.ts @@ -0,0 +1,186 @@ +import TensorRTLLMExtension from '../src/index' +import { + executeOnMain, + systemInformation, + fs, + baseName, + joinPath, + downloadFile, +} from '@janhq/core' + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + LocalOAIEngine: jest.fn().mockImplementation(function () { + // @ts-ignore + this.registerModels = () => { + return Promise.resolve() + } + // @ts-ignore + return this + }), + systemInformation: jest.fn(), + fs: { + existsSync: jest.fn(), + mkdir: jest.fn(), + }, + joinPath: jest.fn(), + baseName: jest.fn(), + downloadFile: jest.fn(), + executeOnMain: jest.fn(), + showToast: jest.fn(), + events: { + emit: jest.fn(), + // @ts-ignore + on: (event, func) => { + func({ fileName: './' }) + }, + off: jest.fn(), + }, +})) + +// @ts-ignore +global.COMPATIBILITY = { + platform: ['win32'], +} +// @ts-ignore +global.PROVIDER = 'tensorrt-llm' +// @ts-ignore +global.INFERENCE_URL = 'http://localhost:5000' +// @ts-ignore +global.NODE = 'node' +// @ts-ignore +global.MODELS = [] +// @ts-ignore +global.TENSORRT_VERSION = '' +// @ts-ignore +global.DOWNLOAD_RUNNER_URL = '' + +describe('TensorRTLLMExtension', () => { + let extension: TensorRTLLMExtension + + beforeEach(() => { + // @ts-ignore + extension = new TensorRTLLMExtension() + jest.clearAllMocks() + }) + + describe('compatibility', () => { + it('should return the correct compatibility', () => { + const result = extension.compatibility() + expect(result).toEqual({ + platform: ['win32'], + }) + }) + }) + + describe('install', () => { + it('should install if compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(executeOnMain as jest.Mock).mockResolvedValue({}) + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(false) + ;(fs.mkdir as jest.Mock).mockResolvedValue(undefined) + ;(baseName as jest.Mock).mockResolvedValue('./') + ;(joinPath as jest.Mock).mockResolvedValue('./') + ;(downloadFile as jest.Mock).mockResolvedValue({}) + + await extension.install() + + expect(executeOnMain).toHaveBeenCalled() + }) + + it('should not install if not compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + + jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve()) + await extension.install() + + expect(executeOnMain).not.toHaveBeenCalled() + }) + }) + + describe('installationState', () => { + it('should return NotCompatible if not compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + + const result = await extension.installationState() + + expect(result).toBe('NotCompatible') + }) + + it('should return Installed if executable exists', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(true) + + const result = await extension.installationState() + + expect(result).toBe('Installed') + }) + + it('should return NotInstalled if executable does not exist', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(false) + + const result = await extension.installationState() + + expect(result).toBe('NotInstalled') + }) + }) + + describe('isCompatible', () => { + it('should return true for compatible system', () => { + const mockInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + + const result = extension.isCompatible(mockInfo) + + expect(result).toBe(true) + }) + + it('should return false for incompatible system', () => { + const mockInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] }, + } + + const result = extension.isCompatible(mockInfo) + + expect(result).toBe(false) + }) + }) +}) + +describe('GitHub Release File URL Test', () => { + const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'; + + it('should return a status code 200 for the release file URL', async () => { + const response = await fetch(url, { method: 'HEAD' }); + expect(response.status).toBe(200); + }); + + it('should not return a 404 status', async () => { + const response = await fetch(url, { method: 'HEAD' }); + expect(response.status).not.toBe(404); + }); +}); diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index 189abc706..11c86a9a7 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -23,6 +23,7 @@ import { ModelEvent, getJanDataFolderPath, SystemInformation, + ModelFile, } from '@janhq/core' /** @@ -40,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { override nodeModule = NODE private supportedGpuArch = ['ampere', 'ada'] - private supportedPlatform = ['win32', 'linux'] override compatibility() { return COMPATIBILITY as unknown as Compatibility @@ -137,7 +137,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.emit(ModelEvent.OnModelsUpdate, {}) } - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if ((await this.installationState()) === 'Installed') return super.loadModel(model) @@ -190,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { !!info.gpuSetting && !!firstGpu && info.gpuSetting.gpus.length > 0 && - this.supportedPlatform.includes(info.osInfo.platform) && + this.compatibility().platform.includes(info.osInfo.platform) && !!firstGpu.arch && firstGpu.name.toLowerCase().includes('nvidia') && this.supportedGpuArch.includes(firstGpu.arch) diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts index c8bc48459..77003389f 100644 --- a/extensions/tensorrt-llm-extension/src/node/index.ts +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -97,7 +97,7 @@ function unloadModel(): Promise { } if (subprocess?.pid) { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -107,7 +107,7 @@ function unloadModel(): Promise { return tcpPortUsed .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) .then(() => resolve()) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch(() => { killRequest() }) diff --git a/extensions/tensorrt-llm-extension/tsconfig.json b/extensions/tensorrt-llm-extension/tsconfig.json index 478a05728..be07e716c 100644 --- a/extensions/tensorrt-llm-extension/tsconfig.json +++ b/extensions/tensorrt-llm-extension/tsconfig.json @@ -16,5 +16,6 @@ "resolveJsonModule": true, "typeRoots": ["node_modules/@types"] }, - "include": ["src"] + "include": ["src"], + "exclude": ["**/*.test.ts"] } diff --git a/joi/jest.config.js b/joi/jest.config.js index 8543f24e3..676042491 100644 --- a/joi/jest.config.js +++ b/joi/jest.config.js @@ -3,6 +3,7 @@ module.exports = { testEnvironment: 'node', roots: ['/src'], testMatch: ['**/*.test.*'], + collectCoverageFrom: ['src/**/*.{ts,tsx}'], setupFilesAfterEnv: ['/jest.setup.js'], testEnvironment: 'jsdom', } diff --git a/joi/src/core/Tabs/Tabs.test.tsx b/joi/src/core/Tabs/Tabs.test.tsx index b6dcf8a7b..46bd48435 100644 --- a/joi/src/core/Tabs/Tabs.test.tsx +++ b/joi/src/core/Tabs/Tabs.test.tsx @@ -96,4 +96,20 @@ describe('@joi/core/Tabs', () => { 'Disabled tab' ) }) + + it('applies the tabStyle if provided', () => { + render( + {}} + tabStyle="segmented" + /> + ) + + const tabsContainer = screen.getByTestId('segmented-style') + expect(tabsContainer).toHaveClass('tabs') + expect(tabsContainer).toHaveClass('tabs--segmented') + }) }) diff --git a/joi/src/core/Tabs/index.tsx b/joi/src/core/Tabs/index.tsx index af004e2ba..2dca19831 100644 --- a/joi/src/core/Tabs/index.tsx +++ b/joi/src/core/Tabs/index.tsx @@ -7,6 +7,8 @@ import { Tooltip } from '../Tooltip' import './styles.scss' import { twMerge } from 'tailwind-merge' +type TabStyles = 'segmented' + type TabsProps = { options: { name: string @@ -14,8 +16,10 @@ type TabsProps = { disabled?: boolean tooltipContent?: string }[] - children: ReactNode + children?: ReactNode + defaultValue?: string + tabStyle?: TabStyles value: string onValueChange?: (value: string) => void } @@ -40,15 +44,18 @@ const TabsContent = ({ value, children, className }: TabsContentProps) => { const Tabs = ({ options, children, + tabStyle, defaultValue, value, onValueChange, + ...props }: TabsProps) => ( {options.map((option, i) => { diff --git a/joi/src/core/Tabs/styles.scss b/joi/src/core/Tabs/styles.scss index a24585b4e..ce3df013b 100644 --- a/joi/src/core/Tabs/styles.scss +++ b/joi/src/core/Tabs/styles.scss @@ -3,6 +3,27 @@ flex-direction: column; width: 100%; + &--segmented { + background-color: hsla(var(--secondary-bg)); + border-radius: 6px; + height: 33px; + + .tabs__list { + border: none; + justify-content: center; + align-items: center; + height: 33px; + } + + .tabs__trigger[data-state='active'] { + background-color: hsla(var(--app-bg)); + border: none; + height: 25px; + margin: 0 4px; + border-radius: 5px; + } + } + &__list { flex-shrink: 0; display: flex; @@ -14,9 +35,11 @@ flex: 1; height: 38px; display: flex; + color: hsla(var(--text-secondary)); align-items: center; justify-content: center; line-height: 1; + font-weight: medium; user-select: none; &:focus { position: relative; @@ -38,4 +61,5 @@ .tabs__trigger[data-state='active'] { border-bottom: 1px solid hsla(var(--primary-bg)); font-weight: 600; + color: hsla(var(--text-primary)); } diff --git a/joi/src/core/TextArea/TextArea.test.tsx b/joi/src/core/TextArea/TextArea.test.tsx index 8bc64010f..e29eed5d0 100644 --- a/joi/src/core/TextArea/TextArea.test.tsx +++ b/joi/src/core/TextArea/TextArea.test.tsx @@ -1,9 +1,8 @@ import React from 'react' -import { render, screen } from '@testing-library/react' +import { render, screen, act } from '@testing-library/react' import '@testing-library/jest-dom' import { TextArea } from './index' -// Mock the styles import jest.mock('./styles.scss', () => ({})) describe('@joi/core/TextArea', () => { @@ -31,4 +30,40 @@ describe('@joi/core/TextArea', () => { const textareaElement = screen.getByTestId('custom-textarea') expect(textareaElement).toHaveAttribute('rows', '5') }) + + it('should auto resize the textarea based on minResize', () => { + render(