* fix: mismatch between model json and path * chore: revert preserve model settings * test: add tests
This commit is contained in:
parent
c3cb192486
commit
8e603bd5db
1
.gitignore
vendored
1
.gitignore
vendored
@ -45,3 +45,4 @@ core/test_results.html
|
||||
coverage
|
||||
.yarn
|
||||
.yarnrc
|
||||
*.tsbuildinfo
|
||||
|
||||
@ -1,98 +1,109 @@
|
||||
import { openExternalUrl } from './core';
|
||||
import { joinPath } from './core';
|
||||
import { openFileExplorer } from './core';
|
||||
import { getJanDataFolderPath } from './core';
|
||||
import { abortDownload } from './core';
|
||||
import { getFileSize } from './core';
|
||||
import { executeOnMain } from './core';
|
||||
import { openExternalUrl } from './core'
|
||||
import { joinPath } from './core'
|
||||
import { openFileExplorer } from './core'
|
||||
import { getJanDataFolderPath } from './core'
|
||||
import { abortDownload } from './core'
|
||||
import { getFileSize } from './core'
|
||||
import { executeOnMain } from './core'
|
||||
|
||||
it('should open external url', async () => {
|
||||
const url = 'http://example.com';
|
||||
globalThis.core = {
|
||||
api: {
|
||||
openExternalUrl: jest.fn().mockResolvedValue('opened')
|
||||
describe('test core apis', () => {
|
||||
it('should open external url', async () => {
|
||||
const url = 'http://example.com'
|
||||
globalThis.core = {
|
||||
api: {
|
||||
openExternalUrl: jest.fn().mockResolvedValue('opened'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await openExternalUrl(url);
|
||||
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url);
|
||||
expect(result).toBe('opened');
|
||||
});
|
||||
const result = await openExternalUrl(url)
|
||||
expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
|
||||
expect(result).toBe('opened')
|
||||
})
|
||||
|
||||
|
||||
it('should join paths', async () => {
|
||||
const paths = ['/path/one', '/path/two'];
|
||||
globalThis.core = {
|
||||
api: {
|
||||
joinPath: jest.fn().mockResolvedValue('/path/one/path/two')
|
||||
it('should join paths', async () => {
|
||||
const paths = ['/path/one', '/path/two']
|
||||
globalThis.core = {
|
||||
api: {
|
||||
joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await joinPath(paths);
|
||||
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths);
|
||||
expect(result).toBe('/path/one/path/two');
|
||||
});
|
||||
const result = await joinPath(paths)
|
||||
expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
|
||||
expect(result).toBe('/path/one/path/two')
|
||||
})
|
||||
|
||||
|
||||
it('should open file explorer', async () => {
|
||||
const path = '/path/to/open';
|
||||
globalThis.core = {
|
||||
api: {
|
||||
openFileExplorer: jest.fn().mockResolvedValue('opened')
|
||||
it('should open file explorer', async () => {
|
||||
const path = '/path/to/open'
|
||||
globalThis.core = {
|
||||
api: {
|
||||
openFileExplorer: jest.fn().mockResolvedValue('opened'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await openFileExplorer(path);
|
||||
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path);
|
||||
expect(result).toBe('opened');
|
||||
});
|
||||
const result = await openFileExplorer(path)
|
||||
expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
|
||||
expect(result).toBe('opened')
|
||||
})
|
||||
|
||||
|
||||
it('should get jan data folder path', async () => {
|
||||
globalThis.core = {
|
||||
api: {
|
||||
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data')
|
||||
it('should get jan data folder path', async () => {
|
||||
globalThis.core = {
|
||||
api: {
|
||||
getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await getJanDataFolderPath();
|
||||
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled();
|
||||
expect(result).toBe('/path/to/jan/data');
|
||||
});
|
||||
const result = await getJanDataFolderPath()
|
||||
expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
|
||||
expect(result).toBe('/path/to/jan/data')
|
||||
})
|
||||
|
||||
|
||||
it('should abort download', async () => {
|
||||
const fileName = 'testFile';
|
||||
globalThis.core = {
|
||||
api: {
|
||||
abortDownload: jest.fn().mockResolvedValue('aborted')
|
||||
it('should abort download', async () => {
|
||||
const fileName = 'testFile'
|
||||
globalThis.core = {
|
||||
api: {
|
||||
abortDownload: jest.fn().mockResolvedValue('aborted'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await abortDownload(fileName);
|
||||
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName);
|
||||
expect(result).toBe('aborted');
|
||||
});
|
||||
const result = await abortDownload(fileName)
|
||||
expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
|
||||
expect(result).toBe('aborted')
|
||||
})
|
||||
|
||||
|
||||
it('should get file size', async () => {
|
||||
const url = 'http://example.com/file';
|
||||
globalThis.core = {
|
||||
api: {
|
||||
getFileSize: jest.fn().mockResolvedValue(1024)
|
||||
it('should get file size', async () => {
|
||||
const url = 'http://example.com/file'
|
||||
globalThis.core = {
|
||||
api: {
|
||||
getFileSize: jest.fn().mockResolvedValue(1024),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await getFileSize(url);
|
||||
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url);
|
||||
expect(result).toBe(1024);
|
||||
});
|
||||
const result = await getFileSize(url)
|
||||
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
|
||||
expect(result).toBe(1024)
|
||||
})
|
||||
|
||||
|
||||
it('should execute function on main process', async () => {
|
||||
const extension = 'testExtension';
|
||||
const method = 'testMethod';
|
||||
const args = ['arg1', 'arg2'];
|
||||
globalThis.core = {
|
||||
api: {
|
||||
invokeExtensionFunc: jest.fn().mockResolvedValue('result')
|
||||
it('should execute function on main process', async () => {
|
||||
const extension = 'testExtension'
|
||||
const method = 'testMethod'
|
||||
const args = ['arg1', 'arg2']
|
||||
globalThis.core = {
|
||||
api: {
|
||||
invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
|
||||
},
|
||||
}
|
||||
};
|
||||
const result = await executeOnMain(extension, method, ...args);
|
||||
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args);
|
||||
expect(result).toBe('result');
|
||||
});
|
||||
const result = await executeOnMain(extension, method, ...args)
|
||||
expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
|
||||
expect(result).toBe('result')
|
||||
})
|
||||
})
|
||||
|
||||
describe('dirName - just a pass thru api', () => {
|
||||
it('should retrieve the directory name from a file path', async () => {
|
||||
const mockDirName = jest.fn()
|
||||
globalThis.core = {
|
||||
api: {
|
||||
dirName: mockDirName.mockResolvedValue('/path/to'),
|
||||
},
|
||||
}
|
||||
// Normal file path with extension
|
||||
const path = '/path/to/file.txt'
|
||||
await globalThis.core.api.dirName(path)
|
||||
expect(mockDirName).toHaveBeenCalledWith(path)
|
||||
})
|
||||
})
|
||||
|
||||
@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
|
||||
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
|
||||
globalThis.core.api?.joinPath(paths)
|
||||
|
||||
/**
|
||||
* Get dirname of a file path.
|
||||
* @param path - The file path to retrieve dirname.
|
||||
* @returns {Promise<string>} A promise that resolves the dirname.
|
||||
*/
|
||||
const dirName: (path: string) => Promise<string> = (path) => globalThis.core.api?.dirName(path)
|
||||
|
||||
/**
|
||||
* Retrieve the basename from an url.
|
||||
* @param path - The path to retrieve.
|
||||
@ -161,5 +168,6 @@ export {
|
||||
systemInformation,
|
||||
showToast,
|
||||
getFileSize,
|
||||
dirName,
|
||||
FileStat,
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
|
||||
import { events } from '../../events'
|
||||
import { BaseExtension } from '../../extension'
|
||||
import { fs } from '../../fs'
|
||||
import { MessageRequest, Model, ModelEvent } from '../../../types'
|
||||
import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
|
||||
import { EngineManager } from './EngineManager'
|
||||
|
||||
/**
|
||||
@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
override onLoad() {
|
||||
this.registerEngine()
|
||||
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Loads the model.
|
||||
*/
|
||||
async loadModel(model: Model): Promise<any> {
|
||||
async loadModel(model: ModelFile): Promise<any> {
|
||||
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
|
||||
import { executeOnMain, systemInformation, dirName } from '../../core'
|
||||
import { events } from '../../events'
|
||||
import { Model, ModelEvent } from '../../../types'
|
||||
import { Model, ModelEvent, ModelFile } from '../../../types'
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
unloadModelFunctionName: string = 'unloadModel'
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
* This class represents a base for local inference providers in the OpenAI architecture.
|
||||
* It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
|
||||
* The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
|
||||
* The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
|
||||
*/
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
// These events are applicable to local inference providers
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
override async loadModel(model: Model): Promise<void> {
|
||||
override async loadModel(model: ModelFile): Promise<void> {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
const modelFolderName = 'models'
|
||||
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
|
||||
const modelFolder = await dirName(model.file_path)
|
||||
const systemInfo = await systemInformation()
|
||||
const res = await executeOnMain(
|
||||
this.nodeModule,
|
||||
|
||||
@ -4,6 +4,7 @@ import {
|
||||
HuggingFaceRepoData,
|
||||
ImportingModel,
|
||||
Model,
|
||||
ModelFile,
|
||||
ModelInterface,
|
||||
OptionType,
|
||||
} from '../../types'
|
||||
@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
|
||||
network?: { proxy: string; ignoreSSL?: boolean }
|
||||
): Promise<void>
|
||||
abstract cancelModelDownload(modelId: string): Promise<void>
|
||||
abstract deleteModel(modelId: string): Promise<void>
|
||||
abstract saveModel(model: Model): Promise<void>
|
||||
abstract getDownloadedModels(): Promise<Model[]>
|
||||
abstract getConfiguredModels(): Promise<Model[]>
|
||||
abstract deleteModel(model: ModelFile): Promise<void>
|
||||
abstract getDownloadedModels(): Promise<ModelFile[]>
|
||||
abstract getConfiguredModels(): Promise<ModelFile[]>
|
||||
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise<void>
|
||||
abstract updateModelInfo(modelInfo: Partial<Model>): Promise<Model>
|
||||
abstract updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile>
|
||||
abstract fetchHuggingFaceRepoData(repoId: string): Promise<HuggingFaceRepoData>
|
||||
abstract getDefaultModel(): Promise<Model>
|
||||
}
|
||||
|
||||
@ -1,40 +1,57 @@
|
||||
import { App } from './app';
|
||||
jest.mock('../../helper', () => ({
|
||||
...jest.requireActual('../../helper'),
|
||||
getJanDataFolderPath: () => './app',
|
||||
}))
|
||||
import { dirname } from 'path'
|
||||
import { App } from './app'
|
||||
|
||||
it('should call stopServer', () => {
|
||||
const app = new App();
|
||||
const stopServerMock = jest.fn().mockResolvedValue('Server stopped');
|
||||
const app = new App()
|
||||
const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
|
||||
jest.mock('@janhq/server', () => ({
|
||||
stopServer: stopServerMock
|
||||
}));
|
||||
const result = app.stopServer();
|
||||
expect(stopServerMock).toHaveBeenCalled();
|
||||
});
|
||||
stopServer: stopServerMock,
|
||||
}))
|
||||
app.stopServer()
|
||||
expect(stopServerMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should correctly retrieve basename', () => {
|
||||
const app = new App();
|
||||
const result = app.baseName('/path/to/file.txt');
|
||||
expect(result).toBe('file.txt');
|
||||
});
|
||||
const app = new App()
|
||||
const result = app.baseName('/path/to/file.txt')
|
||||
expect(result).toBe('file.txt')
|
||||
})
|
||||
|
||||
it('should correctly identify subdirectories', () => {
|
||||
const app = new App();
|
||||
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to';
|
||||
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir';
|
||||
const result = app.isSubdirectory(basePath, subPath);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
const app = new App()
|
||||
const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
|
||||
const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
|
||||
const result = app.isSubdirectory(basePath, subPath)
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should correctly join multiple paths', () => {
|
||||
const app = new App();
|
||||
const result = app.joinPath(['path', 'to', 'file']);
|
||||
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file';
|
||||
expect(result).toBe(expectedPath);
|
||||
});
|
||||
const app = new App()
|
||||
const result = app.joinPath(['path', 'to', 'file'])
|
||||
const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
|
||||
expect(result).toBe(expectedPath)
|
||||
})
|
||||
|
||||
it('should call correct function with provided arguments using process method', () => {
|
||||
const app = new App();
|
||||
const mockFunc = jest.fn();
|
||||
app.joinPath = mockFunc;
|
||||
app.process('joinPath', ['path1', 'path2']);
|
||||
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']);
|
||||
});
|
||||
const app = new App()
|
||||
const mockFunc = jest.fn()
|
||||
app.joinPath = mockFunc
|
||||
app.process('joinPath', ['path1', 'path2'])
|
||||
expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
|
||||
})
|
||||
|
||||
it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
|
||||
const app = new App()
|
||||
const path = 'C:/Users/John Doe/Desktop/file.txt'
|
||||
expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
|
||||
})
|
||||
|
||||
it('should retrieve the directory name when using file protocol', async () => {
|
||||
const app = new App()
|
||||
const path = 'file:/models/file.txt'
|
||||
expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
|
||||
})
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { basename, isAbsolute, join, relative } from 'path'
|
||||
import { basename, dirname, isAbsolute, join, relative } from 'path'
|
||||
|
||||
import { Processor } from './Processor'
|
||||
import {
|
||||
@ -6,6 +6,8 @@ import {
|
||||
appResourcePath,
|
||||
getAppConfigurations as appConfiguration,
|
||||
updateAppConfiguration,
|
||||
normalizeFilePath,
|
||||
getJanDataFolderPath,
|
||||
} from '../../helper'
|
||||
|
||||
export class App implements Processor {
|
||||
@ -28,6 +30,18 @@ export class App implements Processor {
|
||||
return join(...args)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get dirname of a file path.
|
||||
* @param path - The file path to retrieve dirname.
|
||||
*/
|
||||
dirName(path: string) {
|
||||
const arg =
|
||||
path.startsWith(`file:/`) || path.startsWith(`file:\\`)
|
||||
? join(getJanDataFolderPath(), normalizeFilePath(path))
|
||||
: path
|
||||
return dirname(arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the given path is a subdirectory of the given directory.
|
||||
*
|
||||
|
||||
@ -37,6 +37,7 @@ export enum AppRoute {
|
||||
getAppConfigurations = 'getAppConfigurations',
|
||||
updateAppConfiguration = 'updateAppConfiguration',
|
||||
joinPath = 'joinPath',
|
||||
dirName = 'dirName',
|
||||
isSubdirectory = 'isSubdirectory',
|
||||
baseName = 'baseName',
|
||||
startServer = 'startServer',
|
||||
|
||||
@ -52,3 +52,18 @@ type DownloadSize = {
|
||||
total: number
|
||||
transferred: number
|
||||
}
|
||||
|
||||
/**
|
||||
* The file metadata
|
||||
*/
|
||||
export type FileMetadata = {
|
||||
/**
|
||||
* The origin file path.
|
||||
*/
|
||||
file_path: string
|
||||
|
||||
/**
|
||||
* The file name.
|
||||
*/
|
||||
file_name: string
|
||||
}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { FileMetadata } from '../file'
|
||||
|
||||
/**
|
||||
* Represents the information about a model.
|
||||
* @stored
|
||||
@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
|
||||
export type ModelInitFailed = Model & {
|
||||
error: Error
|
||||
}
|
||||
|
||||
/**
|
||||
* ModelFile is the model.json entity and it's file metadata
|
||||
*/
|
||||
export type ModelFile = Model & FileMetadata
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { GpuSetting } from '../miscellaneous'
|
||||
import { Model } from './modelEntity'
|
||||
import { Model, ModelFile } from './modelEntity'
|
||||
|
||||
/**
|
||||
* Model extension for managing models.
|
||||
@ -29,14 +29,7 @@ export interface ModelInterface {
|
||||
* @param modelId - The ID of the model to delete.
|
||||
* @returns A Promise that resolves when the model has been deleted.
|
||||
*/
|
||||
deleteModel(modelId: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Saves a model.
|
||||
* @param model - The model to save.
|
||||
* @returns A Promise that resolves when the model has been saved.
|
||||
*/
|
||||
saveModel(model: Model): Promise<void>
|
||||
deleteModel(model: ModelFile): Promise<void>
|
||||
|
||||
/**
|
||||
* Gets a list of downloaded models.
|
||||
|
||||
@ -22,6 +22,7 @@ import {
|
||||
downloadFile,
|
||||
DownloadState,
|
||||
DownloadEvent,
|
||||
ModelFile,
|
||||
} from '@janhq/core'
|
||||
|
||||
declare const CUDA_DOWNLOAD_URL: string
|
||||
@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
||||
this.nitroProcessInfo = health
|
||||
}
|
||||
|
||||
override loadModel(model: Model): Promise<void> {
|
||||
override loadModel(model: ModelFile): Promise<void> {
|
||||
if (model.engine !== this.provider) return Promise.resolve()
|
||||
this.getNitroProcessHealthIntervalId = setInterval(
|
||||
() => this.periodicallyGetNitroHealth(),
|
||||
|
||||
@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
|
||||
import {
|
||||
log,
|
||||
getSystemResourceInfo,
|
||||
Model,
|
||||
InferenceEngine,
|
||||
ModelSettingParams,
|
||||
PromptTemplate,
|
||||
SystemInformation,
|
||||
getJanDataFolderPath,
|
||||
ModelFile,
|
||||
} from '@janhq/core/node'
|
||||
import { executableNitroFile } from './execute'
|
||||
import terminate from 'terminate'
|
||||
@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
|
||||
*/
|
||||
interface ModelInitOptions {
|
||||
modelFolder: string
|
||||
model: Model
|
||||
model: ModelFile
|
||||
}
|
||||
// The PORT to use for the Nitro subprocess
|
||||
const PORT = 3928
|
||||
|
||||
9
extensions/model-extension/jest.config.js
Normal file
9
extensions/model-extension/jest.config.js
Normal file
@ -0,0 +1,9 @@
|
||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
transform: {
|
||||
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
|
||||
},
|
||||
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
|
||||
}
|
||||
@ -8,6 +8,7 @@
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"test": "jest",
|
||||
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
|
||||
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
||||
},
|
||||
|
||||
@ -27,7 +27,7 @@ export default [
|
||||
// Allow json resolution
|
||||
json(),
|
||||
// Compile TypeScript files
|
||||
typescript({ useTsconfigDeclarationDir: true }),
|
||||
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
|
||||
// Compile TypeScript files
|
||||
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
||||
// commonjs(),
|
||||
@ -62,7 +62,7 @@ export default [
|
||||
// Allow json resolution
|
||||
json(),
|
||||
// Compile TypeScript files
|
||||
typescript({ useTsconfigDeclarationDir: true }),
|
||||
typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
|
||||
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
|
||||
commonjs(),
|
||||
// Allow node_modules resolution, so you can use 'external' to control
|
||||
|
||||
564
extensions/model-extension/src/index.test.ts
Normal file
564
extensions/model-extension/src/index.test.ts
Normal file
@ -0,0 +1,564 @@
|
||||
const readDirSyncMock = jest.fn()
|
||||
const existMock = jest.fn()
|
||||
const readFileSyncMock = jest.fn()
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
fs: {
|
||||
existsSync: existMock,
|
||||
readdirSync: readDirSyncMock,
|
||||
readFileSync: readFileSyncMock,
|
||||
fileStat: () => ({
|
||||
isDirectory: false,
|
||||
}),
|
||||
},
|
||||
dirName: jest.fn(),
|
||||
joinPath: (paths) => paths.join('/'),
|
||||
ModelExtension: jest.fn(),
|
||||
}))
|
||||
|
||||
import JanModelExtension from '.'
|
||||
import { fs, dirName } from '@janhq/core'
|
||||
|
||||
describe('JanModelExtension', () => {
|
||||
let sut: JanModelExtension
|
||||
|
||||
beforeAll(() => {
|
||||
// @ts-ignore
|
||||
sut = new JanModelExtension()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getConfiguredModels', () => {
|
||||
describe("when there's no models are pre-populated", () => {
|
||||
it('should return empty array', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = []
|
||||
existMock.mockReturnValue(true)
|
||||
readDirSyncMock.mockReturnValue([])
|
||||
|
||||
const result = await sut.getConfiguredModels()
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe("when there's are pre-populated models - all flattened", () => {
|
||||
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2']
|
||||
else return ['model.json']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getConfiguredModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model2/model.json',
|
||||
id: '2',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe("when there's are pre-populated models - there are nested folders", () => {
|
||||
it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||
else return ['model.json']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else if (path.includes('model2/model2-1'))
|
||||
return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getConfiguredModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model2/model2-1/model.json',
|
||||
id: '2',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDownloadedModels', () => {
|
||||
describe('no models downloaded', () => {
|
||||
it('should return empty array', async () => {
|
||||
// Mock downloaded models data
|
||||
const downloadedModels = []
|
||||
existMock.mockReturnValue(true)
|
||||
readDirSyncMock.mockReturnValue([])
|
||||
|
||||
const result = await sut.getDownloadedModels()
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
describe('only one model is downloaded', () => {
|
||||
describe('flatten folder', () => {
|
||||
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2']
|
||||
else if (path === 'file://models/model1')
|
||||
return ['model.json', 'test.gguf']
|
||||
else return ['model.json']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getDownloadedModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('all models are downloaded', () => {
|
||||
describe('nested folders', () => {
|
||||
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||
else return ['model.json', 'test.gguf']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getDownloadedModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model2/model2-1/model.json',
|
||||
id: '2',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('all models are downloaded with uppercased GGUF files', () => {
|
||||
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||
else if (path === 'file://models/model1')
|
||||
return ['model.json', 'test.GGUF']
|
||||
else return ['model.json', 'test.gguf']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getDownloadedModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model2/model2-1/model.json',
|
||||
id: '2',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('all models are downloaded - GGUF & Tensort RT', () => {
|
||||
it('returns downloaded models - with correct file_path and model id', async () => {
|
||||
// Mock configured models data
|
||||
const configuredModels = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Model 1',
|
||||
version: '1.0.0',
|
||||
description: 'Model 1 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model1',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
created: new Date(),
|
||||
updated: new Date(),
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
{
|
||||
id: '2',
|
||||
name: 'Model 2',
|
||||
version: '2.0.0',
|
||||
description: 'Model 2 description',
|
||||
object: {
|
||||
type: 'model',
|
||||
uri: 'http://localhost:5000/models/model2',
|
||||
},
|
||||
format: 'onnx',
|
||||
sources: [],
|
||||
parameters: {},
|
||||
settings: {},
|
||||
metadata: {},
|
||||
engine: 'test',
|
||||
} as any,
|
||||
]
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
if (path === 'file://models') return ['model1', 'model2/model2-1']
|
||||
else if (path === 'file://models/model1')
|
||||
return ['model.json', 'test.gguf']
|
||||
else return ['model.json', 'test.engine']
|
||||
})
|
||||
|
||||
readFileSyncMock.mockImplementation((path) => {
|
||||
if (path.includes('model1'))
|
||||
return JSON.stringify(configuredModels[0])
|
||||
else return JSON.stringify(configuredModels[1])
|
||||
})
|
||||
|
||||
const result = await sut.getDownloadedModels()
|
||||
expect(result).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
id: '1',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
file_path: 'file://models/model2/model2-1/model.json',
|
||||
id: '2',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteModel', () => {
|
||||
describe('model is a GGUF model', () => {
|
||||
it('should delete the GGUF file', async () => {
|
||||
fs.unlinkSync = jest.fn()
|
||||
const dirMock = dirName as jest.Mock
|
||||
dirMock.mockReturnValue('file://models/model1')
|
||||
|
||||
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||
|
||||
readDirSyncMock.mockImplementation((path) => {
|
||||
return ['model.json', 'test.gguf']
|
||||
})
|
||||
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
await sut.deleteModel({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
} as any)
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledWith(
|
||||
'file://models/model1/test.gguf'
|
||||
)
|
||||
})
|
||||
|
||||
it('no gguf file presented', async () => {
|
||||
fs.unlinkSync = jest.fn()
|
||||
const dirMock = dirName as jest.Mock
|
||||
dirMock.mockReturnValue('file://models/model1')
|
||||
|
||||
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||
|
||||
readDirSyncMock.mockReturnValue(['model.json'])
|
||||
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
await sut.deleteModel({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
} as any)
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
|
||||
})
|
||||
|
||||
it('delete an imported model', async () => {
|
||||
fs.rm = jest.fn()
|
||||
const dirMock = dirName as jest.Mock
|
||||
dirMock.mockReturnValue('file://models/model1')
|
||||
|
||||
readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
|
||||
|
||||
// MARK: This is a tricky logic implement?
|
||||
// I will just add test for now but will align on the legacy implementation
|
||||
fs.readFileSync = jest.fn().mockReturnValue(
|
||||
JSON.stringify({
|
||||
metadata: {
|
||||
author: 'user',
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
await sut.deleteModel({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
} as any)
|
||||
|
||||
expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
|
||||
})
|
||||
|
||||
it('delete tensorrt-models', async () => {
|
||||
fs.rm = jest.fn()
|
||||
const dirMock = dirName as jest.Mock
|
||||
dirMock.mockReturnValue('file://models/model1')
|
||||
|
||||
readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
|
||||
|
||||
fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
|
||||
|
||||
existMock.mockReturnValue(true)
|
||||
|
||||
await sut.deleteModel({
|
||||
file_path: 'file://models/model1/model.json',
|
||||
} as any)
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -22,6 +22,8 @@ import {
|
||||
getFileSize,
|
||||
AllQuantizations,
|
||||
ModelEvent,
|
||||
ModelFile,
|
||||
dirName,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { extractFileName } from './helpers/path'
|
||||
@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
]
|
||||
private static readonly _tensorRtEngineFormat = '.engine'
|
||||
private static readonly _supportedGpuArch = ['ampere', 'ada']
|
||||
private static readonly _safetensorsRegexs = [
|
||||
/model\.safetensors$/,
|
||||
/model-[0-9]+-of-[0-9]+\.safetensors$/,
|
||||
]
|
||||
private static readonly _pytorchRegexs = [
|
||||
/pytorch_model\.bin$/,
|
||||
/consolidated\.[0-9]+\.pth$/,
|
||||
/pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
|
||||
/.*\.pt$/,
|
||||
]
|
||||
|
||||
interrupted = false
|
||||
|
||||
/**
|
||||
@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* @param filePath - The path to the model file to delete.
|
||||
* @returns A Promise that resolves when the model is deleted.
|
||||
*/
|
||||
async deleteModel(modelId: string): Promise<void> {
|
||||
async deleteModel(model: ModelFile): Promise<void> {
|
||||
try {
|
||||
const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
|
||||
const dirPath = await dirName(model.file_path)
|
||||
const jsonFilePath = await joinPath([
|
||||
dirPath,
|
||||
JanModelExtension._modelMetadataFileName,
|
||||
@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension {
|
||||
await this.readModelMetadata(jsonFilePath)
|
||||
) as Model
|
||||
|
||||
// TODO: This is so tricky?
|
||||
// Should depend on sources?
|
||||
const isUserImportModel =
|
||||
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
||||
if (isUserImportModel) {
|
||||
if (isUserImportModel) {
|
||||
// just delete the folder
|
||||
return fs.rm(dirPath)
|
||||
}
|
||||
@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a model file.
|
||||
* @param model - The model to save.
|
||||
* @returns A Promise that resolves when the model is saved.
|
||||
*/
|
||||
async saveModel(model: Model): Promise<void> {
|
||||
const jsonFilePath = await joinPath([
|
||||
JanModelExtension._homeDir,
|
||||
model.id,
|
||||
JanModelExtension._modelMetadataFileName,
|
||||
])
|
||||
|
||||
try {
|
||||
await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all downloaded models.
|
||||
* @returns A Promise that resolves with an array of all models.
|
||||
*/
|
||||
async getDownloadedModels(): Promise<Model[]> {
|
||||
async getDownloadedModels(): Promise<ModelFile[]> {
|
||||
return await this.getModelsMetadata(
|
||||
async (modelDir: string, model: Model) => {
|
||||
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
|
||||
@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension {
|
||||
): Promise<string | undefined> {
|
||||
// try to find model.json recursively inside each folder
|
||||
if (!(await fs.existsSync(folderFullPath))) return undefined
|
||||
|
||||
const files: string[] = await fs.readdirSync(folderFullPath)
|
||||
if (files.length === 0) return undefined
|
||||
|
||||
if (files.includes(JanModelExtension._modelMetadataFileName)) {
|
||||
return joinPath([
|
||||
folderFullPath,
|
||||
@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
|
||||
private async getModelsMetadata(
|
||||
selector?: (path: string, model: Model) => Promise<boolean>
|
||||
): Promise<Model[]> {
|
||||
): Promise<ModelFile[]> {
|
||||
try {
|
||||
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
|
||||
console.debug('Model folder not found')
|
||||
@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
JanModelExtension._homeDir,
|
||||
dirName,
|
||||
])
|
||||
|
||||
const jsonPath = await this.getModelJsonPath(folderFullPath)
|
||||
|
||||
if (await fs.existsSync(jsonPath)) {
|
||||
@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension {
|
||||
},
|
||||
]
|
||||
}
|
||||
model.file_path = jsonPath
|
||||
model.file_name = JanModelExtension._modelMetadataFileName
|
||||
|
||||
if (selector && !(await selector?.(dirName, model))) {
|
||||
return
|
||||
@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
typeof result.value === 'object'
|
||||
? result.value
|
||||
: JSON.parse(result.value)
|
||||
return model as Model
|
||||
return model as ModelFile
|
||||
} catch {
|
||||
console.debug(`Unable to parse model metadata: ${result.value}`)
|
||||
}
|
||||
@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* Gets all available models.
|
||||
* @returns A Promise that resolves with an array of all models.
|
||||
*/
|
||||
async getConfiguredModels(): Promise<Model[]> {
|
||||
async getConfiguredModels(): Promise<ModelFile[]> {
|
||||
return this.getModelsMetadata()
|
||||
}
|
||||
|
||||
@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
modelBinaryPath: string,
|
||||
modelFolderName: string,
|
||||
modelFolderPath: string
|
||||
): Promise<Model> {
|
||||
): Promise<ModelFile> {
|
||||
const fileStats = await fs.fileStat(modelBinaryPath, true)
|
||||
const binaryFileSize = fileStats.size
|
||||
|
||||
@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension {
|
||||
|
||||
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
|
||||
|
||||
return model
|
||||
return {
|
||||
...model,
|
||||
file_path: modelFilePath,
|
||||
file_name: JanModelExtension._modelMetadataFileName,
|
||||
}
|
||||
}
|
||||
|
||||
async updateModelInfo(modelInfo: Partial<Model>): Promise<Model> {
|
||||
const modelId = modelInfo.id
|
||||
async updateModelInfo(modelInfo: Partial<ModelFile>): Promise<ModelFile> {
|
||||
if (modelInfo.id == null) throw new Error('Model ID is required')
|
||||
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const jsonFilePath = await joinPath([
|
||||
janDataFolderPath,
|
||||
'models',
|
||||
modelId,
|
||||
JanModelExtension._modelMetadataFileName,
|
||||
])
|
||||
const model = JSON.parse(
|
||||
await this.readModelMetadata(jsonFilePath)
|
||||
) as Model
|
||||
await this.readModelMetadata(modelInfo.file_path)
|
||||
) as ModelFile
|
||||
|
||||
const updatedModel: Model = {
|
||||
const updatedModel: ModelFile = {
|
||||
...model,
|
||||
...modelInfo,
|
||||
parameters: {
|
||||
@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension {
|
||||
...model.metadata,
|
||||
...modelInfo.metadata,
|
||||
},
|
||||
// Should not persist file_path & file_name
|
||||
file_path: undefined,
|
||||
file_name: undefined,
|
||||
}
|
||||
|
||||
await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2))
|
||||
await fs.writeFileSync(
|
||||
modelInfo.file_path,
|
||||
JSON.stringify(updatedModel, null, 2)
|
||||
)
|
||||
return updatedModel
|
||||
}
|
||||
|
||||
|
||||
@ -10,5 +10,6 @@
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"]
|
||||
"include": ["./src"],
|
||||
"exclude": ["**/*.test.ts"]
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import {
|
||||
ModelEvent,
|
||||
getJanDataFolderPath,
|
||||
SystemInformation,
|
||||
ModelFile,
|
||||
} from '@janhq/core'
|
||||
|
||||
/**
|
||||
@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
}
|
||||
|
||||
override async loadModel(model: Model): Promise<void> {
|
||||
override async loadModel(model: ModelFile): Promise<void> {
|
||||
if ((await this.installationState()) === 'Installed')
|
||||
return super.loadModel(model)
|
||||
|
||||
|
||||
@ -46,7 +46,6 @@ import {
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
||||
import {
|
||||
configuredModelsAtom,
|
||||
@ -91,8 +90,6 @@ const ModelDropdown = ({
|
||||
const featuredModel = configuredModels.filter((x) =>
|
||||
x.metadata.tags.includes('Featured')
|
||||
)
|
||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
||||
|
||||
const { updateThreadMetadata } = useCreateNewThread()
|
||||
|
||||
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
|
||||
@ -191,27 +188,14 @@ const ModelDropdown = ({
|
||||
],
|
||||
})
|
||||
|
||||
// Default setting ctx_len for the model for a better onboarding experience
|
||||
// TODO: When Cortex support hardware instructions, we should remove this
|
||||
const defaultContextLength = preserveModelSettings
|
||||
? model?.metadata?.default_ctx_len
|
||||
: 2048
|
||||
const defaultMaxTokens = preserveModelSettings
|
||||
? model?.metadata?.default_max_tokens
|
||||
: 2048
|
||||
const overriddenSettings =
|
||||
model?.settings.ctx_len && model.settings.ctx_len > 2048
|
||||
? { ctx_len: defaultContextLength ?? 2048 }
|
||||
: {}
|
||||
const overriddenParameters =
|
||||
model?.parameters.max_tokens && model.parameters.max_tokens
|
||||
? { max_tokens: defaultMaxTokens ?? 2048 }
|
||||
model?.settings.ctx_len && model.settings.ctx_len > 4096
|
||||
? { ctx_len: 4096 }
|
||||
: {}
|
||||
|
||||
const modelParams = {
|
||||
...model?.parameters,
|
||||
...model?.settings,
|
||||
...overriddenParameters,
|
||||
...overriddenSettings,
|
||||
}
|
||||
|
||||
@ -222,6 +206,7 @@ const ModelDropdown = ({
|
||||
if (model)
|
||||
updateModelParameter(activeThread, {
|
||||
params: modelParams,
|
||||
modelPath: model.file_path,
|
||||
modelId: model.id,
|
||||
engine: model.engine,
|
||||
})
|
||||
@ -235,7 +220,6 @@ const ModelDropdown = ({
|
||||
setThreadModelParams,
|
||||
updateModelParameter,
|
||||
updateThreadMetadata,
|
||||
preserveModelSettings,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled'
|
||||
const IGNORE_SSL = 'ignoreSSLFeature'
|
||||
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
|
||||
const QUICK_ASK_ENABLED = 'quickAskEnabled'
|
||||
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
|
||||
|
||||
export const janDataFolderPathAtom = atom('')
|
||||
|
||||
@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
|
||||
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
|
||||
|
||||
export const hostAtom = atom('http://localhost:1337/')
|
||||
|
||||
// This feature is to allow user to cache model settings on thread creation
|
||||
export const preserveModelSettingsAtom = atomWithStorage(
|
||||
PRESERVE_MODEL_SETTINGS,
|
||||
false
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { ImportingModel, Model, InferenceEngine } from '@janhq/core'
|
||||
import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
|
||||
import { atom } from 'jotai'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom(
|
||||
}
|
||||
)
|
||||
|
||||
export const downloadedModelsAtom = atom<Model[]>([])
|
||||
|
||||
export const updateDownloadedModelAtom = atom(
|
||||
null,
|
||||
(get, set, updatedModel: Model) => {
|
||||
const models: Model[] = get(downloadedModelsAtom).map((c) =>
|
||||
c.id === updatedModel.id ? updatedModel : c
|
||||
)
|
||||
|
||||
set(downloadedModelsAtom, models)
|
||||
}
|
||||
)
|
||||
export const downloadedModelsAtom = atom<ModelFile[]>([])
|
||||
|
||||
export const removeDownloadedModelAtom = atom(
|
||||
null,
|
||||
@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom(
|
||||
}
|
||||
)
|
||||
|
||||
export const configuredModelsAtom = atom<Model[]>([])
|
||||
export const configuredModelsAtom = atom<ModelFile[]>([])
|
||||
|
||||
export const defaultModelAtom = atom<Model | undefined>(undefined)
|
||||
|
||||
@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom(
|
||||
}
|
||||
)
|
||||
|
||||
export const selectedModelAtom = atom<Model | undefined>(undefined)
|
||||
export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
|
||||
|
||||
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
|
||||
import { EngineManager, Model } from '@janhq/core'
|
||||
import { EngineManager, Model, ModelFile } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const activeModelAtom = atom<Model | undefined>(undefined)
|
||||
export const activeModelAtom = atom<ModelFile | undefined>(undefined)
|
||||
export const loadModelErrorAtom = atom<string | undefined>(undefined)
|
||||
|
||||
type ModelState = {
|
||||
@ -37,7 +37,7 @@ export function useActiveModel() {
|
||||
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
|
||||
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
|
||||
|
||||
const downloadedModelsRef = useRef<Model[]>([])
|
||||
const downloadedModelsRef = useRef<ModelFile[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
downloadedModelsRef.current = downloadedModels
|
||||
|
||||
@ -7,8 +7,8 @@ import {
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
ThreadState,
|
||||
Model,
|
||||
AssistantTool,
|
||||
ModelFile,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import {
|
||||
experimentalFeatureEnabledAtom,
|
||||
preserveModelSettingsAtom,
|
||||
} from '@/helpers/atoms/AppConfig.atom'
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
threadsAtom,
|
||||
@ -67,7 +64,6 @@ export const useCreateNewThread = () => {
|
||||
const copyOverInstructionEnabled = useAtomValue(
|
||||
copyOverInstructionEnabledAtom
|
||||
)
|
||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
|
||||
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
@ -80,7 +76,7 @@ export const useCreateNewThread = () => {
|
||||
|
||||
const requestCreateNewThread = async (
|
||||
assistant: Assistant,
|
||||
model?: Model | undefined
|
||||
model?: ModelFile | undefined
|
||||
) => {
|
||||
// Stop generating if any
|
||||
setIsGeneratingResponse(false)
|
||||
@ -109,19 +105,13 @@ export const useCreateNewThread = () => {
|
||||
enabled: true,
|
||||
settings: assistant.tools && assistant.tools[0].settings,
|
||||
}
|
||||
const defaultContextLength = preserveModelSettings
|
||||
? defaultModel?.metadata?.default_ctx_len
|
||||
: 2048
|
||||
const defaultMaxTokens = preserveModelSettings
|
||||
? defaultModel?.metadata?.default_max_tokens
|
||||
: 2048
|
||||
const overriddenSettings =
|
||||
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
|
||||
? { ctx_len: defaultContextLength ?? 2048 }
|
||||
? { ctx_len: 4096 }
|
||||
: {}
|
||||
|
||||
const overriddenParameters = defaultModel?.parameters.max_tokens
|
||||
? { max_tokens: defaultMaxTokens ?? 2048 }
|
||||
? { max_tokens: 4096 }
|
||||
: {}
|
||||
|
||||
const createdAt = Date.now()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useCallback } from 'react'
|
||||
|
||||
import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core'
|
||||
import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core'
|
||||
|
||||
import { useSetAtom } from 'jotai'
|
||||
|
||||
@ -13,8 +13,8 @@ export default function useDeleteModel() {
|
||||
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
|
||||
|
||||
const deleteModel = useCallback(
|
||||
async (model: Model) => {
|
||||
await localDeleteModel(model.id)
|
||||
async (model: ModelFile) => {
|
||||
await localDeleteModel(model)
|
||||
removeDownloadedModel(model.id)
|
||||
toaster({
|
||||
title: 'Model Deletion Successful',
|
||||
@ -28,5 +28,7 @@ export default function useDeleteModel() {
|
||||
return { deleteModel }
|
||||
}
|
||||
|
||||
const localDeleteModel = async (id: string) =>
|
||||
extensionManager.get<ModelExtension>(ExtensionTypeEnum.Model)?.deleteModel(id)
|
||||
const localDeleteModel = async (model: ModelFile) =>
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.deleteModel(model)
|
||||
|
||||
@ -5,6 +5,7 @@ import {
|
||||
Model,
|
||||
ModelEvent,
|
||||
ModelExtension,
|
||||
ModelFile,
|
||||
events,
|
||||
} from '@janhq/core'
|
||||
|
||||
@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise<Model | undefined> =>
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.getDefaultModel()
|
||||
|
||||
const getLocalConfiguredModels = async (): Promise<Model[]> =>
|
||||
const getLocalConfiguredModels = async (): Promise<ModelFile[]> =>
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.getConfiguredModels() ?? []
|
||||
|
||||
const getLocalDownloadedModels = async (): Promise<Model[]> =>
|
||||
const getLocalDownloadedModels = async (): Promise<ModelFile[]> =>
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.getDownloadedModels() ?? []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
|
||||
import { Model, InferenceEngine } from '@janhq/core'
|
||||
import { Model, InferenceEngine, ModelFile } from '@janhq/core'
|
||||
|
||||
import { atom, useAtomValue } from 'jotai'
|
||||
|
||||
@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id'
|
||||
*/
|
||||
export default function useRecommendedModel() {
|
||||
const activeModel = useAtomValue(activeModelAtom)
|
||||
const [sortedModels, setSortedModels] = useState<Model[]>([])
|
||||
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
|
||||
const [sortedModels, setSortedModels] = useState<ModelFile[]>([])
|
||||
const [recommendedModel, setRecommendedModel] = useState<
|
||||
ModelFile | undefined
|
||||
>()
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||
|
||||
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
|
||||
const getAndSortDownloadedModels = useCallback(async (): Promise<
|
||||
ModelFile[]
|
||||
> => {
|
||||
const models = downloadedModels.sort((a, b) =>
|
||||
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
|
||||
? 1
|
||||
|
||||
@ -4,8 +4,6 @@ import {
|
||||
ConversationalExtension,
|
||||
ExtensionTypeEnum,
|
||||
InferenceEngine,
|
||||
Model,
|
||||
ModelExtension,
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
} from '@janhq/core'
|
||||
@ -17,14 +15,8 @@ import {
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import useRecommendedModel from './useRecommendedModel'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import {
|
||||
selectedModelAtom,
|
||||
updateDownloadedModelAtom,
|
||||
} from '@/helpers/atoms/Model.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
ModelParams,
|
||||
getActiveThreadModelParamsAtom,
|
||||
@ -34,16 +26,14 @@ import {
|
||||
export type UpdateModelParameter = {
|
||||
params?: ModelParams
|
||||
modelId?: string
|
||||
modelPath?: string
|
||||
engine?: InferenceEngine
|
||||
}
|
||||
|
||||
export default function useUpdateModelParameters() {
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||
const [selectedModel] = useAtom(selectedModelAtom)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
|
||||
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
|
||||
const { recommendedModel, setRecommendedModel } = useRecommendedModel()
|
||||
|
||||
const updateModelParameter = useCallback(
|
||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
||||
@ -83,50 +73,8 @@ export default function useUpdateModelParameters() {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
|
||||
// Persists default settings to model file
|
||||
// Do not overwrite ctx_len and max_tokens
|
||||
if (preserveModelFeatureEnabled) {
|
||||
const defaultContextLength = settingParams.ctx_len
|
||||
const defaultMaxTokens = runtimeParams.max_tokens
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
|
||||
const { ctx_len, ...toSaveSettings } = settingParams
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
|
||||
const { max_tokens, ...toSaveParams } = runtimeParams
|
||||
|
||||
const updatedModel = {
|
||||
id: settings.modelId ?? selectedModel?.id,
|
||||
parameters: {
|
||||
...toSaveSettings,
|
||||
},
|
||||
settings: {
|
||||
...toSaveParams,
|
||||
},
|
||||
metadata: {
|
||||
default_ctx_len: defaultContextLength,
|
||||
default_max_tokens: defaultMaxTokens,
|
||||
},
|
||||
} as Partial<Model>
|
||||
|
||||
const model = await extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.updateModelInfo(updatedModel)
|
||||
if (model) updateDownloadedModel(model)
|
||||
if (selectedModel?.id === model?.id) setSelectedModel(model)
|
||||
if (recommendedModel?.id === model?.id) setRecommendedModel(model)
|
||||
}
|
||||
},
|
||||
[
|
||||
activeModelParams,
|
||||
selectedModel,
|
||||
setThreadModelParams,
|
||||
preserveModelFeatureEnabled,
|
||||
updateDownloadedModel,
|
||||
setSelectedModel,
|
||||
recommendedModel,
|
||||
setRecommendedModel,
|
||||
]
|
||||
[activeModelParams, selectedModel, setThreadModelParams]
|
||||
)
|
||||
|
||||
const processStopWords = (params: ModelParams): ModelParams => {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useCallback } from 'react'
|
||||
|
||||
import { Model } from '@janhq/core'
|
||||
import { ModelFile } from '@janhq/core'
|
||||
import { Button, Badge, Tooltip } from '@janhq/joi'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
@ -38,7 +38,7 @@ import {
|
||||
} from '@/helpers/atoms/SystemBar.atom'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
model: ModelFile
|
||||
onClick: () => void
|
||||
open: string
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
|
||||
import { Model } from '@janhq/core'
|
||||
import { ModelFile } from '@janhq/core'
|
||||
import { Badge } from '@janhq/joi'
|
||||
|
||||
import { twMerge } from 'tailwind-merge'
|
||||
@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader'
|
||||
import { toGibibytes } from '@/utils/converter'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
model: ModelFile
|
||||
}
|
||||
|
||||
const ModelItem: React.FC<Props> = ({ model }) => {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useMemo } from 'react'
|
||||
|
||||
import { Model } from '@janhq/core'
|
||||
import { ModelFile } from '@janhq/core'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
type Props = {
|
||||
models: Model[]
|
||||
models: ModelFile[]
|
||||
}
|
||||
|
||||
const ModelList = ({ models }: Props) => {
|
||||
const downloadedModels = useAtomValue(downloadedModelsAtom)
|
||||
const sortedModels: Model[] = useMemo(() => {
|
||||
const featuredModels: Model[] = []
|
||||
const remoteModels: Model[] = []
|
||||
const localModels: Model[] = []
|
||||
const remainingModels: Model[] = []
|
||||
const sortedModels: ModelFile[] = useMemo(() => {
|
||||
const featuredModels: ModelFile[] = []
|
||||
const remoteModels: ModelFile[] = []
|
||||
const localModels: ModelFile[] = []
|
||||
const remainingModels: ModelFile[] = []
|
||||
models.forEach((m) => {
|
||||
if (m.metadata?.tags?.includes('Featured')) {
|
||||
featuredModels.push(m)
|
||||
|
||||
@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC<Props> = ({
|
||||
const { requestCreateNewThread } = useCreateNewThread()
|
||||
const setMainViewState = useSetAtom(mainViewStateAtom)
|
||||
const assistants = useAtomValue(assistantsAtom)
|
||||
const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null
|
||||
const downloadedModel = downloadedModels.find((md) => md.id === fileName)
|
||||
|
||||
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
|
||||
const defaultModel = useAtomValue(defaultModelAtom)
|
||||
@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC<Props> = ({
|
||||
alert('No assistant available')
|
||||
return
|
||||
}
|
||||
await requestCreateNewThread(assistants[0], model)
|
||||
await requestCreateNewThread(assistants[0], downloadedModel)
|
||||
setMainViewState(MainViewState.Thread)
|
||||
setHfImportingStage('NONE')
|
||||
}, [
|
||||
assistants,
|
||||
model,
|
||||
downloadedModel,
|
||||
requestCreateNewThread,
|
||||
setMainViewState,
|
||||
setHfImportingStage,
|
||||
@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC<Props> = ({
|
||||
</Badge>
|
||||
</div>
|
||||
|
||||
{isDownloaded ? (
|
||||
{downloadedModel ? (
|
||||
<Button
|
||||
variant="soft"
|
||||
className="min-w-[98px]"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { memo, useState } from 'react'
|
||||
|
||||
import { InferenceEngine, Model } from '@janhq/core'
|
||||
import { InferenceEngine, ModelFile } from '@janhq/core'
|
||||
import { Badge, Button, Tooltip, useClickOutside } from '@janhq/joi'
|
||||
import { useAtom } from 'jotai'
|
||||
import {
|
||||
@ -21,7 +21,7 @@ import { localEngines } from '@/utils/modelEngine'
|
||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
model: ModelFile
|
||||
groupTitle?: string
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user