Merge pull request #3746 from janhq/dev

release: Jan Release Cut v0.5.5
This commit is contained in:
Louis 2024-10-01 14:18:06 +07:00 committed by GitHub
commit b0b49f44f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
223 changed files with 9757 additions and 851 deletions

View File

@ -0,0 +1,25 @@
name: Trigger Docs Workflow
on:
release:
types:
- published
workflow_dispatch:
push:
branches:
- ci/auto-trigger-jan-docs-for-new-release
jobs:
trigger_docs_workflow:
runs-on: ubuntu-latest
steps:
- name: Trigger external workflow using GitHub API
env:
GITHUB_TOKEN: ${{ secrets.PAT_SERVICE_ACCOUNT }}
run: |
curl -X POST \
-H "Accept: application/vnd.github.v3+json" \
-H "Authorization: token $GITHUB_TOKEN" \
https://api.github.com/repos/janhq/docs/actions/workflows/jan-docs.yml/dispatches \
-d '{"ref":"main"}'

2
.gitignore vendored
View File

@ -45,3 +45,5 @@ core/test_results.html
coverage
.yarn
.yarnrc
test_results.html
*.tsbuildinfo

View File

@ -1,8 +1,17 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
collectCoverageFrom: ['src/**/*.{ts,tsx}'],
moduleNameMapper: {
'@/(.*)': '<rootDir>/src/$1',
},
runner: './testRunner.js',
transform: {
"^.+\\.tsx?$": [
"ts-jest",
{
diagnostics: false,
},
],
},
}

View File

@ -1,98 +1,109 @@
import { openExternalUrl } from './core';
import { joinPath } from './core';
import { openFileExplorer } from './core';
import { getJanDataFolderPath } from './core';
import { abortDownload } from './core';
import { getFileSize } from './core';
import { executeOnMain } from './core';
import { openExternalUrl } from './core'
import { joinPath } from './core'
import { openFileExplorer } from './core'
import { getJanDataFolderPath } from './core'
import { abortDownload } from './core'
import { getFileSize } from './core'
import { executeOnMain } from './core'
it('should open external url', async () => {
const url = 'http://example.com';
globalThis.core = {
api: {
openExternalUrl: jest.fn().mockResolvedValue('opened')
describe('test core apis', () => {
it('should open external url', async () => {
const url = 'http://example.com'
globalThis.core = {
api: {
openExternalUrl: jest.fn().mockResolvedValue('opened'),
},
}
};
const result = await openExternalUrl(url);
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url);
expect(result).toBe('opened');
});
const result = await openExternalUrl(url)
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
expect(result).toBe('opened')
})
it('should join paths', async () => {
const paths = ['/path/one', '/path/two'];
globalThis.core = {
api: {
joinPath: jest.fn().mockResolvedValue('/path/one/path/two')
it('should join paths', async () => {
const paths = ['/path/one', '/path/two']
globalThis.core = {
api: {
joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
},
}
};
const result = await joinPath(paths);
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths);
expect(result).toBe('/path/one/path/two');
});
const result = await joinPath(paths)
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
expect(result).toBe('/path/one/path/two')
})
it('should open file explorer', async () => {
const path = '/path/to/open';
globalThis.core = {
api: {
openFileExplorer: jest.fn().mockResolvedValue('opened')
it('should open file explorer', async () => {
const path = '/path/to/open'
globalThis.core = {
api: {
openFileExplorer: jest.fn().mockResolvedValue('opened'),
},
}
};
const result = await openFileExplorer(path);
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path);
expect(result).toBe('opened');
});
const result = await openFileExplorer(path)
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
expect(result).toBe('opened')
})
it('should get jan data folder path', async () => {
globalThis.core = {
api: {
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data')
it('should get jan data folder path', async () => {
globalThis.core = {
api: {
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
},
}
};
const result = await getJanDataFolderPath();
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled();
expect(result).toBe('/path/to/jan/data');
});
const result = await getJanDataFolderPath()
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
expect(result).toBe('/path/to/jan/data')
})
it('should abort download', async () => {
const fileName = 'testFile';
globalThis.core = {
api: {
abortDownload: jest.fn().mockResolvedValue('aborted')
it('should abort download', async () => {
const fileName = 'testFile'
globalThis.core = {
api: {
abortDownload: jest.fn().mockResolvedValue('aborted'),
},
}
};
const result = await abortDownload(fileName);
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName);
expect(result).toBe('aborted');
});
const result = await abortDownload(fileName)
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
expect(result).toBe('aborted')
})
it('should get file size', async () => {
const url = 'http://example.com/file';
globalThis.core = {
api: {
getFileSize: jest.fn().mockResolvedValue(1024)
it('should get file size', async () => {
const url = 'http://example.com/file'
globalThis.core = {
api: {
getFileSize: jest.fn().mockResolvedValue(1024),
},
}
};
const result = await getFileSize(url);
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url);
expect(result).toBe(1024);
});
const result = await getFileSize(url)
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
expect(result).toBe(1024)
})
it('should execute function on main process', async () => {
const extension = 'testExtension';
const method = 'testMethod';
const args = ['arg1', 'arg2'];
globalThis.core = {
api: {
invokeExtensionFunc: jest.fn().mockResolvedValue('result')
it('should execute function on main process', async () => {
const extension = 'testExtension'
const method = 'testMethod'
const args = ['arg1', 'arg2']
globalThis.core = {
api: {
invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
},
}
};
const result = await executeOnMain(extension, method, ...args);
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args);
expect(result).toBe('result');
});
const result = await executeOnMain(extension, method, ...args)
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
expect(result).toBe('result')
})
})
describe('dirName - just a pass thru api', () => {
it('should retrieve the directory name from a file path', async () => {
const mockDirName = jest.fn()
globalThis.core = {
api: {
dirName: mockDirName.mockResolvedValue('/path/to'),
},
}
// Normal file path with extension
const path = '/path/to/file.txt'
await globalThis.core.api.dirName(path)
expect(mockDirName).toHaveBeenCalledWith(path)
})
})

View File

@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
globalThis.core.api?.joinPath(paths)
/**
* Get dirname of a file path.
* @param path - The file path to retrieve dirname.
* @returns {Promise<string>} A promise that resolves the dirname.
*/
const dirName: (path: string) => Promise<string> = (path) => globalThis.core.api?.dirName(path)
/**
* Retrieve the basename from an url.
* @param path - The path to retrieve.
@ -161,5 +168,6 @@ export {
systemInformation,
showToast,
getFileSize,
dirName,
FileStat,
}

View File

@ -1,4 +1,9 @@
import { BaseExtension } from './extension'
import { SettingComponentProps } from '../types'
import { getJanDataFolderPath, joinPath } from './core'
import { fs } from './fs'
jest.mock('./core')
jest.mock('./fs')
class TestBaseExtension extends BaseExtension {
onLoad(): void {}
@ -44,3 +49,103 @@ describe('BaseExtension', () => {
// Add your assertions here
})
})
describe('BaseExtension', () => {
class TestBaseExtension extends BaseExtension {
onLoad(): void {}
onUnload(): void {}
}
let baseExtension: TestBaseExtension
beforeEach(() => {
baseExtension = new TestBaseExtension('https://example.com', 'TestExtension')
})
afterEach(() => {
jest.resetAllMocks()
})
it('should have the correct properties', () => {
expect(baseExtension.name).toBe('TestExtension')
expect(baseExtension.productName).toBeUndefined()
expect(baseExtension.url).toBe('https://example.com')
expect(baseExtension.active).toBeUndefined()
expect(baseExtension.description).toBeUndefined()
expect(baseExtension.version).toBeUndefined()
})
it('should return undefined for type()', () => {
expect(baseExtension.type()).toBeUndefined()
})
it('should have abstract methods onLoad() and onUnload()', () => {
expect(baseExtension.onLoad).toBeDefined()
expect(baseExtension.onUnload).toBeDefined()
})
it('should have installationState() return "NotRequired"', async () => {
const installationState = await baseExtension.installationState()
expect(installationState).toBe('NotRequired')
})
it('should install the extension', async () => {
await baseExtension.install()
// Add your assertions here
})
it('should register settings', async () => {
const settings: SettingComponentProps[] = [
{ key: 'setting1', controllerProps: { value: 'value1' } } as any,
{ key: 'setting2', controllerProps: { value: 'value2' } } as any,
]
;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data')
;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension')
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined)
await baseExtension.registerSettings(settings)
expect(fs.mkdir).toHaveBeenCalledWith('/data/settings/TestExtension')
expect(fs.writeFileSync).toHaveBeenCalledWith(
'/data/settings/TestExtension',
JSON.stringify(settings, null, 2)
)
})
it('should get setting with default value', async () => {
const settings: SettingComponentProps[] = [
{ key: 'setting1', controllerProps: { value: 'value1' } } as any,
]
jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings)
const value = await baseExtension.getSetting('setting1', 'defaultValue')
expect(value).toBe('value1')
const defaultValue = await baseExtension.getSetting('setting2', 'defaultValue')
expect(defaultValue).toBe('defaultValue')
})
it('should update settings', async () => {
const settings: SettingComponentProps[] = [
{ key: 'setting1', controllerProps: { value: 'value1' } } as any,
]
jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings)
;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data')
;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension/settings.json')
;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined)
await baseExtension.updateSettings([
{ key: 'setting1', controllerProps: { value: 'newValue' } } as any,
])
expect(fs.writeFileSync).toHaveBeenCalledWith(
'/data/settings/TestExtension/settings.json',
JSON.stringify([{ key: 'setting1', controllerProps: { value: 'newValue' } }], null, 2)
)
})
})

View File

@ -0,0 +1,8 @@
import { AssistantExtension } from './assistant';
import { ExtensionTypeEnum } from '../extension';
it('should return the correct type', () => {
const extension = new AssistantExtension();
expect(extension.type()).toBe(ExtensionTypeEnum.Assistant);
});

View File

@ -0,0 +1,59 @@
import { AIEngine } from './AIEngine'
import { events } from '../../events'
import { ModelEvent, Model, ModelFile, InferenceEngine } from '../../../types'
import { EngineManager } from './EngineManager'
import { fs } from '../../fs'
jest.mock('../../events')
jest.mock('./EngineManager')
jest.mock('../../fs')
class TestAIEngine extends AIEngine {
onUnload(): void {}
provider = 'test-provider'
inference(data: any) {}
stopInference() {}
}
describe('AIEngine', () => {
let engine: TestAIEngine
beforeEach(() => {
engine = new TestAIEngine('', '')
jest.clearAllMocks()
})
it('should load model if provider matches', async () => {
const model: ModelFile = { id: 'model1', engine: 'test-provider' } as any
await engine.loadModel(model)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
})
it('should not load model if provider does not match', async () => {
const model: ModelFile = { id: 'model1', engine: 'other-provider' } as any
await engine.loadModel(model)
expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
})
it('should unload model if provider matches', async () => {
const model: Model = { id: 'model1', version: '1.0', engine: 'test-provider' } as any
await engine.unloadModel(model)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, model)
})
it('should not unload model if provider does not match', async () => {
const model: Model = { id: 'model1', version: '1.0', engine: 'other-provider' } as any
await engine.unloadModel(model)
expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, model)
})
})

View File

@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { MessageRequest, Model, ModelEvent } from '../../../types'
import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
import { EngineManager } from './EngineManager'
/**
@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
override onLoad() {
this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
async loadModel(model: ModelFile): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()

View File

@ -0,0 +1,43 @@
/**
* @jest-environment jsdom
*/
import { EngineManager } from './EngineManager'
import { AIEngine } from './AIEngine'
// @ts-ignore
class MockAIEngine implements AIEngine {
provider: string
constructor(provider: string) {
this.provider = provider
}
}
describe('EngineManager', () => {
let engineManager: EngineManager
beforeEach(() => {
engineManager = new EngineManager()
})
test('should register an engine', () => {
const engine = new MockAIEngine('testProvider')
// @ts-ignore
engineManager.register(engine)
expect(engineManager.engines.get('testProvider')).toBe(engine)
})
test('should retrieve a registered engine by provider', () => {
const engine = new MockAIEngine('testProvider')
// @ts-ignore
engineManager.register(engine)
// @ts-ignore
const retrievedEngine = engineManager.get<MockAIEngine>('testProvider')
expect(retrievedEngine).toBe(engine)
})
test('should return undefined for an unregistered provider', () => {
// @ts-ignore
const retrievedEngine = engineManager.get<MockAIEngine>('nonExistentProvider')
expect(retrievedEngine).toBeUndefined()
})
})

View File

@ -0,0 +1,100 @@
/**
* @jest-environment jsdom
*/
import { LocalOAIEngine } from './LocalOAIEngine'
import { events } from '../../events'
import { ModelEvent, ModelFile, Model } from '../../../types'
import { executeOnMain, systemInformation, dirName } from '../../core'
jest.mock('../../core', () => ({
executeOnMain: jest.fn(),
systemInformation: jest.fn(),
dirName: jest.fn(),
}))
jest.mock('../../events', () => ({
events: {
on: jest.fn(),
emit: jest.fn(),
},
}))
class TestLocalOAIEngine extends LocalOAIEngine {
inferenceUrl = ''
nodeModule = 'testNodeModule'
provider = 'testProvider'
}
describe('LocalOAIEngine', () => {
let engine: TestLocalOAIEngine
beforeEach(() => {
engine = new TestLocalOAIEngine('', '')
})
afterEach(() => {
jest.clearAllMocks()
})
it('should subscribe to events on load', () => {
engine.onLoad()
expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelInit, expect.any(Function))
expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelStop, expect.any(Function))
})
it('should load model correctly', async () => {
const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any
const modelFolder = 'path/to'
const systemInfo = { os: 'testOS' }
const res = { error: null }
;(dirName as jest.Mock).mockResolvedValue(modelFolder)
;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
;(executeOnMain as jest.Mock).mockResolvedValue(res)
await engine.loadModel(model)
expect(dirName).toHaveBeenCalledWith(model.file_path)
expect(systemInformation).toHaveBeenCalled()
expect(executeOnMain).toHaveBeenCalledWith(
engine.nodeModule,
engine.loadModelFunctionName,
{ modelFolder, model },
systemInfo
)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
})
it('should handle load model error', async () => {
const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any
const modelFolder = 'path/to'
const systemInfo = { os: 'testOS' }
const res = { error: 'load error' }
;(dirName as jest.Mock).mockResolvedValue(modelFolder)
;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
;(executeOnMain as jest.Mock).mockResolvedValue(res)
await expect(engine.loadModel(model)).rejects.toEqual('load error')
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelFail, { error: res.error })
})
it('should unload model correctly', async () => {
const model: Model = { engine: 'testProvider' } as any
await engine.unloadModel(model)
expect(executeOnMain).toHaveBeenCalledWith(engine.nodeModule, engine.unloadModelFunctionName)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, {})
})
it('should not unload model if engine does not match', async () => {
const model: Model = { engine: 'otherProvider' } as any
await engine.unloadModel(model)
expect(executeOnMain).not.toHaveBeenCalled()
expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, {})
})
})

View File

@ -1,6 +1,6 @@
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
import { executeOnMain, systemInformation, dirName } from '../../core'
import { events } from '../../events'
import { Model, ModelEvent } from '../../../types'
import { Model, ModelEvent, ModelFile } from '../../../types'
import { OAIEngine } from './OAIEngine'
/**
@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
unloadModelFunctionName: string = 'unloadModel'
/**
* On extension load, subscribe to events.
* This class represents a base for local inference providers in the OpenAI architecture.
* It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
* The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
* The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
*/
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
override async loadModel(model: Model): Promise<void> {
override async loadModel(model: ModelFile): Promise<void> {
if (model.engine.toString() !== this.provider) return
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const modelFolder = await dirName(model.file_path)
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,

View File

@ -0,0 +1,119 @@
/**
* @jest-environment jsdom
*/
import { OAIEngine } from './OAIEngine'
import { events } from '../../events'
import {
MessageEvent,
InferenceEvent,
MessageRequest,
MessageRequestType,
MessageStatus,
ChatCompletionRole,
ContentType,
} from '../../../types'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
jest.mock('./helpers/sse')
jest.mock('ulidx')
jest.mock('../../events')
class TestOAIEngine extends OAIEngine {
inferenceUrl = 'http://test-inference-url'
provider = 'test-provider'
async headers() {
return { Authorization: 'Bearer test-token' }
}
}
describe('OAIEngine', () => {
let engine: TestOAIEngine
beforeEach(() => {
engine = new TestOAIEngine('', '')
jest.clearAllMocks()
})
it('should subscribe to events on load', () => {
engine.onLoad()
expect(events.on).toHaveBeenCalledWith(MessageEvent.OnMessageSent, expect.any(Function))
expect(events.on).toHaveBeenCalledWith(InferenceEvent.OnInferenceStopped, expect.any(Function))
})
it('should handle inference request', async () => {
const data: MessageRequest = {
model: { engine: 'test-provider', id: 'test-model' } as any,
threadId: 'test-thread',
type: MessageRequestType.Thread,
assistantId: 'test-assistant',
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
}
;(ulid as jest.Mock).mockReturnValue('test-id')
;(requestInference as jest.Mock).mockReturnValue({
subscribe: ({ next, complete }: any) => {
next('test response')
complete()
},
})
await engine.inference(data)
expect(requestInference).toHaveBeenCalledWith(
'http://test-inference-url',
expect.objectContaining({ model: 'test-model' }),
expect.any(Object),
expect.any(AbortController),
{ Authorization: 'Bearer test-token' },
undefined
)
expect(events.emit).toHaveBeenCalledWith(
MessageEvent.OnMessageResponse,
expect.objectContaining({ id: 'test-id' })
)
expect(events.emit).toHaveBeenCalledWith(
MessageEvent.OnMessageUpdate,
expect.objectContaining({
content: [{ type: ContentType.Text, text: { value: 'test response', annotations: [] } }],
status: MessageStatus.Ready,
})
)
})
it('should handle inference error', async () => {
const data: MessageRequest = {
model: { engine: 'test-provider', id: 'test-model' } as any,
threadId: 'test-thread',
type: MessageRequestType.Thread,
assistantId: 'test-assistant',
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
}
;(ulid as jest.Mock).mockReturnValue('test-id')
;(requestInference as jest.Mock).mockReturnValue({
subscribe: ({ error }: any) => {
error({ message: 'test error', code: 500 })
},
})
await engine.inference(data)
expect(events.emit).toHaveBeenCalledWith(
MessageEvent.OnMessageUpdate,
expect.objectContaining({
content: [{ type: ContentType.Text, text: { value: 'test error', annotations: [] } }],
status: MessageStatus.Error,
error_code: 500,
})
)
})
it('should stop inference', () => {
engine.stopInference()
expect(engine.isCancelled).toBe(true)
expect(engine.controller.signal.aborted).toBe(true)
})
})

View File

@ -0,0 +1,43 @@
/**
* @jest-environment jsdom
*/
import { RemoteOAIEngine } from './'
class TestRemoteOAIEngine extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'TestRemoteOAIEngine'
}
describe('RemoteOAIEngine', () => {
let engine: TestRemoteOAIEngine
beforeEach(() => {
engine = new TestRemoteOAIEngine('', '')
})
test('should call onLoad and super.onLoad', () => {
const onLoadSpy = jest.spyOn(engine, 'onLoad')
const superOnLoadSpy = jest.spyOn(Object.getPrototypeOf(RemoteOAIEngine.prototype), 'onLoad')
engine.onLoad()
expect(onLoadSpy).toHaveBeenCalled()
expect(superOnLoadSpy).toHaveBeenCalled()
})
test('should return headers with apiKey', async () => {
engine.apiKey = 'test-api-key'
const headers = await engine.headers()
expect(headers).toEqual({
'Authorization': 'Bearer test-api-key',
'api-key': 'test-api-key',
})
})
test('should return empty headers when apiKey is not set', async () => {
engine.apiKey = undefined
const headers = await engine.headers()
expect(headers).toEqual({})
})
})

View File

@ -1,6 +1,7 @@
import { lastValueFrom, Observable } from 'rxjs'
import { requestInference } from './sse'
import { ReadableStream } from 'stream/web';
describe('requestInference', () => {
it('should send a request to the inference server and return an Observable', () => {
// Mock the fetch function
@ -58,3 +59,66 @@ describe('requestInference', () => {
expect(lastValueFrom(result)).rejects.toEqual({ message: 'Wrong API Key', code: 'invalid_api_key' })
})
})
it('should handle a successful response with a transformResponse function', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: true,
json: () => Promise.resolve({ choices: [{ message: { content: 'Generated response' } }] }),
headers: new Headers(),
redirected: false,
status: 200,
statusText: 'OK',
})
)
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
// Define the test inputs
const inferenceUrl = 'https://inference-server.com'
const requestBody = { message: 'Hello' }
const model = { id: 'model-id', parameters: { stream: false } }
const transformResponse = (data: any) => data.choices[0].message.content.toUpperCase()
// Call the function
const result = requestInference(inferenceUrl, requestBody, model, undefined, undefined, transformResponse)
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable)
expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE')
})
it('should handle a successful response with streaming enabled', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: true,
body: new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode('data: {"choices": [{"delta": {"content": "Streamed"}}]}'));
controller.enqueue(new TextEncoder().encode('data: [DONE]'));
controller.close();
}
}),
headers: new Headers(),
redirected: false,
status: 200,
statusText: 'OK',
})
);
jest.spyOn(global, 'fetch').mockImplementation(mockFetch);
// Define the test inputs
const inferenceUrl = 'https://inference-server.com';
const requestBody = { message: 'Hello' };
const model = { id: 'model-id', parameters: { stream: true } };
// Call the function
const result = requestInference(inferenceUrl, requestBody, model);
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable);
expect(lastValueFrom(result)).resolves.toEqual('Streamed');
});

View File

@ -0,0 +1,6 @@
import { expect } from '@jest/globals';
it('should re-export all exports from ./AIEngine', () => {
expect(require('./index')).toHaveProperty('AIEngine');
});

View File

@ -0,0 +1,32 @@
import { ConversationalExtension } from './index';
import { InferenceExtension } from './index';
import { MonitoringExtension } from './index';
import { AssistantExtension } from './index';
import { ModelExtension } from './index';
import * as Engines from './index';
describe('index.ts exports', () => {
test('should export ConversationalExtension', () => {
expect(ConversationalExtension).toBeDefined();
});
test('should export InferenceExtension', () => {
expect(InferenceExtension).toBeDefined();
});
test('should export MonitoringExtension', () => {
expect(MonitoringExtension).toBeDefined();
});
test('should export AssistantExtension', () => {
expect(AssistantExtension).toBeDefined();
});
test('should export ModelExtension', () => {
expect(ModelExtension).toBeDefined();
});
test('should export Engines', () => {
expect(Engines).toBeDefined();
});
});

View File

@ -0,0 +1,45 @@
import { MessageRequest, ThreadMessage } from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'
import { InferenceExtension } from './'
// Mock the MessageRequest and ThreadMessage types
type MockMessageRequest = {
text: string
}
type MockThreadMessage = {
text: string
userId: string
}
// Mock the BaseExtension class
class MockBaseExtension extends BaseExtension {
type(): ExtensionTypeEnum | undefined {
return ExtensionTypeEnum.Base
}
}
// Create a mock implementation of InferenceExtension
class MockInferenceExtension extends InferenceExtension {
async inference(data: MessageRequest): Promise<ThreadMessage> {
return { text: 'Mock response', userId: '123' } as unknown as ThreadMessage
}
}
describe('InferenceExtension', () => {
let inferenceExtension: InferenceExtension
beforeEach(() => {
inferenceExtension = new MockInferenceExtension()
})
it('should have the correct type', () => {
expect(inferenceExtension.type()).toBe(ExtensionTypeEnum.Inference)
})
it('should implement the inference method', async () => {
const messageRequest: MessageRequest = { text: 'Hello' } as unknown as MessageRequest
const result = await inferenceExtension.inference(messageRequest)
expect(result).toEqual({ text: 'Mock response', userId: '123' } as unknown as ThreadMessage)
})
})

View File

@ -4,6 +4,7 @@ import {
HuggingFaceRepoData,
ImportingModel,
Model,
ModelFile,
ModelInterface,
OptionType,
} from '../../types'
@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
network?: { proxy: string; ignoreSSL?: boolean }
): Promise<void>
abstract cancelModelDownload(modelId: string): Promise<void>
abstract deleteModel(modelId: string): Promise<void>
abstract saveModel(model: Model): Promise<void>
abstract getDownloadedModels(): Promise<Model[]>
abstract getConfiguredModels(): Promise<Model[]>
abstract deleteModel(model: ModelFile): Promise<void>
abstract getDownloadedModels(): Promise<ModelFile[]>
abstract getConfiguredModels(): Promise<ModelFile[]>
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model>
abstract updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile>
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
abstract getDefaultModel(): Promise<Model>
}

View File

@ -0,0 +1,42 @@
import { ExtensionTypeEnum } from '../extension';
import { MonitoringExtension } from './monitoring';
it('should have the correct type', () => {
class TestMonitoringExtension extends MonitoringExtension {
getGpuSetting(): Promise<GpuSetting | undefined> {
throw new Error('Method not implemented.');
}
getResourcesInfo(): Promise<any> {
throw new Error('Method not implemented.');
}
getCurrentLoad(): Promise<any> {
throw new Error('Method not implemented.');
}
getOsInfo(): Promise<OperatingSystemInfo> {
throw new Error('Method not implemented.');
}
}
const monitoringExtension = new TestMonitoringExtension();
expect(monitoringExtension.type()).toBe(ExtensionTypeEnum.SystemMonitoring);
});
it('should create an instance of MonitoringExtension', () => {
class TestMonitoringExtension extends MonitoringExtension {
getGpuSetting(): Promise<GpuSetting | undefined> {
throw new Error('Method not implemented.');
}
getResourcesInfo(): Promise<any> {
throw new Error('Method not implemented.');
}
getCurrentLoad(): Promise<any> {
throw new Error('Method not implemented.');
}
getOsInfo(): Promise<OperatingSystemInfo> {
throw new Error('Method not implemented.');
}
}
const monitoringExtension = new TestMonitoringExtension();
expect(monitoringExtension).toBeInstanceOf(MonitoringExtension);
});

View File

@ -0,0 +1,97 @@
import { fs } from './fs'
describe('fs module', () => {
beforeEach(() => {
globalThis.core = {
api: {
writeFileSync: jest.fn(),
writeBlob: jest.fn(),
readFileSync: jest.fn(),
existsSync: jest.fn(),
readdirSync: jest.fn(),
mkdir: jest.fn(),
rm: jest.fn(),
unlinkSync: jest.fn(),
appendFileSync: jest.fn(),
copyFile: jest.fn(),
getGgufFiles: jest.fn(),
fileStat: jest.fn(),
},
}
})
it('should call writeFileSync with correct arguments', () => {
const args = ['path/to/file', 'data']
fs.writeFileSync(...args)
expect(globalThis.core.api.writeFileSync).toHaveBeenCalledWith(...args)
})
it('should call writeBlob with correct arguments', async () => {
const path = 'path/to/file'
const data = 'blob data'
await fs.writeBlob(path, data)
expect(globalThis.core.api.writeBlob).toHaveBeenCalledWith(path, data)
})
it('should call readFileSync with correct arguments', () => {
const args = ['path/to/file']
fs.readFileSync(...args)
expect(globalThis.core.api.readFileSync).toHaveBeenCalledWith(...args)
})
it('should call existsSync with correct arguments', () => {
const args = ['path/to/file']
fs.existsSync(...args)
expect(globalThis.core.api.existsSync).toHaveBeenCalledWith(...args)
})
it('should call readdirSync with correct arguments', () => {
const args = ['path/to/directory']
fs.readdirSync(...args)
expect(globalThis.core.api.readdirSync).toHaveBeenCalledWith(...args)
})
it('should call mkdir with correct arguments', () => {
const args = ['path/to/directory']
fs.mkdir(...args)
expect(globalThis.core.api.mkdir).toHaveBeenCalledWith(...args)
})
it('should call rm with correct arguments', () => {
const args = ['path/to/directory']
fs.rm(...args)
expect(globalThis.core.api.rm).toHaveBeenCalledWith(...args, { recursive: true, force: true })
})
it('should call unlinkSync with correct arguments', () => {
const args = ['path/to/file']
fs.unlinkSync(...args)
expect(globalThis.core.api.unlinkSync).toHaveBeenCalledWith(...args)
})
it('should call appendFileSync with correct arguments', () => {
const args = ['path/to/file', 'data']
fs.appendFileSync(...args)
expect(globalThis.core.api.appendFileSync).toHaveBeenCalledWith(...args)
})
it('should call copyFile with correct arguments', async () => {
const src = 'path/to/src'
const dest = 'path/to/dest'
await fs.copyFile(src, dest)
expect(globalThis.core.api.copyFile).toHaveBeenCalledWith(src, dest)
})
it('should call getGgufFiles with correct arguments', async () => {
const paths = ['path/to/file1', 'path/to/file2']
await fs.getGgufFiles(paths)
expect(globalThis.core.api.getGgufFiles).toHaveBeenCalledWith(paths)
})
it('should call fileStat with correct arguments', async () => {
const path = 'path/to/file'
const outsideJanDataFolder = true
await fs.fileStat(path, outsideJanDataFolder)
expect(globalThis.core.api.fileStat).toHaveBeenCalledWith(path, outsideJanDataFolder)
})
})

View File

@ -0,0 +1,5 @@
it('should not throw any errors when imported', () => {
expect(() => require('./index')).not.toThrow();
})

View File

@ -0,0 +1,63 @@
import { ToolManager } from '../../browser/tools/manager'
import { InferenceTool } from '../../browser/tools/tool'
import { AssistantTool, MessageRequest } from '../../types'
class MockInferenceTool implements InferenceTool {
name = 'mockTool'
process(request: MessageRequest, tool: AssistantTool): Promise<MessageRequest> {
return Promise.resolve(request)
}
}
it('should register a tool', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
expect(manager.get(tool.name)).toBe(tool)
})
it('should retrieve a tool by its name', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const retrievedTool = manager.get(tool.name)
expect(retrievedTool).toBe(tool)
})
it('should return undefined for a non-existent tool', () => {
const manager = new ToolManager()
const retrievedTool = manager.get('nonExistentTool')
expect(retrievedTool).toBeUndefined()
})
it('should process the message request with enabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should skip processing for disabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should throw an error when process is called without implementation', () => {
class TestTool extends InferenceTool {
name = 'testTool'
}
const tool = new TestTool()
expect(() => tool.process({} as MessageRequest)).toThrowError()
})

7
core/src/index.test.ts Normal file
View File

@ -0,0 +1,7 @@
it('should declare global object core when importing the module and then deleting it', () => {
import('./index');
delete globalThis.core;
expect(typeof globalThis.core).toBe('undefined');
});

View File

@ -0,0 +1,7 @@
import * as restfulV1 from './restful/v1';
it('should re-export from restful/v1', () => {
const restfulV1Exports = require('./restful/v1');
expect(restfulV1Exports).toBeDefined();
})

View File

@ -0,0 +1,6 @@
import { Processor } from './Processor';
it('should be defined', () => {
expect(Processor).toBeDefined();
});

View File

@ -1,40 +1,57 @@
import { App } from './app';
jest.mock('../../helper', () => ({
...jest.requireActual('../../helper'),
getJanDataFolderPath: () => './app',
}))
import { dirname } from 'path'
import { App } from './app'
it('should call stopServer', () => {
const app = new App();
const stopServerMock = jest.fn().mockResolvedValue('Server stopped');
const app = new App()
const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
jest.mock('@janhq/server', () => ({
stopServer: stopServerMock
}));
const result = app.stopServer();
expect(stopServerMock).toHaveBeenCalled();
});
stopServer: stopServerMock,
}))
app.stopServer()
expect(stopServerMock).toHaveBeenCalled()
})
it('should correctly retrieve basename', () => {
const app = new App();
const result = app.baseName('/path/to/file.txt');
expect(result).toBe('file.txt');
});
const app = new App()
const result = app.baseName('/path/to/file.txt')
expect(result).toBe('file.txt')
})
it('should correctly identify subdirectories', () => {
const app = new App();
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to';
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir';
const result = app.isSubdirectory(basePath, subPath);
expect(result).toBe(true);
});
const app = new App()
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
const result = app.isSubdirectory(basePath, subPath)
expect(result).toBe(true)
})
it('should correctly join multiple paths', () => {
const app = new App();
const result = app.joinPath(['path', 'to', 'file']);
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file';
expect(result).toBe(expectedPath);
});
const app = new App()
const result = app.joinPath(['path', 'to', 'file'])
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
expect(result).toBe(expectedPath)
})
it('should call correct function with provided arguments using process method', () => {
const app = new App();
const mockFunc = jest.fn();
app.joinPath = mockFunc;
app.process('joinPath', ['path1', 'path2']);
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']);
});
const app = new App()
const mockFunc = jest.fn()
app.joinPath = mockFunc
app.process('joinPath', ['path1', 'path2'])
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
})
it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
const app = new App()
const path = 'C:/Users/John Doe/Desktop/file.txt'
expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
})
it('should retrieve the directory name when using file protocol', async () => {
const app = new App()
const path = 'file:/models/file.txt'
expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
})

View File

@ -1,4 +1,4 @@
import { basename, isAbsolute, join, relative } from 'path'
import { basename, dirname, isAbsolute, join, relative } from 'path'
import { Processor } from './Processor'
import {
@ -6,6 +6,8 @@ import {
appResourcePath,
getAppConfigurations as appConfiguration,
updateAppConfiguration,
normalizeFilePath,
getJanDataFolderPath,
} from '../../helper'
export class App implements Processor {
@ -28,6 +30,18 @@ export class App implements Processor {
return join(...args)
}
/**
* Get dirname of a file path.
* @param path - The file path to retrieve dirname.
*/
dirName(path: string) {
const arg =
path.startsWith(`file:/`) || path.startsWith(`file:\\`)
? join(getJanDataFolderPath(), normalizeFilePath(path))
: path
return dirname(arg)
}
/**
* Checks if the given path is a subdirectory of the given directory.
*

View File

@ -1,59 +1,131 @@
import { Downloader } from './download';
import { DownloadEvent } from '../../../types/api';
import { DownloadManager } from '../../helper/download';
import { Downloader } from './download'
import { DownloadEvent } from '../../../types/api'
import { DownloadManager } from '../../helper/download'
it('should handle getFileSize errors correctly', async () => {
const observer = jest.fn();
const url = 'http://example.com/file';
jest.mock('../../helper', () => ({
getJanDataFolderPath: jest.fn().mockReturnValue('path/to/folder'),
}))
const downloader = new Downloader(observer);
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null);
});
jest.mock('request', () => requestMock);
jest.mock('../../helper/path', () => ({
validatePath: jest.fn().mockReturnValue('path/to/folder'),
normalizeFilePath: () => process.platform === 'win32' ? 'C:\\Users\path\\to\\file.gguf' : '/Users/path/to/file.gguf',
}))
await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error');
});
jest.mock(
'request',
jest.fn().mockReturnValue(() => ({
on: jest.fn(),
}))
)
jest.mock('fs', () => ({
createWriteStream: jest.fn(),
}))
it('should pause download correctly', () => {
const observer = jest.fn();
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
jest.mock('request-progress', () => {
return jest.fn().mockImplementation(() => {
return {
on: jest.fn().mockImplementation((event, callback) => {
if (event === 'error') {
callback(new Error('Download failed'))
}
return {
on: jest.fn().mockImplementation((event, callback) => {
if (event === 'error') {
callback(new Error('Download failed'))
}
return {
on: jest.fn().mockImplementation((event, callback) => {
if (event === 'error') {
callback(new Error('Download failed'))
}
return { pipe: jest.fn() }
}),
}
}),
}
}),
}
})
})
const downloader = new Downloader(observer);
const pauseMock = jest.fn();
DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock };
describe('Downloader', () => {
beforeEach(() => {
jest.resetAllMocks()
})
it('should handle getFileSize errors correctly', async () => {
const observer = jest.fn()
const url = 'http://example.com/file'
downloader.pauseDownload(observer, fileName);
const downloader = new Downloader(observer)
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)
expect(pauseMock).toHaveBeenCalled();
});
await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
})
it('should resume download correctly', () => {
const observer = jest.fn();
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
it('should pause download correctly', () => {
const observer = jest.fn()
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
const downloader = new Downloader(observer);
const resumeMock = jest.fn();
DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock };
const downloader = new Downloader(observer)
const pauseMock = jest.fn()
DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock }
downloader.resumeDownload(observer, fileName);
downloader.pauseDownload(observer, fileName)
expect(resumeMock).toHaveBeenCalled();
});
expect(pauseMock).toHaveBeenCalled()
})
it('should handle aborting a download correctly', () => {
const observer = jest.fn();
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
it('should resume download correctly', () => {
const observer = jest.fn()
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
const downloader = new Downloader(observer);
const abortMock = jest.fn();
DownloadManager.instance.networkRequests[fileName] = { abort: abortMock };
const downloader = new Downloader(observer)
const resumeMock = jest.fn()
DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock }
downloader.abortDownload(observer, fileName);
downloader.resumeDownload(observer, fileName)
expect(abortMock).toHaveBeenCalled();
expect(observer).toHaveBeenCalledWith(DownloadEvent.onFileDownloadError, expect.objectContaining({
error: 'aborted'
}));
});
expect(resumeMock).toHaveBeenCalled()
})
it('should handle aborting a download correctly', () => {
const observer = jest.fn()
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
const downloader = new Downloader(observer)
const abortMock = jest.fn()
DownloadManager.instance.networkRequests[fileName] = { abort: abortMock }
downloader.abortDownload(observer, fileName)
expect(abortMock).toHaveBeenCalled()
expect(observer).toHaveBeenCalledWith(
DownloadEvent.onFileDownloadError,
expect.objectContaining({
error: 'aborted',
})
)
})
it('should handle download fail correctly', () => {
const observer = jest.fn()
const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file.gguf'
const downloader = new Downloader(observer)
downloader.downloadFile(observer, {
localPath: fileName,
url: 'http://127.0.0.1',
})
expect(observer).toHaveBeenCalledWith(
DownloadEvent.onFileDownloadError,
expect.objectContaining({
error: expect.anything(),
})
)
})
})

View File

@ -34,7 +34,7 @@ export class Downloader implements Processor {
}
const array = normalizedPath.split(sep)
const fileName = array.pop() ?? ''
const modelId = array.pop() ?? ''
const modelId = downloadRequest.modelId ?? array.pop() ?? ''
const destination = resolve(getJanDataFolderPath(), normalizedPath)
validatePath(destination)
@ -100,7 +100,11 @@ export class Downloader implements Processor {
})
.on('end', () => {
const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId]
if (currentDownloadState && DownloadManager.instance.networkRequests[normalizedPath]) {
if (
currentDownloadState &&
DownloadManager.instance.networkRequests[normalizedPath] &&
DownloadManager.instance.downloadProgressMap[modelId]?.downloadState !== 'error'
) {
// Finished downloading, rename temp file to actual file
renameSync(downloadingTempFile, destination)
const downloadState: DownloadState = {

View File

@ -7,3 +7,34 @@ it('should call function associated with key in process method', () => {
extension.process('testKey', 'arg1', 'arg2');
expect(mockFunc).toHaveBeenCalledWith('arg1', 'arg2');
});
it('should_handle_empty_extension_list_for_install', async () => {
jest.mock('../../extension/store', () => ({
installExtensions: jest.fn(() => Promise.resolve([])),
}));
const extension = new Extension();
const result = await extension.installExtension([]);
expect(result).toEqual([]);
});
it('should_handle_empty_extension_list_for_update', async () => {
jest.mock('../../extension/store', () => ({
getExtension: jest.fn(() => ({ update: jest.fn(() => Promise.resolve(true)) })),
}));
const extension = new Extension();
const result = await extension.updateExtension([]);
expect(result).toEqual([]);
});
it('should_handle_empty_extension_list', async () => {
jest.mock('../../extension/store', () => ({
getExtension: jest.fn(() => ({ uninstall: jest.fn(() => Promise.resolve(true)) })),
removeExtension: jest.fn(),
}));
const extension = new Extension();
const result = await extension.uninstallExtension([]);
expect(result).toBe(true);
});

View File

@ -0,0 +1,305 @@
import {
existsSync,
readdirSync,
readFileSync,
writeFileSync,
mkdirSync,
appendFileSync,
rmdirSync,
} from 'fs'
import { join } from 'path'
import {
getBuilder,
retrieveBuilder,
deleteBuilder,
getMessages,
retrieveMessage,
createThread,
updateThread,
createMessage,
downloadModel,
chatCompletions,
} from './builder'
import { RouteConfiguration } from './configuration'
jest.mock('fs')
jest.mock('path')
jest.mock('../../../helper', () => ({
getEngineConfiguration: jest.fn(),
getJanDataFolderPath: jest.fn().mockReturnValue('/mock/path'),
}))
jest.mock('request')
jest.mock('request-progress')
jest.mock('node-fetch')
describe('builder helper functions', () => {
const mockConfiguration: RouteConfiguration = {
dirName: 'mockDir',
metadataFileName: 'metadata.json',
delete: {
object: 'mockObject',
},
}
beforeEach(() => {
jest.clearAllMocks()
})
describe('getBuilder', () => {
it('should return an empty array if directory does not exist', async () => {
;(existsSync as jest.Mock).mockReturnValue(false)
const result = await getBuilder(mockConfiguration)
expect(result).toEqual([])
})
it('should return model data if directory exists', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await getBuilder(mockConfiguration)
expect(result).toEqual([{ id: 'model1' }])
})
})
describe('retrieveBuilder', () => {
it('should return undefined if no data matches the id', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await retrieveBuilder(mockConfiguration, 'nonexistentId')
expect(result).toBeUndefined()
})
it('should return the matching data', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await retrieveBuilder(mockConfiguration, 'model1')
expect(result).toEqual({ id: 'model1' })
})
})
describe('deleteBuilder', () => {
it('should return a message if trying to delete Jan assistant', async () => {
const result = await deleteBuilder({ ...mockConfiguration, dirName: 'assistants' }, 'jan')
expect(result).toEqual({ message: 'Cannot delete Jan assistant' })
})
it('should return a message if data is not found', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await deleteBuilder(mockConfiguration, 'nonexistentId')
expect(result).toEqual({ message: 'Not found' })
})
it('should delete the directory and return success message', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await deleteBuilder(mockConfiguration, 'model1')
expect(rmdirSync).toHaveBeenCalledWith(join('/mock/path', 'mockDir', 'model1'), {
recursive: true,
})
expect(result).toEqual({ id: 'model1', object: 'mockObject', deleted: true })
})
})
describe('getMessages', () => {
it('should return an empty array if message file does not exist', async () => {
;(existsSync as jest.Mock).mockReturnValue(false)
const result = await getMessages('thread1')
expect(result).toEqual([])
})
it('should return messages if message file exists', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n{"id":"msg2"}\n')
const result = await getMessages('thread1')
expect(result).toEqual([{ id: 'msg1' }, { id: 'msg2' }])
})
})
describe('retrieveMessage', () => {
it('should return a message if no messages match the id', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n')
const result = await retrieveMessage('thread1', 'nonexistentId')
expect(result).toEqual({ message: 'Not found' })
})
it('should return the matching message', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n')
const result = await retrieveMessage('thread1', 'msg1')
expect(result).toEqual({ id: 'msg1' })
})
})
describe('createThread', () => {
it('should return a message if thread has no assistants', async () => {
const result = await createThread({})
expect(result).toEqual({ message: 'Thread must have at least one assistant' })
})
it('should create a thread and return the updated thread', async () => {
;(existsSync as jest.Mock).mockReturnValue(false)
const thread = { assistants: [{ assistant_id: 'assistant1' }] }
const result = await createThread(thread)
expect(mkdirSync).toHaveBeenCalled()
expect(writeFileSync).toHaveBeenCalled()
expect(result.id).toBeDefined()
})
})
describe('updateThread', () => {
it('should return a message if thread is not found', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await updateThread('nonexistentId', {})
expect(result).toEqual({ message: 'Thread not found' })
})
it('should update the thread and return the updated thread', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await updateThread('model1', { name: 'updatedName' })
expect(writeFileSync).toHaveBeenCalled()
expect(result.name).toEqual('updatedName')
})
})
describe('createMessage', () => {
it('should create a message and return the created message', async () => {
;(existsSync as jest.Mock).mockReturnValue(false)
const message = { role: 'user', content: 'Hello' }
const result = (await createMessage('thread1', message)) as any
expect(mkdirSync).toHaveBeenCalled()
expect(appendFileSync).toHaveBeenCalled()
expect(result.id).toBeDefined()
})
})
describe('downloadModel', () => {
it('should return a message if model is not found', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
const result = await downloadModel('nonexistentId')
expect(result).toEqual({ message: 'Model not found' })
})
it('should start downloading the model', async () => {
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(
JSON.stringify({ id: 'model1', object: 'model', sources: ['http://example.com'] })
)
const result = await downloadModel('model1')
expect(result).toEqual({ message: 'Starting download model1' })
})
})
describe('chatCompletions', () => {
it('should return an error if model is not found', async () => {
const request = { body: { model: 'nonexistentModel' } }
const reply = { code: jest.fn().mockReturnThis(), send: jest.fn() }
await chatCompletions(request, reply)
expect(reply.code).toHaveBeenCalledWith(404)
expect(reply.send).toHaveBeenCalledWith({
error: {
message: 'The model nonexistentModel does not exist',
type: 'invalid_request_error',
param: null,
code: 'model_not_found',
},
})
})
it('should return the error on status not ok', async () => {
const request = { body: { model: 'model1' } }
const mockSend = jest.fn()
const reply = {
code: jest.fn().mockReturnThis(),
send: jest.fn(),
headers: jest.fn().mockReturnValue({
send: mockSend,
}),
raw: {
writeHead: jest.fn(),
pipe: jest.fn(),
},
}
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(
JSON.stringify({ id: 'model1', engine: 'openai' })
)
// Mock fetch
const fetch = require('node-fetch')
fetch.mockResolvedValue({
status: 400,
headers: new Map([
['content-type', 'application/json'],
['x-request-id', '123456'],
]),
body: { pipe: jest.fn() },
text: jest.fn().mockResolvedValue({ error: 'Mock error response' }),
})
await chatCompletions(request, reply)
expect(reply.code).toHaveBeenCalledWith(400)
expect(mockSend).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Mock error response',
})
)
})
it('should return the chat completions', async () => {
const request = { body: { model: 'model1' } }
const reply = {
code: jest.fn().mockReturnThis(),
send: jest.fn(),
raw: { writeHead: jest.fn(), pipe: jest.fn() },
}
;(existsSync as jest.Mock).mockReturnValue(true)
;(readdirSync as jest.Mock).mockReturnValue(['file1'])
;(readFileSync as jest.Mock).mockReturnValue(
JSON.stringify({ id: 'model1', engine: 'openai' })
)
// Mock fetch
const fetch = require('node-fetch')
fetch.mockResolvedValue({
status: 200,
body: { pipe: jest.fn() },
json: jest.fn().mockResolvedValue({ completions: ['completion1'] }),
})
await chatCompletions(request, reply)
expect(reply.raw.writeHead).toHaveBeenCalledWith(200, expect.any(Object))
})
})
})

View File

@ -280,13 +280,13 @@ export const downloadModel = async (
for (const source of model.sources) {
const rq = request({ url: source, strictSSL, proxy })
progress(rq, {})
.on('progress', function (state: any) {
?.on('progress', function (state: any) {
console.debug('progress', JSON.stringify(state, null, 2))
})
.on('error', function (err: Error) {
?.on('error', function (err: Error) {
console.error('error', err)
})
.on('end', function () {
?.on('end', function () {
console.debug('end')
})
.pipe(createWriteStream(modelBinaryPath))
@ -353,8 +353,10 @@ export const chatCompletions = async (request: any, reply: any) => {
body: JSON.stringify(request.body),
})
if (response.status !== 200) {
console.error(response)
reply.code(400).send(response)
// Forward the error response to client via reply
const responseBody = await response.text()
const responseHeaders = Object.fromEntries(response.headers)
reply.code(response.status).headers(responseHeaders).send(responseBody)
} else {
reply.raw.writeHead(200, {
'Content-Type': request.body.stream === true ? 'text/event-stream' : 'application/json',

View File

@ -0,0 +1,6 @@
import { NITRO_DEFAULT_PORT } from './consts';
it('should test NITRO_DEFAULT_PORT', () => {
expect(NITRO_DEFAULT_PORT).toBe(3928);
});

View File

@ -0,0 +1,16 @@
import { startModel } from './startStopModel'
describe('startModel', () => {
it('test_startModel_error', async () => {
const modelId = 'testModelId'
const settingParams = undefined
const result = await startModel(modelId, settingParams)
expect(result).toEqual({
error: expect.any(Error),
})
})
})

View File

@ -0,0 +1,7 @@
import { useExtensions } from './index'
test('testUseExtensionsMissingPath', () => {
expect(() => useExtensions(undefined as any)).toThrowError('A path to the extensions folder is required to use extensions')
})

View File

@ -1,6 +1,8 @@
import { getEngineConfiguration } from './config';
import { getAppConfigurations, defaultAppConfig } from './config';
import { getJanExtensionsPath } from './config';
import { getJanDataFolderPath } from './config';
it('should return undefined for invalid engine ID', async () => {
const config = await getEngineConfiguration('invalid_engine');
expect(config).toBeUndefined();
@ -12,3 +14,15 @@ it('should return default config when CI is e2e', () => {
const config = getAppConfigurations();
expect(config).toEqual(defaultAppConfig());
});
it('should return extensions path when retrieved successfully', () => {
const extensionsPath = getJanExtensionsPath();
expect(extensionsPath).not.toBeUndefined();
});
it('should return data folder path when retrieved successfully', () => {
const dataFolderPath = getJanDataFolderPath();
expect(dataFolderPath).not.toBeUndefined();
});

View File

@ -0,0 +1,24 @@
import { NativeRoute } from '../index';
test('testNativeRouteEnum', () => {
expect(NativeRoute.openExternalUrl).toBe('openExternalUrl');
expect(NativeRoute.openAppDirectory).toBe('openAppDirectory');
expect(NativeRoute.openFileExplore).toBe('openFileExplorer');
expect(NativeRoute.selectDirectory).toBe('selectDirectory');
expect(NativeRoute.selectFiles).toBe('selectFiles');
expect(NativeRoute.relaunch).toBe('relaunch');
expect(NativeRoute.setNativeThemeLight).toBe('setNativeThemeLight');
expect(NativeRoute.setNativeThemeDark).toBe('setNativeThemeDark');
expect(NativeRoute.setMinimizeApp).toBe('setMinimizeApp');
expect(NativeRoute.setCloseApp).toBe('setCloseApp');
expect(NativeRoute.setMaximizeApp).toBe('setMaximizeApp');
expect(NativeRoute.showOpenMenu).toBe('showOpenMenu');
expect(NativeRoute.hideQuickAskWindow).toBe('hideQuickAskWindow');
expect(NativeRoute.sendQuickAskInput).toBe('sendQuickAskInput');
expect(NativeRoute.hideMainWindow).toBe('hideMainWindow');
expect(NativeRoute.showMainWindow).toBe('showMainWindow');
expect(NativeRoute.quickAskSizeUpdated).toBe('quickAskSizeUpdated');
expect(NativeRoute.ackDeepLink).toBe('ackDeepLink');
});

View File

@ -37,6 +37,7 @@ export enum AppRoute {
getAppConfigurations = 'getAppConfigurations',
updateAppConfiguration = 'updateAppConfiguration',
joinPath = 'joinPath',
dirName = 'dirName',
isSubdirectory = 'isSubdirectory',
baseName = 'baseName',
startServer = 'startServer',

View File

@ -0,0 +1,9 @@
import { AppConfigurationEventName } from './appConfigEvent';
describe('AppConfigurationEventName', () => {
it('should have the correct value for OnConfigurationUpdate', () => {
expect(AppConfigurationEventName.OnConfigurationUpdate).toBe('OnConfigurationUpdate');
});
});

View File

@ -40,6 +40,14 @@ export type DownloadRequest = {
*/
extensionId?: string
/**
* The model ID of the model that initiated the download.
*/
modelId?: string
/**
* The download type.
*/
downloadType?: DownloadType | string
}
@ -52,3 +60,18 @@ type DownloadSize = {
total: number
transferred: number
}
/**
* The file metadata
*/
export type FileMetadata = {
/**
* The origin file path.
*/
file_path: string
/**
* The file name.
*/
file_name: string
}

View File

@ -0,0 +1,28 @@
import { AllQuantizations } from './huggingfaceEntity';
test('testAllQuantizationsArray', () => {
expect(AllQuantizations).toEqual([
'Q3_K_S',
'Q3_K_M',
'Q3_K_L',
'Q4_K_S',
'Q4_K_M',
'Q5_K_S',
'Q5_K_M',
'Q4_0',
'Q4_1',
'Q5_0',
'Q5_1',
'IQ2_XXS',
'IQ2_XS',
'Q2_K',
'Q2_K_S',
'Q6_K',
'Q8_0',
'F16',
'F32',
'COPY',
]);
});

View File

@ -0,0 +1,8 @@
import * as huggingfaceEntity from './huggingfaceEntity';
import * as index from './index';
test('test_exports_from_huggingfaceEntity', () => {
expect(index).toEqual(huggingfaceEntity);
});

View File

@ -0,0 +1,28 @@
import * as assistant from './assistant';
import * as model from './model';
import * as thread from './thread';
import * as message from './message';
import * as inference from './inference';
import * as monitoring from './monitoring';
import * as file from './file';
import * as config from './config';
import * as huggingface from './huggingface';
import * as miscellaneous from './miscellaneous';
import * as api from './api';
import * as setting from './setting';
test('test_module_exports', () => {
expect(assistant).toBeDefined();
expect(model).toBeDefined();
expect(thread).toBeDefined();
expect(message).toBeDefined();
expect(inference).toBeDefined();
expect(monitoring).toBeDefined();
expect(file).toBeDefined();
expect(config).toBeDefined();
expect(huggingface).toBeDefined();
expect(miscellaneous).toBeDefined();
expect(api).toBeDefined();
expect(setting).toBeDefined();
});

View File

@ -0,0 +1,13 @@
import { ChatCompletionMessage, ChatCompletionRole } from './inferenceEntity';
test('test_chatCompletionMessage_withStringContent_andSystemRole', () => {
const message: ChatCompletionMessage = {
content: 'Hello, world!',
role: ChatCompletionRole.System,
};
expect(message.content).toBe('Hello, world!');
expect(message.role).toBe(ChatCompletionRole.System);
});

View File

@ -0,0 +1,7 @@
import { InferenceEvent } from './inferenceEvent';
test('testInferenceEventEnumContainsOnInferenceStopped', () => {
expect(InferenceEvent.OnInferenceStopped).toBe('OnInferenceStopped');
});

View File

@ -0,0 +1,9 @@
import { MessageStatus } from './messageEntity';
it('should have correct values', () => {
expect(MessageStatus.Ready).toBe('ready');
expect(MessageStatus.Pending).toBe('pending');
expect(MessageStatus.Error).toBe('error');
expect(MessageStatus.Stopped).toBe('stopped');
})

View File

@ -0,0 +1,7 @@
import { MessageEvent } from './messageEvent';
test('testOnMessageSentValue', () => {
expect(MessageEvent.OnMessageSent).toBe('OnMessageSent');
});

View File

@ -0,0 +1,7 @@
import { MessageRequestType } from './messageRequestType';
test('testMessageRequestTypeEnumContainsThread', () => {
expect(MessageRequestType.Thread).toBe('Thread');
});

View File

@ -0,0 +1,6 @@
import { SupportedPlatforms } from './systemResourceInfo';
it('should contain the correct values', () => {
expect(SupportedPlatforms).toEqual(['win32', 'linux', 'darwin']);
});

View File

@ -0,0 +1,30 @@
import { Model, ModelSettingParams, ModelRuntimeParams, InferenceEngine } from '../model'
test('testValidModelCreation', () => {
const model: Model = {
object: 'model',
version: '1.0',
format: 'format1',
sources: [{ filename: 'model.bin', url: 'http://example.com/model.bin' }],
id: 'model1',
name: 'Test Model',
created: Date.now(),
description: 'A cool model from Huggingface',
settings: { ctx_len: 100, ngl: 50, embedding: true },
parameters: { temperature: 0.5, token_limit: 100, top_k: 10 },
metadata: { author: 'Author', tags: ['tag1', 'tag2'], size: 100 },
engine: InferenceEngine.anthropic
};
expect(model).toBeDefined();
expect(model.object).toBe('model');
expect(model.version).toBe('1.0');
expect(model.sources).toHaveLength(1);
expect(model.sources[0].filename).toBe('model.bin');
expect(model.settings).toBeDefined();
expect(model.parameters).toBeDefined();
expect(model.metadata).toBeDefined();
expect(model.engine).toBe(InferenceEngine.anthropic);
});

View File

@ -1,3 +1,5 @@
import { FileMetadata } from '../file'
/**
* Represents the information about a model.
* @stored
@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
export type ModelInitFailed = Model & {
error: Error
}
/**
* ModelFile is the model.json entity and it's file metadata
*/
export type ModelFile = Model & FileMetadata

View File

@ -0,0 +1,7 @@
import { ModelEvent } from './modelEvent';
test('testOnModelInit', () => {
expect(ModelEvent.OnModelInit).toBe('OnModelInit');
});

View File

@ -1,5 +1,5 @@
import { GpuSetting } from '../miscellaneous'
import { Model } from './modelEntity'
import { Model, ModelFile } from './modelEntity'
/**
* Model extension for managing models.
@ -12,7 +12,7 @@ export interface ModelInterface {
* @returns A Promise that resolves when the model has been downloaded.
*/
downloadModel(
model: Model,
model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void>
@ -29,24 +29,17 @@ export interface ModelInterface {
* @param modelId - The ID of the model to delete.
* @returns A Promise that resolves when the model has been deleted.
*/
deleteModel(modelId: string): Promise<void>
/**
* Saves a model.
* @param model - The model to save.
* @returns A Promise that resolves when the model has been saved.
*/
saveModel(model: Model): Promise<void>
deleteModel(model: ModelFile): Promise<void>
/**
* Gets a list of downloaded models.
* @returns A Promise that resolves with an array of downloaded models.
*/
getDownloadedModels(): Promise<Model[]>
getDownloadedModels(): Promise<ModelFile[]>
/**
* Gets a list of configured models.
* @returns A Promise that resolves with an array of configured models.
*/
getConfiguredModels(): Promise<Model[]>
getConfiguredModels(): Promise<ModelFile[]>
}

View File

@ -0,0 +1,16 @@
import * as monitoringInterface from './monitoringInterface';
import * as resourceInfo from './resourceInfo';
import * as index from './index';
import * as monitoringInterface from './monitoringInterface';
import * as resourceInfo from './resourceInfo';
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
for (const key in monitoringInterface) {
expect(index[key]).toBe(monitoringInterface[key]);
}
for (const key in resourceInfo) {
expect(index[key]).toBe(resourceInfo[key]);
}
});

View File

@ -0,0 +1,5 @@
it('should not throw any errors', () => {
expect(() => require('./index')).not.toThrow();
});

View File

@ -0,0 +1,19 @@
import { createSettingComponent } from './settingComponent';
it('should throw an error when creating a setting component with invalid controller type', () => {
const props: SettingComponentProps = {
key: 'invalidControllerKey',
title: 'Invalid Controller Title',
description: 'Invalid Controller Description',
controllerType: 'invalid' as any,
controllerProps: {
placeholder: 'Enter text',
value: 'Initial Value',
type: 'text',
textAlign: 'left',
inputActions: ['unobscure'],
},
};
expect(() => createSettingComponent(props)).toThrowError();
});

View File

@ -0,0 +1,6 @@
import { ThreadEvent } from './threadEvent';
it('should have the correct values', () => {
expect(ThreadEvent.OnThreadStarted).toBe('OnThreadStarted');
});

18
electron/jest.config.js Normal file
View File

@ -0,0 +1,18 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
collectCoverageFrom: ['src/**/*.{ts,tsx}'],
modulePathIgnorePatterns: ['<rootDir>/tests'],
moduleNameMapper: {
'@/(.*)': '<rootDir>/src/$1',
},
runner: './testRunner.js',
transform: {
'^.+\\.tsx?$': [
'ts-jest',
{
diagnostics: false,
},
],
},
}

10
electron/testRunner.js Normal file
View File

@ -0,0 +1,10 @@
const jestRunner = require('jest-runner');
class EmptyTestFileRunner extends jestRunner.default {
async runTests(tests, watcher, onStart, onResult, onFailure, options) {
const nonEmptyTests = tests.filter(test => test.context.hasteFS.getSize(test.path) > 0);
return super.runTests(nonEmptyTests, watcher, onStart, onResult, onFailure, options);
}
}
module.exports = EmptyTestFileRunner;

View File

@ -1,32 +1,29 @@
import { expect } from '@playwright/test'
import { page, test, TIMEOUT } from '../config/fixtures'
test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => {
test('Select GPT model from Hub and Chat with Invalid API Key', async ({
hubPage,
}) => {
await hubPage.navigateByMenu()
await hubPage.verifyContainerVisible()
// Select the first GPT model
await page
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
.first().click()
// Attempt to create thread and chat in Thread page
await page
.getByTestId('btn-create-thread')
.first()
.click()
await page
.getByTestId('txt-input-chat')
.fill('dummy value')
await page.getByTestId('txt-input-chat').fill('dummy value')
await page
.getByTestId('btn-send-chat')
.click()
await page.getByTestId('btn-send-chat').click()
await page.waitForFunction(() => {
const loaders = document.querySelectorAll('[data-testid$="loader"]');
return !loaders.length;
}, { timeout: TIMEOUT });
await page.waitForFunction(
() => {
const loaders = document.querySelectorAll('[data-testid$="loader"]')
return !loaders.length
},
{ timeout: TIMEOUT }
)
const APIKeyError = page.getByTestId('invalid-API-key-error')
await expect(APIKeyError).toBeVisible({

View File

@ -0,0 +1,5 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
}

View File

@ -7,6 +7,7 @@
"author": "Jan <service@jan.ai>",
"license": "MIT",
"scripts": {
"test": "jest",
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
},

View File

@ -0,0 +1,408 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
writeFileSync: jest.fn(),
readdirSync: jest.fn(),
readFileSync: jest.fn(),
appendFileSync: jest.fn(),
rm: jest.fn(),
writeBlob: jest.fn(),
joinPath: jest.fn(),
fileStat: jest.fn(),
},
joinPath: jest.fn(),
ConversationalExtension: jest.fn(),
}))
import { fs } from '@janhq/core'
import JSONConversationalExtension from '.'
describe('JSONConversationalExtension Tests', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'data:image/png;base64,abcd',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should write messages', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const messages = [{ id: '1', thread_id: '1', content: [] }] as any
await extension.writeMessages('1', messages)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should get all messages on string response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should get all messages on object response', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: 1 }])
})
it('get all messages return empty on error', async () => {
jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('return empty messages on no messages file', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([])
})
it('should ignore error message', async () => {
jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
const messages = await extension.getAllMessages('1')
expect(messages).toEqual([{ id: '1' }, { id: '2' }])
})
it('should create thread folder on load if it does not exist', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
await extension.onLoad()
expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
})
it('should log message on unload', () => {
const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
extension.onUnload()
expect(consoleSpy).toHaveBeenCalledWith(
'JSONConversationalExtension unloaded'
)
})
it('should return sorted threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce({ updated: '2023-01-01' })
.mockResolvedValueOnce({ updated: '2023-01-02' })
const threads = await extension.getThreads()
expect(threads).toEqual([
{ updated: '2023-01-02' },
{ updated: '2023-01-01' },
])
})
it('should ignore broken threads', async () => {
jest
.spyOn(extension, 'getValidThreadDirs')
.mockResolvedValue(['dir1', 'dir2'])
jest
.spyOn(extension, 'readThread')
.mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
.mockResolvedValueOnce('this_is_an_invalid_json_content')
const threads = await extension.getThreads()
expect(threads).toEqual([{ updated: '2023-01-01' }])
})
it('should save thread', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const writeFileSyncSpy = jest
.spyOn(fs, 'writeFileSync')
.mockResolvedValue({})
const thread = { id: '1', updated: '2023-01-01' } as any
await extension.saveThread(thread)
expect(mkdirSpy).toHaveBeenCalled()
expect(writeFileSyncSpy).toHaveBeenCalled()
})
it('should delete thread', async () => {
const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
await extension.deleteThread('1')
expect(rmSpy).toHaveBeenCalled()
})
it('should add new message', async () => {
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
const message = {
thread_id: '1',
content: [{ type: 'text', text: { annotations: [] } }],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new image message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should add new pdf message', async () => {
jest
.spyOn(fs, 'existsSync')
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(false)
// @ts-ignore
.mockResolvedValueOnce(true)
const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
const appendFileSyncSpy = jest
.spyOn(fs, 'appendFileSync')
.mockResolvedValue({})
jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
const message = {
thread_id: '1',
content: [
{ type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
],
} as any
await extension.addNewMessage(message)
expect(mkdirSpy).toHaveBeenCalled()
expect(appendFileSyncSpy).toHaveBeenCalled()
})
it('should store image', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeImage(
'data:image/png;base64,abcd',
'path/to/image.png'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
it('should store file', async () => {
const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
await extension.storeFile(
'data:application/pdf;base64,abcd',
'path/to/file.pdf'
)
expect(writeBlobSpy).toHaveBeenCalled()
})
})
describe('test readThread', () => {
let extension: JSONConversationalExtension
beforeEach(() => {
// @ts-ignore
extension = new JSONConversationalExtension()
})
it('should read thread', async () => {
jest
.spyOn(fs, 'readFileSync')
.mockResolvedValue(JSON.stringify({ id: '1' }))
const thread = await extension.readThread('1')
expect(thread).toEqual(`{"id":"1"}`)
})
it('getValidThreadDirs should return valid thread directories', async () => {
jest
.spyOn(fs, 'readdirSync')
.mockResolvedValueOnce(['1', '2', '3'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce(['thread.json'])
.mockResolvedValueOnce([])
// @ts-ignore
jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
jest.spyOn(fs, 'fileStat').mockResolvedValue({
isDirectory: true,
} as any)
const validThreadDirs = await extension.getValidThreadDirs()
expect(validThreadDirs).toEqual(['1', '2'])
})
})

View File

@ -5,6 +5,7 @@ import {
Thread,
ThreadMessage,
} from '@janhq/core'
import { safelyParseJSON } from './jsonUtil'
/**
* JSONConversationalExtension is a ConversationalExtension implementation that provides
@ -45,10 +46,11 @@ export default class JSONConversationalExtension extends ConversationalExtension
if (result.status === 'fulfilled') {
return typeof result.value === 'object'
? result.value
: JSON.parse(result.value)
: safelyParseJSON(result.value)
}
return undefined
})
.filter((convo) => convo != null)
.filter((convo) => !!convo)
convos.sort(
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
)
@ -195,7 +197,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
* @param threadDirName the thread dir we are reading from.
* @returns data of the thread
*/
private async readThread(threadDirName: string): Promise<any> {
async readThread(threadDirName: string): Promise<any> {
return fs.readFileSync(
await joinPath([
JSONConversationalExtension._threadFolder,
@ -210,7 +212,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
* Returns a Promise that resolves to an array of thread directories.
* @private
*/
private async getValidThreadDirs(): Promise<string[]> {
async getValidThreadDirs(): Promise<string[]> {
const fileInsideThread: string[] = await fs.readdirSync(
JSONConversationalExtension._threadFolder
)
@ -266,7 +268,8 @@ export default class JSONConversationalExtension extends ConversationalExtension
const messages: ThreadMessage[] = []
result.forEach((line: string) => {
messages.push(JSON.parse(line))
const message = safelyParseJSON(line)
if (message) messages.push(safelyParseJSON(line))
})
return messages
} catch (err) {

View File

@ -0,0 +1,14 @@
// Note about performance
// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
// v8 4.5 and above can optimise try/catch
export function safelyParseJSON(json) {
// This function cannot be optimised, it's best to
// keep it small!
var parsed
try {
parsed = JSON.parse(json)
} catch (e) {
return undefined
}
return parsed // Could be undefined!
}

View File

@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["src/**/*.test.ts"]
}

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -9,6 +9,7 @@
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"test": "jest test",
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install",
"sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-anthropic-extension && yarn && yarn build:publish"

View File

@ -0,0 +1,77 @@
// Import necessary modules
import JanInferenceAnthropicExtension, { Settings } from '.'
import { PayloadType, ChatCompletionRole } from '@janhq/core'
// Mocks
jest.mock('@janhq/core', () => ({
RemoteOAIEngine: jest.fn().mockImplementation(() => ({
registerSettings: jest.fn(),
registerModels: jest.fn(),
getSetting: jest.fn(),
onChange: jest.fn(),
onSettingUpdate: jest.fn(),
onLoad: jest.fn(),
headers: jest.fn(),
})),
PayloadType: jest.fn(),
ChatCompletionRole: {
User: 'user' as const,
Assistant: 'assistant' as const,
System: 'system' as const,
},
}))
// Helper functions
const createMockPayload = (): PayloadType => ({
messages: [
{ role: ChatCompletionRole.System, content: 'Meow' },
{ role: ChatCompletionRole.User, content: 'Hello' },
{ role: ChatCompletionRole.Assistant, content: 'Hi there' },
],
model: 'claude-v1',
stream: false,
})
describe('JanInferenceAnthropicExtension', () => {
let extension: JanInferenceAnthropicExtension
beforeEach(() => {
extension = new JanInferenceAnthropicExtension('', '')
extension.apiKey = 'mock-api-key'
extension.inferenceUrl = 'mock-endpoint'
jest.clearAllMocks()
})
it('should initialize with correct settings', async () => {
await extension.onLoad()
expect(extension.apiKey).toBe('mock-api-key')
expect(extension.inferenceUrl).toBe('mock-endpoint')
})
it('should transform payload correctly', () => {
const payload = createMockPayload()
const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload).toEqual({
max_tokens: 4096,
model: 'claude-v1',
stream: false,
system: 'Meow',
messages: [
{ role: 'user', content: 'Hello' },
{ role: 'assistant', content: 'Hi there' },
],
})
})
it('should transform response correctly', () => {
const nonStreamResponse = { content: [{ text: 'Test response' }] }
const streamResponse =
'data: {"type":"content_block_delta","delta":{"text":"Hello"}}'
expect(extension.transformResponse(nonStreamResponse)).toBe('Test response')
expect(extension.transformResponse(streamResponse)).toBe('Hello')
expect(extension.transformResponse('')).toBe('')
expect(extension.transformResponse('event: something')).toBe('')
})
})

View File

@ -13,7 +13,7 @@ import { ChatCompletionRole } from '@janhq/core'
declare const SETTINGS: Array<any>
declare const MODELS: Array<any>
enum Settings {
export enum Settings {
apiKey = 'anthropic-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
@ -23,6 +23,7 @@ type AnthropicPayloadType = {
model?: string
max_tokens?: number
messages?: Array<{ role: string; content: string }>
system?: string
}
/**
@ -113,6 +114,10 @@ export default class JanInferenceAnthropicExtension extends RemoteOAIEngine {
role: 'assistant',
content: item.content as string,
})
} else if (item.role === ChatCompletionRole.System) {
// When using Claude, you can dramatically improve its performance by using the system parameter to give it a role.
// This technique, known as role prompting, is the most powerful way to use system prompts with Claude.
convertedData.system = item.content as string
}
})

View File

@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}

View File

@ -1,7 +1,7 @@
{
"name": "@janhq/inference-cortex-extension",
"productName": "Cortex Inference Engine",
"version": "1.0.17",
"version": "1.0.18",
"description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.",
"main": "dist/index.js",
"node": "dist/node/index.cjs.js",

View File

@ -8,7 +8,7 @@
"id": "deepseek-coder-1.3b",
"object": "model",
"name": "Deepseek Coder 1.3B Instruct Q8",
"version": "1.3",
"version": "1.4",
"description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.",
"format": "gguf",
"settings": {
@ -22,13 +22,13 @@
"top_p": 0.95,
"stream": true,
"max_tokens": 16384,
"stop": [],
"stop": ["<|EOT|>"],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "Deepseek, The Bloke",
"tags": ["Tiny", "Foundational Model"],
"tags": ["Tiny"],
"size": 1430000000
},
"engine": "nitro"

View File

@ -2,13 +2,13 @@
"sources": [
{
"filename": "deepseek-coder-33b-instruct.Q4_K_M.gguf",
"url": "https://huggingface.co/TheBloke/deepseek-coder-33B-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf"
"url": "https://huggingface.co/mradermacher/deepseek-coder-33b-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf"
}
],
"id": "deepseek-coder-34b",
"object": "model",
"name": "Deepseek Coder 33B Instruct Q4",
"version": "1.3",
"version": "1.4",
"description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.",
"format": "gguf",
"settings": {
@ -22,13 +22,13 @@
"top_p": 0.95,
"stream": true,
"max_tokens": 16384,
"stop": [],
"stop": ["<|EOT|>"],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "Deepseek, The Bloke",
"tags": ["34B", "Foundational Model"],
"author": "Deepseek",
"tags": ["33B"],
"size": 19940000000
},
"engine": "nitro"

View File

@ -22,6 +22,7 @@ import {
downloadFile,
DownloadState,
DownloadEvent,
ModelFile,
} from '@janhq/core'
declare const CUDA_DOWNLOAD_URL: string
@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
this.nitroProcessInfo = health
}
override loadModel(model: Model): Promise<void> {
override loadModel(model: ModelFile): Promise<void> {
if (model.engine !== this.provider) return Promise.resolve()
this.getNitroProcessHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),

View File

@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
import {
log,
getSystemResourceInfo,
Model,
InferenceEngine,
ModelSettingParams,
PromptTemplate,
SystemInformation,
getJanDataFolderPath,
ModelFile,
} from '@janhq/core/node'
import { executableNitroFile } from './execute'
import terminate from 'terminate'
@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
*/
interface ModelInitOptions {
modelFolder: string
model: Model
model: ModelFile
}
// The PORT to use for the Nitro subprocess
const PORT = 3928
@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
if (!settings?.ngl) {
settings.ngl = 100
}
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
method: 'POST',
headers: {
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
})
.then((res) => {
log(
`[CORTEX]::Debug: Load model success with response ${JSON.stringify(
`[CORTEX]:: Load model success with response ${JSON.stringify(
res
)}`
)
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
async function validateModelStatus(modelId: string): Promise<void> {
// Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
log(`[CORTEX]::Debug: Validating model ${modelId}`)
log(`[CORTEX]:: Validating model ${modelId}`)
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
method: 'POST',
body: JSON.stringify({
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
retryDelay: 300,
}).then(async (res: Response) => {
log(
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
`[CORTEX]:: Validate model state with response ${JSON.stringify(
res.status
)}`
)
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
// Otherwise, return an object with an error message.
if (body.model_loaded) {
log(
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
`[CORTEX]:: Validate model state success with response ${JSON.stringify(
body
)}`
)
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
}
const errorBody = await res.text()
log(
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
`[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
res.statusText
)}`
)
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
async function killSubprocess(): Promise<void> {
const controller = new AbortController()
setTimeout(() => controller.abort(), 5000)
log(`[CORTEX]::Debug: Request to kill cortex`)
log(`[CORTEX]:: Request to kill cortex`)
const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, {
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
.then(() =>
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
)
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch((err) => {
log(
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
`[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
})
}
if (subprocess?.pid && process.platform !== 'darwin') {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
} else {
tcpPortUsed
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.then(() => resolve())
.catch(() => {
log(
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
* @returns A promise that resolves when the Nitro subprocess is started.
*/
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
log(`[CORTEX]:: Spawning cortex subprocess...`)
return new Promise<void>(async (resolve, reject) => {
let executableOptions = executableNitroFile(
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary
log(
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
`[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
)
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
// Handle subprocess output
subprocess.stdout.on('data', (data: any) => {
log(`[CORTEX]::Debug: ${data}`)
log(`[CORTEX]:: ${data}`)
})
subprocess.stderr.on('data', (data: any) => {
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
})
subprocess.on('close', (code: any) => {
log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
log(`[CORTEX]:: cortex exited with code: ${code}`)
subprocess = undefined
reject(`child process exited with code ${code}`)
})
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
tcpPortUsed
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
.then(() => {
log(`[CORTEX]::Debug: cortex is ready`)
log(`[CORTEX]:: cortex is ready`)
resolve()
})
})

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -119,5 +119,65 @@
]
},
"engine": "openai"
},
{
"sources": [
{
"url": "https://openai.com"
}
],
"id": "o1-preview",
"object": "model",
"name": "OpenAI o1-preview",
"version": "1.0",
"description": "OpenAI o1-preview is a new model with complex reasoning",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 0.95,
"stream": true,
"stop": [],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "OpenAI",
"tags": [
"General"
]
},
"engine": "openai"
},
{
"sources": [
{
"url": "https://openai.com"
}
],
"id": "o1-mini",
"object": "model",
"name": "OpenAI o1-mini",
"version": "1.0",
"description": "OpenAI o1-mini is a lightweight reasoning model",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 0.95,
"stream": true,
"stop": [],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "OpenAI",
"tags": [
"General"
]
},
"engine": "openai"
}
]

View File

@ -0,0 +1,54 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
RemoteOAIEngine: jest.fn().mockImplementation(() => ({
onLoad: jest.fn(),
registerSettings: jest.fn(),
registerModels: jest.fn(),
getSetting: jest.fn(),
onSettingUpdate: jest.fn(),
})),
}))
import JanInferenceOpenAIExtension, { Settings } from '.'
describe('JanInferenceOpenAIExtension', () => {
let extension: JanInferenceOpenAIExtension
beforeEach(() => {
// @ts-ignore
extension = new JanInferenceOpenAIExtension()
})
it('should initialize with settings and models', async () => {
await extension.onLoad()
// Assuming there are some default SETTINGS and MODELS being registered
expect(extension.apiKey).toBe(undefined)
expect(extension.inferenceUrl).toBe('')
})
it('should transform the payload for preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'o1-mini',
// Add other required properties...
}
const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens)
expect(transformedPayload).not.toHaveProperty('max_tokens')
expect(transformedPayload).toHaveProperty('max_completion_tokens')
})
it('should not transform the payload for non-preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'non-preview-model',
// Add other required properties...
}
const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload).toEqual(payload)
})
})

View File

@ -6,16 +6,17 @@
* @module inference-openai-extension/src/index
*/
import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core'
import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core'
declare const SETTINGS: Array<any>
declare const MODELS: Array<any>
enum Settings {
export enum Settings {
apiKey = 'openai-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
type OpenAIPayloadType = PayloadType &
ModelRuntimeParams & { max_completion_tokens: number }
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
@ -24,6 +25,7 @@ enum Settings {
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'openai'
previewModels = ['o1-mini', 'o1-preview']
override async onLoad(): Promise<void> {
super.onLoad()
@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
}
}
}
/**
* Tranform the payload before sending it to the inference endpoint.
* The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter.
* Others do not.
* @param payload
* @returns
*/
transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => {
// Transform the payload for preview models
if (this.previewModels.includes(payload.model)) {
const { max_tokens, ...params } = payload
return {
...params,
max_completion_tokens: max_tokens,
}
}
// Pass through for non-preview models
return payload
}
}

View File

@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -8,6 +8,7 @@
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
},

View File

@ -27,7 +27,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
typescript({ useTsconfigDeclarationDir: true }),
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Compile TypeScript files
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
// commonjs(),
@ -62,7 +62,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
typescript({ useTsconfigDeclarationDir: true }),
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
commonjs(),
// Allow node_modules resolution, so you can use 'external' to control

View File

@ -0,0 +1,87 @@
import { extractFileName } from './path';
describe('extractFileName Function', () => {
it('should correctly extract the file name with the provided file extension', () => {
const url = 'http://example.com/some/path/to/file.ext';
const fileExtension = '.ext';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('file.ext');
});
it('should correctly append the file extension if it does not already exist in the file name', () => {
const url = 'http://example.com/some/path/to/file';
const fileExtension = '.txt';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('file.txt');
});
it('should handle cases where the URL does not have a file extension correctly', () => {
const url = 'http://example.com/some/path/to/file';
const fileExtension = '.jpg';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('file.jpg');
});
it('should correctly handle URLs without a trailing slash', () => {
const url = 'http://example.com/some/path/tofile';
const fileExtension = '.txt';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('tofile.txt');
});
it('should correctly handle URLs with multiple file extensions', () => {
const url = 'http://example.com/some/path/tofile.tar.gz';
const fileExtension = '.gz';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('tofile.tar.gz');
});
it('should correctly handle URLs with special characters', () => {
const url = 'http://example.com/some/path/tófílë.extë';
const fileExtension = '.extë';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('tófílë.extë');
});
it('should correctly handle URLs that are just a file with no path', () => {
const url = 'http://example.com/file.txt';
const fileExtension = '.txt';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('file.txt');
});
it('should correctly handle URLs that have special query parameters', () => {
const url = 'http://example.com/some/path/tofile.ext?query=1';
const fileExtension = '.ext';
const fileName = extractFileName(url.split('?')[0], fileExtension);
expect(fileName).toBe('tofile.ext');
});
it('should correctly handle URLs that have uppercase characters', () => {
const url = 'http://EXAMPLE.COM/PATH/TO/FILE.EXT';
const fileExtension = '.ext';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('FILE.EXT');
});
it('should correctly handle invalid URLs', () => {
const url = 'invalid-url';
const fileExtension = '.txt';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('invalid-url.txt');
});
it('should correctly handle empty URLs', () => {
const url = '';
const fileExtension = '.txt';
const fileName = extractFileName(url, fileExtension);
expect(fileName).toBe('.txt');
});
it('should correctly handle undefined URLs', () => {
const url = undefined;
const fileExtension = '.txt';
const fileName = extractFileName(url as any, fileExtension);
expect(fileName).toBe('.txt');
});
});

View File

@ -3,6 +3,8 @@
*/
export function extractFileName(url: string, fileExtension: string): string {
if(!url) return fileExtension
const extractedFileName = url.split('/').pop()
const fileName = extractedFileName.toLowerCase().endsWith(fileExtension)
? extractedFileName

View File

@ -0,0 +1,788 @@
/**
* @jest-environment jsdom
*/
const readDirSyncMock = jest.fn()
const existMock = jest.fn()
const readFileSyncMock = jest.fn()
const downloadMock = jest.fn()
const mkdirMock = jest.fn()
const writeFileSyncMock = jest.fn()
const copyFileMock = jest.fn()
const dirNameMock = jest.fn()
const executeMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
events: {
emit: jest.fn(),
},
fs: {
existsSync: existMock,
readdirSync: readDirSyncMock,
readFileSync: readFileSyncMock,
writeFileSync: writeFileSyncMock,
mkdir: mkdirMock,
copyFile: copyFileMock,
fileStat: () => ({
isDirectory: false,
}),
},
dirName: dirNameMock,
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
downloadFile: downloadMock,
executeOnMain: executeMock,
}))
jest.mock('@huggingface/gguf')
global.fetch = jest.fn(() =>
Promise.resolve({
json: () => Promise.resolve({ test: 100 }),
arrayBuffer: jest.fn(),
})
) as jest.Mock
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
import { gguf } from '@huggingface/gguf'
describe('JanModelExtension', () => {
let sut: JanModelExtension
beforeAll(() => {
// @ts-ignore
sut = new JanModelExtension()
})
beforeEach(() => {
jest.clearAllMocks()
})
describe('getConfiguredModels', () => {
describe("when there's no models are pre-populated", () => {
it('should return empty array', async () => {
// Mock configured models data
const configuredModels = []
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
const result = await sut.getConfiguredModels()
expect(result).toEqual([])
})
})
describe("when there's are pre-populated models - all flattened", () => {
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getConfiguredModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model.json',
id: '2',
}),
])
)
})
})
describe("when there's are pre-populated models - there are nested folders", () => {
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else if (path.includes('model2/model2-1'))
return JSON.stringify(configuredModels[1])
})
const result = await sut.getConfiguredModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('getDownloadedModels', () => {
describe('no models downloaded', () => {
it('should return empty array', async () => {
// Mock downloaded models data
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
const result = await sut.getDownloadedModels()
expect(result).toEqual([])
})
})
describe('only one model is downloaded', () => {
describe('flatten folder', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2']
else if (path === 'file://models/model1')
return ['model.json', 'test.gguf']
else return ['model.json']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
])
)
})
})
})
describe('all models are downloaded', () => {
describe('nested folders', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else return ['model.json', 'test.gguf']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('all models are downloaded with uppercased GGUF files', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else if (path === 'file://models/model1')
return ['model.json', 'test.GGUF']
else return ['model.json', 'test.gguf']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
describe('all models are downloaded - GGUF & Tensort RT', () => {
it('returns downloaded models - with correct file_path and model id', async () => {
// Mock configured models data
const configuredModels = [
{
id: '1',
name: 'Model 1',
version: '1.0.0',
description: 'Model 1 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model1',
},
format: 'onnx',
sources: [],
created: new Date(),
updated: new Date(),
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
{
id: '2',
name: 'Model 2',
version: '2.0.0',
description: 'Model 2 description',
object: {
type: 'model',
uri: 'http://localhost:5000/models/model2',
},
format: 'onnx',
sources: [],
parameters: {},
settings: {},
metadata: {},
engine: 'test',
} as any,
]
existMock.mockReturnValue(true)
readDirSyncMock.mockImplementation((path) => {
if (path === 'file://models') return ['model1', 'model2/model2-1']
else if (path === 'file://models/model1')
return ['model.json', 'test.gguf']
else return ['model.json', 'test.engine']
})
readFileSyncMock.mockImplementation((path) => {
if (path.includes('model1'))
return JSON.stringify(configuredModels[0])
else return JSON.stringify(configuredModels[1])
})
const result = await sut.getDownloadedModels()
expect(result).toEqual(
expect.arrayContaining([
expect.objectContaining({
file_path: 'file://models/model1/model.json',
id: '1',
}),
expect.objectContaining({
file_path: 'file://models/model2/model2-1/model.json',
id: '2',
}),
])
)
})
})
})
describe('deleteModel', () => {
describe('model is a GGUF model', () => {
it('should delete the GGUF file', async () => {
fs.unlinkSync = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
readDirSyncMock.mockImplementation((path) => {
return ['model.json', 'test.gguf']
})
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledWith(
'file://models/model1/test.gguf'
)
})
it('no gguf file presented', async () => {
fs.unlinkSync = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
readDirSyncMock.mockReturnValue(['model.json'])
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
})
it('delete an imported model', async () => {
fs.rm = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
// MARK: This is a tricky logic implement?
// I will just add test for now but will align on the legacy implementation
fs.readFileSync = jest.fn().mockReturnValue(
JSON.stringify({
metadata: {
author: 'user',
},
})
)
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
})
it('delete tensorrt-models', async () => {
fs.rm = jest.fn()
const dirMock = dirName as jest.Mock
dirMock.mockReturnValue('file://models/model1')
readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
existMock.mockReturnValue(true)
await sut.deleteModel({
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledWith(
'file://models/model1/test.engine'
)
})
})
})
describe('downloadModel', () => {
const model: any = {
id: 'model-id',
name: 'Test Model',
sources: [
{ url: 'http://example.com/model.gguf', filename: 'model.gguf' },
],
engine: 'test-engine',
}
const network = {
ignoreSSL: true,
proxy: 'http://proxy.example.com',
}
const gpuSettings: any = {
gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }],
}
it('should reject with invalid gguf metadata', async () => {
existMock.mockImplementation(() => false)
expect(
sut.downloadModel(model, gpuSettings, network)
).rejects.toBeTruthy()
})
it('should download corresponding ID', async () => {
existMock.mockImplementation(() => true)
dirNameMock.mockImplementation(() => 'file://models/model1')
downloadMock.mockImplementation(() => {
return Promise.resolve({})
})
expect(
await sut.downloadModel(
{ ...model, file_path: 'file://models/model1/model.json' },
gpuSettings,
network
)
).toBeUndefined()
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model.gguf',
modelId: 'model-id',
url: 'http://example.com/model.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
})
it('should handle invalid model file', async () => {
executeMock.mockResolvedValue({})
fs.readFileSync = jest.fn(() => {
return JSON.stringify({ metadata: { author: 'user' } })
})
expect(
sut.downloadModel(
{ ...model, file_path: 'file://models/model1/model.json' },
gpuSettings,
network
)
).resolves.not.toThrow()
expect(downloadMock).not.toHaveBeenCalled()
})
it('should handle model file with no sources', async () => {
executeMock.mockResolvedValue({})
const modelWithoutSources = { ...model, sources: [] }
expect(
sut.downloadModel(
{
...modelWithoutSources,
file_path: 'file://models/model1/model.json',
},
gpuSettings,
network
)
).resolves.toBe(undefined)
expect(downloadMock).not.toHaveBeenCalled()
})
it('should handle model file with multiple sources', async () => {
const modelWithMultipleSources = {
...model,
sources: [
{ url: 'http://example.com/model1.gguf', filename: 'model1.gguf' },
{ url: 'http://example.com/model2.gguf', filename: 'model2.gguf' },
],
}
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
downloadMock.mockImplementation(() => {
return Promise.resolve({})
})
expect(
await sut.downloadModel(
{
...modelWithMultipleSources,
file_path: 'file://models/model1/model.json',
},
gpuSettings,
network
)
).toBeUndefined()
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model1.gguf',
modelId: 'model-id',
url: 'http://example.com/model1.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model2.gguf',
modelId: 'model-id',
url: 'http://example.com/model2.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
})
it('should handle model file with no file_path', async () => {
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
const modelWithoutFilepath = { ...model, file_path: undefined }
await sut.downloadModel(modelWithoutFilepath, gpuSettings, network)
expect(downloadMock).toHaveBeenCalledWith(
expect.objectContaining({
localPath: 'file://models/model-id/model.gguf',
}),
expect.anything()
)
})
it('should handle model file with invalid file_path', async () => {
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
const modelWithInvalidFilepath = {
...model,
file_path: 'file://models/invalid-model.json',
}
await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network)
expect(downloadMock).toHaveBeenCalledWith(
expect.objectContaining({
localPath: 'file://models/model1/model.gguf',
}),
expect.anything()
)
})
})
})

View File

@ -22,6 +22,8 @@ import {
getFileSize,
AllQuantizations,
ModelEvent,
ModelFile,
dirName,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
]
private static readonly _tensorRtEngineFormat = '.engine'
private static readonly _supportedGpuArch = ['ampere', 'ada']
private static readonly _safetensorsRegexs = [
/model\.safetensors$/,
/model-[0-9]+-of-[0-9]+\.safetensors$/,
]
private static readonly _pytorchRegexs = [
/pytorch_model\.bin$/,
/consolidated\.[0-9]+\.pth$/,
/pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
/.*\.pt$/,
]
interrupted = false
/**
@ -83,15 +76,32 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is downloaded.
*/
async downloadModel(
model: Model,
model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void> {
// create corresponding directory
// Create corresponding directory
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
const modelJsonPath =
model.file_path ?? (await joinPath([modelDirPath, 'model.json']))
// Download HF model - model.json not exist
if (!(await fs.existsSync(modelJsonPath))) {
// It supports only one source for HF download
const metadata = await this.fetchModelMetadata(model.sources[0].url)
const updatedModel = await this.retrieveGGUFMetadata(metadata)
if (updatedModel) {
// Update model settings
model.settings = {
...model.settings,
...updatedModel.settings,
}
model.parameters = {
...model.parameters,
...updatedModel.parameters,
}
}
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
events.emit(ModelEvent.OnModelsUpdate, {})
}
@ -142,11 +152,15 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._supportedModelFormat
)
if (source.filename) {
path = await joinPath([modelDirPath, source.filename])
path = model.file_path
? await joinPath([await dirName(model.file_path), source.filename])
: await joinPath([modelDirPath, source.filename])
}
const downloadRequest: DownloadRequest = {
url: source.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)
}
@ -156,10 +170,13 @@ export default class JanModelExtension extends ModelExtension {
model.sources[0]?.url,
JanModelExtension._supportedModelFormat
)
const path = await joinPath([modelDirPath, fileName])
const path = model.file_path
? await joinPath([await dirName(model.file_path), fileName])
: await joinPath([modelDirPath, fileName])
const downloadRequest: DownloadRequest = {
url: model.sources[0]?.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)
@ -319,9 +336,9 @@ export default class JanModelExtension extends ModelExtension {
* @param filePath - The path to the model file to delete.
* @returns A Promise that resolves when the model is deleted.
*/
async deleteModel(modelId: string): Promise<void> {
async deleteModel(model: ModelFile): Promise<void> {
try {
const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
const dirPath = await dirName(model.file_path)
const jsonFilePath = await joinPath([
dirPath,
JanModelExtension._modelMetadataFileName,
@ -330,6 +347,8 @@ export default class JanModelExtension extends ModelExtension {
await this.readModelMetadata(jsonFilePath)
) as Model
// TODO: This is so tricky?
// Should depend on sources?
const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user'
if (isUserImportModel) {
@ -350,30 +369,11 @@ export default class JanModelExtension extends ModelExtension {
}
}
/**
* Saves a model file.
* @param model - The model to save.
* @returns A Promise that resolves when the model is saved.
*/
async saveModel(model: Model): Promise<void> {
const jsonFilePath = await joinPath([
JanModelExtension._homeDir,
model.id,
JanModelExtension._modelMetadataFileName,
])
try {
await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
} catch (err) {
console.error(err)
}
}
/**
* Gets all downloaded models.
* @returns A Promise that resolves with an array of all models.
*/
async getDownloadedModels(): Promise<Model[]> {
async getDownloadedModels(): Promise<ModelFile[]> {
return await this.getModelsMetadata(
async (modelDir: string, model: Model) => {
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
@ -425,8 +425,10 @@ export default class JanModelExtension extends ModelExtension {
): Promise<string | undefined> {
// try to find model.json recursively inside each folder
if (!(await fs.existsSync(folderFullPath))) return undefined
const files: string[] = await fs.readdirSync(folderFullPath)
if (files.length === 0) return undefined
if (files.includes(JanModelExtension._modelMetadataFileName)) {
return joinPath([
folderFullPath,
@ -446,7 +448,7 @@ export default class JanModelExtension extends ModelExtension {
private async getModelsMetadata(
selector?: (path: string, model: Model) => Promise<boolean>
): Promise<Model[]> {
): Promise<ModelFile[]> {
try {
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
console.debug('Model folder not found')
@ -469,6 +471,7 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._homeDir,
dirName,
])
const jsonPath = await this.getModelJsonPath(folderFullPath)
if (await fs.existsSync(jsonPath)) {
@ -486,6 +489,8 @@ export default class JanModelExtension extends ModelExtension {
},
]
}
model.file_path = jsonPath
model.file_name = JanModelExtension._modelMetadataFileName
if (selector && !(await selector?.(dirName, model))) {
return
@ -506,7 +511,7 @@ export default class JanModelExtension extends ModelExtension {
typeof result.value === 'object'
? result.value
: JSON.parse(result.value)
return model as Model
return model as ModelFile
} catch {
console.debug(`Unable to parse model metadata: ${result.value}`)
}
@ -574,7 +579,7 @@ export default class JanModelExtension extends ModelExtension {
])
)
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
const updatedModel = await this.retrieveGGUFMetadata(metadata)
if (!defaultModel) {
console.error('Unable to find default model')
@ -594,18 +599,11 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
stop: eos_id
? [metadata['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
prompt_template:
metadata?.parsed_chat_template ??
defaultModel.settings.prompt_template,
ctx_len:
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@ -637,7 +635,7 @@ export default class JanModelExtension extends ModelExtension {
* Gets all available models.
* @returns A Promise that resolves with an array of all models.
*/
async getConfiguredModels(): Promise<Model[]> {
async getConfiguredModels(): Promise<ModelFile[]> {
return this.getModelsMetadata()
}
@ -669,7 +667,7 @@ export default class JanModelExtension extends ModelExtension {
modelBinaryPath: string,
modelFolderName: string,
modelFolderPath: string
): Promise<Model> {
): Promise<ModelFile> {
const fileStats = await fs.fileStat(modelBinaryPath, true)
const binaryFileSize = fileStats.size
@ -685,9 +683,9 @@ export default class JanModelExtension extends ModelExtension {
'retrieveGGUFMetadata',
modelBinaryPath
)
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
const binaryFileName = await baseName(modelBinaryPath)
const updatedModel = await this.retrieveGGUFMetadata(metadata)
const model: Model = {
...defaultModel,
@ -701,19 +699,12 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
stop: eos_id
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
prompt_template:
metadata?.parsed_chat_template ??
defaultModel.settings.prompt_template,
ctx_len:
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@ -732,25 +723,21 @@ export default class JanModelExtension extends ModelExtension {
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
return model
return {
...model,
file_path: modelFilePath,
file_name: JanModelExtension._modelMetadataFileName,
}
}
async updateModelInfo(modelInfo: Partial<Model>): Promise<Model> {
const modelId = modelInfo.id
async updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile> {
if (modelInfo.id == null) throw new Error('Model ID is required')
const janDataFolderPath = await getJanDataFolderPath()
const jsonFilePath = await joinPath([
janDataFolderPath,
'models',
modelId,
JanModelExtension._modelMetadataFileName,
])
const model = JSON.parse(
await this.readModelMetadata(jsonFilePath)
) as Model
await this.readModelMetadata(modelInfo.file_path)
) as ModelFile
const updatedModel: Model = {
const updatedModel: ModelFile = {
...model,
...modelInfo,
parameters: {
@ -765,9 +752,15 @@ export default class JanModelExtension extends ModelExtension {
...model.metadata,
...modelInfo.metadata,
},
// Should not persist file_path & file_name
file_path: undefined,
file_name: undefined,
}
await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2))
await fs.writeFileSync(
modelInfo.file_path,
JSON.stringify(updatedModel, null, 2)
)
return updatedModel
}
@ -877,4 +870,35 @@ export default class JanModelExtension extends ModelExtension {
importedModels
)
}
/**
* Retrieve Model Settings from GGUF Metadata
* @param metadata
* @returns
*/
async retrieveGGUFMetadata(metadata: any): Promise<Partial<Model>> {
const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata)
const defaultModel = DEFAULT_MODEL as Model
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const architecture = metadata['general.architecture']
return {
settings: {
prompt_template: template ?? defaultModel.settings.prompt_template,
ctx_len:
metadata[`${architecture}.context_length`] ??
metadata['llama.context_length'] ??
4096,
ngl:
(metadata[`${architecture}.block_count`] ??
metadata['llama.block_count'] ??
32) + 1,
},
parameters: {
stop: eos_id
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
},
}
}
}

View File

@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
// Parse metadata and tensor info
const { metadata } = ggufMetadata(buffer.buffer)
const template = new Template(metadata['tokenizer.chat_template'])
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const bos_id = metadata['tokenizer.ggml.bos_token_id']
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
// Parse jinja template
const renderedTemplate = template.render({
add_generation_prompt: true,
eos_token,
bos_token,
messages: [
{
role: 'system',
content: '{system_message}',
},
{
role: 'user',
content: '{prompt}',
},
],
})
const renderedTemplate = renderJinjaTemplate(metadata)
return {
...metadata,
parsed_chat_template: renderedTemplate,
@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
console.log('[MODEL_EXT]', e)
}
}
/**
* Convert metadata to jinja template
* @param metadata
*/
export const renderJinjaTemplate = (metadata: any): string => {
const template = new Template(metadata['tokenizer.chat_template'])
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const bos_id = metadata['tokenizer.ggml.bos_token_id']
if (eos_id === undefined || bos_id === undefined) {
return ''
}
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
// Parse jinja template
return template.render({
add_generation_prompt: true,
eos_token,
bos_token,
messages: [
{
role: 'system',
content: '{system_message}',
},
{
role: 'user',
content: '{prompt}',
},
],
})
}

View File

@ -0,0 +1,53 @@
import { renderJinjaTemplate } from './index'
import { Template } from '@huggingface/jinja'
jest.mock('@huggingface/jinja', () => ({
Template: jest.fn((template: string) => ({
render: jest.fn(() => `${template}_rendered`),
})),
}))
describe('renderJinjaTemplate', () => {
beforeEach(() => {
jest.clearAllMocks() // Clear mocks between tests
})
it('should render the template with correct parameters', () => {
const metadata = {
'tokenizer.chat_template': 'Hello, {{ messages }}!',
'tokenizer.ggml.eos_token_id': 0,
'tokenizer.ggml.bos_token_id': 1,
'tokenizer.ggml.tokens': ['EOS', 'BOS'],
}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered')
})
it('should handle missing token IDs gracefully', () => {
const metadata = {
'tokenizer.chat_template': 'Hello, {{ messages }}!',
'tokenizer.ggml.eos_token_id': 0,
'tokenizer.ggml.tokens': ['EOS'],
}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
expect(renderedTemplate).toBe('')
})
it('should handle empty template gracefully', () => {
const metadata = {}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith(undefined)
expect(renderedTemplate).toBe("")
})
})

View File

@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -22,6 +22,7 @@
"tensorrtVersion": "0.1.8",
"provider": "nitro-tensorrt-llm",
"scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
@ -49,7 +50,12 @@
"rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6",
"typescript": "^5.2.2"
"typescript": "^5.2.2",
"@types/jest": "^29.5.12",
"jest": "^29.7.0",
"jest-junit": "^16.0.0",
"jest-runner": "^29.7.0",
"ts-jest": "^29.2.5"
},
"dependencies": {
"@janhq/core": "file:../../core",

View File

@ -23,10 +23,10 @@ export default [
DOWNLOAD_RUNNER_URL:
process.platform === 'win32'
? JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
)
: JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
INFERENCE_URL: JSON.stringify(

View File

@ -0,0 +1,186 @@
import TensorRTLLMExtension from '../src/index'
import {
executeOnMain,
systemInformation,
fs,
baseName,
joinPath,
downloadFile,
} from '@janhq/core'
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
LocalOAIEngine: jest.fn().mockImplementation(function () {
// @ts-ignore
this.registerModels = () => {
return Promise.resolve()
}
// @ts-ignore
return this
}),
systemInformation: jest.fn(),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
},
joinPath: jest.fn(),
baseName: jest.fn(),
downloadFile: jest.fn(),
executeOnMain: jest.fn(),
showToast: jest.fn(),
events: {
emit: jest.fn(),
// @ts-ignore
on: (event, func) => {
func({ fileName: './' })
},
off: jest.fn(),
},
}))
// @ts-ignore
global.COMPATIBILITY = {
platform: ['win32'],
}
// @ts-ignore
global.PROVIDER = 'tensorrt-llm'
// @ts-ignore
global.INFERENCE_URL = 'http://localhost:5000'
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.MODELS = []
// @ts-ignore
global.TENSORRT_VERSION = ''
// @ts-ignore
global.DOWNLOAD_RUNNER_URL = ''
describe('TensorRTLLMExtension', () => {
let extension: TensorRTLLMExtension
beforeEach(() => {
// @ts-ignore
extension = new TensorRTLLMExtension()
jest.clearAllMocks()
})
describe('compatibility', () => {
it('should return the correct compatibility', () => {
const result = extension.compatibility()
expect(result).toEqual({
platform: ['win32'],
})
})
})
describe('install', () => {
it('should install if compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(executeOnMain as jest.Mock).mockResolvedValue({})
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
;(baseName as jest.Mock).mockResolvedValue('./')
;(joinPath as jest.Mock).mockResolvedValue('./')
;(downloadFile as jest.Mock).mockResolvedValue({})
await extension.install()
expect(executeOnMain).toHaveBeenCalled()
})
it('should not install if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
await extension.install()
expect(executeOnMain).not.toHaveBeenCalled()
})
})
describe('installationState', () => {
it('should return NotCompatible if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
const result = await extension.installationState()
expect(result).toBe('NotCompatible')
})
it('should return Installed if executable exists', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(true)
const result = await extension.installationState()
expect(result).toBe('Installed')
})
it('should return NotInstalled if executable does not exist', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
const result = await extension.installationState()
expect(result).toBe('NotInstalled')
})
})
describe('isCompatible', () => {
it('should return true for compatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
const result = extension.isCompatible(mockInfo)
expect(result).toBe(true)
})
it('should return false for incompatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
}
const result = extension.isCompatible(mockInfo)
expect(result).toBe(false)
})
})
})
describe('GitHub Release File URL Test', () => {
const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';
it('should return a status code 200 for the release file URL', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).toBe(200);
});
it('should not return a 404 status', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).not.toBe(404);
});
});

View File

@ -23,6 +23,7 @@ import {
ModelEvent,
getJanDataFolderPath,
SystemInformation,
ModelFile,
} from '@janhq/core'
/**
@ -40,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
override nodeModule = NODE
private supportedGpuArch = ['ampere', 'ada']
private supportedPlatform = ['win32', 'linux']
override compatibility() {
return COMPATIBILITY as unknown as Compatibility
@ -137,7 +137,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {})
}
override async loadModel(model: Model): Promise<void> {
override async loadModel(model: ModelFile): Promise<void> {
if ((await this.installationState()) === 'Installed')
return super.loadModel(model)
@ -190,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
!!info.gpuSetting &&
!!firstGpu &&
info.gpuSetting.gpus.length > 0 &&
this.supportedPlatform.includes(info.osInfo.platform) &&
this.compatibility().platform.includes(info.osInfo.platform) &&
!!firstGpu.arch &&
firstGpu.name.toLowerCase().includes('nvidia') &&
this.supportedGpuArch.includes(firstGpu.arch)

View File

@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
}
if (subprocess?.pid) {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
return tcpPortUsed
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
.then(() => resolve())
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch(() => {
killRequest()
})

View File

@ -16,5 +16,6 @@
"resolveJsonModule": true,
"typeRoots": ["node_modules/@types"]
},
"include": ["src"]
"include": ["src"],
"exclude": ["**/*.test.ts"]
}

Some files were not shown because too many files have changed in this diff Show More