From 98bef7b7cffa811a67945d8c8f4659862c15026c Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 17 Sep 2024 08:34:58 +0700 Subject: [PATCH 01/37] test: add model parameter validation rules and persistence tests (#3618) * test: add model parameter validation rules and persistence tests * chore: fix CI cov step * fix: invalid model settings should fallback to origin value * test: support fallback integer settings --- .../src/node/index.ts | 32 +- .../tensorrt-llm-extension/src/node/index.ts | 4 +- web/containers/Providers/EventHandler.tsx | 4 +- web/containers/SliderRightPanel/index.tsx | 28 +- web/hooks/useSendChatMessage.ts | 9 +- web/hooks/useUpdateModelParameters.test.ts | 314 ++++++++++++++++++ web/hooks/useUpdateModelParameters.ts | 16 +- .../LocalServerRightPanel/index.tsx | 13 +- web/screens/Thread/ThreadRightPanel/index.tsx | 26 +- web/utils/modelParam.test.ts | 183 ++++++++++ web/utils/modelParam.ts | 106 +++++- 11 files changed, 681 insertions(+), 54 deletions(-) create mode 100644 web/hooks/useUpdateModelParameters.test.ts create mode 100644 web/utils/modelParam.test.ts diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index edc2d013d..3a969ad5e 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise { if (!settings?.ngl) { settings.ngl = 100 } - log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`) + log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`) return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, { method: 'POST', headers: { @@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise { }) .then((res) => { log( - `[CORTEX]::Debug: Load model success with response ${JSON.stringify( + `[CORTEX]:: Load model success with response ${JSON.stringify( res )}` ) @@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise { async function validateModelStatus(modelId: string): Promise { // Send a GET request to the validation URL. // Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries. - log(`[CORTEX]::Debug: Validating model ${modelId}`) + log(`[CORTEX]:: Validating model ${modelId}`) return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, { method: 'POST', body: JSON.stringify({ @@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise { retryDelay: 300, }).then(async (res: Response) => { log( - `[CORTEX]::Debug: Validate model state with response ${JSON.stringify( + `[CORTEX]:: Validate model state with response ${JSON.stringify( res.status )}` ) @@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise { // Otherwise, return an object with an error message. if (body.model_loaded) { log( - `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify( + `[CORTEX]:: Validate model state success with response ${JSON.stringify( body )}` ) @@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise { } const errorBody = await res.text() log( - `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( + `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify( res.statusText )}` ) @@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise { async function killSubprocess(): Promise { const controller = new AbortController() setTimeout(() => controller.abort(), 5000) - log(`[CORTEX]::Debug: Request to kill cortex`) + log(`[CORTEX]:: Request to kill cortex`) const killRequest = () => { return fetch(NITRO_HTTP_KILL_URL, { @@ -321,17 +321,17 @@ async function killSubprocess(): Promise { .then(() => tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) ) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch((err) => { log( - `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` ) throw 'PORT_NOT_AVAILABLE' }) } if (subprocess?.pid && process.platform !== 'darwin') { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -341,7 +341,7 @@ async function killSubprocess(): Promise { } else { tcpPortUsed .waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .then(() => resolve()) .catch(() => { log( @@ -362,7 +362,7 @@ async function killSubprocess(): Promise { * @returns A promise that resolves when the Nitro subprocess is started. */ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { - log(`[CORTEX]::Debug: Spawning cortex subprocess...`) + log(`[CORTEX]:: Spawning cortex subprocess...`) return new Promise(async (resolve, reject) => { let executableOptions = executableNitroFile( @@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { const args: string[] = ['1', LOCAL_HOST, PORT.toString()] // Execute the binary log( - `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` + `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}` ) log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`) @@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { // Handle subprocess output subprocess.stdout.on('data', (data: any) => { - log(`[CORTEX]::Debug: ${data}`) + log(`[CORTEX]:: ${data}`) }) subprocess.stderr.on('data', (data: any) => { @@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { }) subprocess.on('close', (code: any) => { - log(`[CORTEX]::Debug: cortex exited with code: ${code}`) + log(`[CORTEX]:: cortex exited with code: ${code}`) subprocess = undefined reject(`child process exited with code ${code}`) }) @@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise { tcpPortUsed .waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000) .then(() => { - log(`[CORTEX]::Debug: cortex is ready`) + log(`[CORTEX]:: cortex is ready`) resolve() }) }) diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts index c8bc48459..77003389f 100644 --- a/extensions/tensorrt-llm-extension/src/node/index.ts +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -97,7 +97,7 @@ function unloadModel(): Promise { } if (subprocess?.pid) { - log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`) + log(`[CORTEX]:: Killing PID ${subprocess.pid}`) const pid = subprocess.pid return new Promise((resolve, reject) => { terminate(pid, function (err) { @@ -107,7 +107,7 @@ function unloadModel(): Promise { return tcpPortUsed .waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000) .then(() => resolve()) - .then(() => log(`[CORTEX]::Debug: cortex process is terminated`)) + .then(() => log(`[CORTEX]:: cortex process is terminated`)) .catch(() => { killRequest() }) diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index e4c96aeb7..4809ce83e 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -20,7 +20,7 @@ import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' -import { toRuntimeParams } from '@/utils/modelParam' +import { extractInferenceParams } from '@/utils/modelParam' import { extensionManager } from '@/extension' import { @@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) { }, ] - const runtimeParams = toRuntimeParams(activeModelParamsRef.current) + const runtimeParams = extractInferenceParams(activeModelParamsRef.current) const messageRequest: MessageRequest = { id: msgId, diff --git a/web/containers/SliderRightPanel/index.tsx b/web/containers/SliderRightPanel/index.tsx index df415ffb5..c00d9f002 100644 --- a/web/containers/SliderRightPanel/index.tsx +++ b/web/containers/SliderRightPanel/index.tsx @@ -87,26 +87,28 @@ const SliderRightPanel = ({ onValueChanged?.(Number(min)) setVal(min.toString()) setShowTooltip({ max: false, min: true }) + } else { + setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5 } }} onChange={(e) => { - // Should not accept invalid value or NaN - // E.g. anything changes that trigger onValueChanged - // Which is incorrect - if (Number(e.target.value) > Number(max)) { - setVal(max.toString()) - } else if ( - Number(e.target.value) < Number(min) || - !e.target.value.length - ) { - setVal(min.toString()) - } else if (Number.isNaN(Number(e.target.value))) return - - onValueChanged?.(Number(e.target.value)) // TODO: How to support negative number input? + // Passthru since it validates again onBlur if (/^\d*\.?\d*$/.test(e.target.value)) { setVal(e.target.value) } + + // Should not accept invalid value or NaN + // E.g. anything changes that trigger onValueChanged + // Which is incorrect + if ( + Number(e.target.value) > Number(max) || + Number(e.target.value) < Number(min) || + Number.isNaN(Number(e.target.value)) + ) { + return + } + onValueChanged?.(Number(e.target.value)) }} /> } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 8c6013505..1dbd5b45e 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -23,7 +23,10 @@ import { import { Stack } from '@/utils/Stack' import { compressImage, getBase64 } from '@/utils/base64' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' @@ -189,8 +192,8 @@ export default function useSendChatMessage() { if (engineParamsUpdate) setReloadModel(true) - const runtimeParams = toRuntimeParams(activeModelParams) - const settingParams = toSettingParams(activeModelParams) + const runtimeParams = extractInferenceParams(activeModelParams) + const settingParams = extractModelLoadParams(activeModelParams) const prompt = message.trim() diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts new file mode 100644 index 000000000..bc60aa631 --- /dev/null +++ b/web/hooks/useUpdateModelParameters.test.ts @@ -0,0 +1,314 @@ +import { renderHook, act } from '@testing-library/react' +// Mock dependencies +jest.mock('ulidx') +jest.mock('@/extension') + +import useUpdateModelParameters from './useUpdateModelParameters' +import { extensionManager } from '@/extension' + +// Mock data +let model: any = { + id: 'model-1', + engine: 'nitro', +} + +let extension: any = { + saveThread: jest.fn(), +} + +const mockThread: any = { + id: 'thread-1', + assistants: [ + { + model: { + parameters: {}, + settings: {}, + }, + }, + ], + object: 'thread', + title: 'New Thread', + created: 0, + updated: 0, +} + +describe('useUpdateModelParameters', () => { + beforeAll(() => { + jest.clearAllMocks() + jest.mock('./useRecommendedModel', () => ({ + useRecommendedModel: () => ({ + recommendedModel: model, + setRecommendedModel: jest.fn(), + downloadedModels: [], + }), + })) + }) + + it('should update model parameters and save thread when params are valid', async () => { + const mockValidParameters: any = { + params: { + // Inference + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, + + // Load model + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter(mockThread, mockValidParameters) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, + }, + settings: { + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should not update invalid model parameters', async () => { + const mockInvalidParameters: any = { + params: { + // Inference + stop: [1, ''], + temperature: '0.5', + token_limit: '1000', + top_k: '0.7', + top_p: '0.1', + stream: 'true', + max_tokens: '1000', + frequency_penalty: '0.3', + presence_penalty: '0.2', + + // Load model + ctx_len: '1024', + ngl: '12', + embedding: 'true', + n_parallel: '2', + cpu_threads: '4', + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockInvalidParameters + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + max_tokens: 1000, + token_limit: 1000, + }, + settings: { + cpu_threads: 4, + ctx_len: 1024, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + n_parallel: 2, + ngl: 12, + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should update valid model parameters only', async () => { + const mockInvalidParameters: any = { + params: { + // Inference + stop: [''], + temperature: -0.5, + token_limit: 100.2, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 1.2, + presence_penalty: 0.2, + + // Load model + ctx_len: 1024, + ngl: 0, + embedding: 'true', + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + vision_model: 'vision', + text_model: 'text', + }, + modelId: 'model-1', + engine: 'nitro', + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockInvalidParameters + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: [''], + top_k: 0.7, + top_p: 0.1, + stream: true, + token_limit: 100, + max_tokens: 1000, + presence_penalty: 0.2, + }, + settings: { + ctx_len: 1024, + ngl: 0, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) + + it('should handle missing modelId and engine gracefully', async () => { + const mockParametersWithoutModelIdAndEngine: any = { + params: { + stop: ['', ''], + temperature: 0.5, + }, + } + + // Spy functions + jest.spyOn(extensionManager, 'get').mockReturnValue(extension) + jest.spyOn(extension, 'saveThread').mockReturnValue({}) + + const { result } = renderHook(() => useUpdateModelParameters()) + + await act(async () => { + await result.current.updateModelParameter( + mockThread, + mockParametersWithoutModelIdAndEngine + ) + }) + + // Check if the model parameters are valid before persisting + expect(extension.saveThread).toHaveBeenCalledWith({ + assistants: [ + { + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + }, + settings: {}, + }, + }, + ], + created: 0, + id: 'thread-1', + object: 'thread', + title: 'New Thread', + updated: 0, + }) + }) +}) diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 79d877456..46bf07cd5 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -12,7 +12,10 @@ import { import { useAtom, useAtomValue, useSetAtom } from 'jotai' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import useRecommendedModel from './useRecommendedModel' @@ -47,12 +50,17 @@ export default function useUpdateModelParameters() { const toUpdateSettings = processStopWords(settings.params ?? {}) const updatedModelParams = settings.modelId ? toUpdateSettings - : { ...activeModelParams, ...toUpdateSettings } + : { + ...selectedModel?.parameters, + ...selectedModel?.settings, + ...activeModelParams, + ...toUpdateSettings, + } // update the state setThreadModelParams(thread.id, updatedModelParams) - const runtimeParams = toRuntimeParams(updatedModelParams) - const settingParams = toSettingParams(updatedModelParams) + const runtimeParams = extractInferenceParams(updatedModelParams) + const settingParams = extractModelLoadParams(updatedModelParams) const assistants = thread.assistants.map( (assistant: ThreadAssistantInfo) => { diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx index 309709c26..13e3cad57 100644 --- a/web/screens/LocalServer/LocalServerRightPanel/index.tsx +++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx @@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel' import { getConfigurationsData } from '@/utils/componentSettings' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' @@ -27,16 +30,18 @@ const LocalServerRightPanel = () => { const selectedModel = useAtomValue(selectedModelAtom) const [currentModelSettingParams, setCurrentModelSettingParams] = useState( - toSettingParams(selectedModel?.settings) + extractModelLoadParams(selectedModel?.settings) ) useEffect(() => { if (selectedModel) { - setCurrentModelSettingParams(toSettingParams(selectedModel?.settings)) + setCurrentModelSettingParams( + extractModelLoadParams(selectedModel?.settings) + ) } }, [selectedModel]) - const modelRuntimeParams = toRuntimeParams(selectedModel?.settings) + const modelRuntimeParams = extractInferenceParams(selectedModel?.settings) const componentDataRuntimeSetting = getConfigurationsData( modelRuntimeParams, diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 9e7cdf7d8..e7d0a27b9 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { getConfigurationsData } from '@/utils/componentSettings' import { localEngines } from '@/utils/modelEngine' -import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { + extractInferenceParams, + extractModelLoadParams, +} from '@/utils/modelParam' import PromptTemplateSetting from './PromptTemplateSetting' import Tools from './Tools' @@ -68,14 +71,26 @@ const ThreadRightPanel = () => { const settings = useMemo(() => { // runtime setting - const modelRuntimeParams = toRuntimeParams(activeModelParams) + const modelRuntimeParams = extractInferenceParams( + { + ...selectedModel?.parameters, + ...activeModelParams, + }, + selectedModel?.parameters + ) const componentDataRuntimeSetting = getConfigurationsData( modelRuntimeParams, selectedModel ).filter((x) => x.key !== 'prompt_template') // engine setting - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams( + { + ...selectedModel?.settings, + ...activeModelParams, + }, + selectedModel?.settings + ) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel @@ -126,7 +141,10 @@ const ThreadRightPanel = () => { }, [activeModelParams, selectedModel]) const promptTemplateSettings = useMemo(() => { - const modelEngineParams = toSettingParams(activeModelParams) + const modelEngineParams = extractModelLoadParams({ + ...selectedModel?.settings, + ...activeModelParams, + }) const componentDataEngineSetting = getConfigurationsData( modelEngineParams, selectedModel diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts new file mode 100644 index 000000000..f1b858955 --- /dev/null +++ b/web/utils/modelParam.test.ts @@ -0,0 +1,183 @@ +// web/utils/modelParam.test.ts +import { normalizeValue, validationRules } from './modelParam' + +describe('validationRules', () => { + it('should validate temperature correctly', () => { + expect(validationRules.temperature(0.5)).toBe(true) + expect(validationRules.temperature(2)).toBe(true) + expect(validationRules.temperature(0)).toBe(true) + expect(validationRules.temperature(-0.1)).toBe(false) + expect(validationRules.temperature(2.3)).toBe(false) + expect(validationRules.temperature('0.5')).toBe(false) + }) + + it('should validate token_limit correctly', () => { + expect(validationRules.token_limit(100)).toBe(true) + expect(validationRules.token_limit(1)).toBe(true) + expect(validationRules.token_limit(0)).toBe(true) + expect(validationRules.token_limit(-1)).toBe(false) + expect(validationRules.token_limit('100')).toBe(false) + }) + + it('should validate top_k correctly', () => { + expect(validationRules.top_k(0.5)).toBe(true) + expect(validationRules.top_k(1)).toBe(true) + expect(validationRules.top_k(0)).toBe(true) + expect(validationRules.top_k(-0.1)).toBe(false) + expect(validationRules.top_k(1.1)).toBe(false) + expect(validationRules.top_k('0.5')).toBe(false) + }) + + it('should validate top_p correctly', () => { + expect(validationRules.top_p(0.5)).toBe(true) + expect(validationRules.top_p(1)).toBe(true) + expect(validationRules.top_p(0)).toBe(true) + expect(validationRules.top_p(-0.1)).toBe(false) + expect(validationRules.top_p(1.1)).toBe(false) + expect(validationRules.top_p('0.5')).toBe(false) + }) + + it('should validate stream correctly', () => { + expect(validationRules.stream(true)).toBe(true) + expect(validationRules.stream(false)).toBe(true) + expect(validationRules.stream('true')).toBe(false) + expect(validationRules.stream(1)).toBe(false) + }) + + it('should validate max_tokens correctly', () => { + expect(validationRules.max_tokens(100)).toBe(true) + expect(validationRules.max_tokens(1)).toBe(true) + expect(validationRules.max_tokens(0)).toBe(true) + expect(validationRules.max_tokens(-1)).toBe(false) + expect(validationRules.max_tokens('100')).toBe(false) + }) + + it('should validate stop correctly', () => { + expect(validationRules.stop(['word1', 'word2'])).toBe(true) + expect(validationRules.stop([])).toBe(true) + expect(validationRules.stop(['word1', 2])).toBe(false) + expect(validationRules.stop('word1')).toBe(false) + }) + + it('should validate frequency_penalty correctly', () => { + expect(validationRules.frequency_penalty(0.5)).toBe(true) + expect(validationRules.frequency_penalty(1)).toBe(true) + expect(validationRules.frequency_penalty(0)).toBe(true) + expect(validationRules.frequency_penalty(-0.1)).toBe(false) + expect(validationRules.frequency_penalty(1.1)).toBe(false) + expect(validationRules.frequency_penalty('0.5')).toBe(false) + }) + + it('should validate presence_penalty correctly', () => { + expect(validationRules.presence_penalty(0.5)).toBe(true) + expect(validationRules.presence_penalty(1)).toBe(true) + expect(validationRules.presence_penalty(0)).toBe(true) + expect(validationRules.presence_penalty(-0.1)).toBe(false) + expect(validationRules.presence_penalty(1.1)).toBe(false) + expect(validationRules.presence_penalty('0.5')).toBe(false) + }) + + it('should validate ctx_len correctly', () => { + expect(validationRules.ctx_len(1024)).toBe(true) + expect(validationRules.ctx_len(1)).toBe(true) + expect(validationRules.ctx_len(0)).toBe(true) + expect(validationRules.ctx_len(-1)).toBe(false) + expect(validationRules.ctx_len('1024')).toBe(false) + }) + + it('should validate ngl correctly', () => { + expect(validationRules.ngl(12)).toBe(true) + expect(validationRules.ngl(1)).toBe(true) + expect(validationRules.ngl(0)).toBe(true) + expect(validationRules.ngl(-1)).toBe(false) + expect(validationRules.ngl('12')).toBe(false) + }) + + it('should validate embedding correctly', () => { + expect(validationRules.embedding(true)).toBe(true) + expect(validationRules.embedding(false)).toBe(true) + expect(validationRules.embedding('true')).toBe(false) + expect(validationRules.embedding(1)).toBe(false) + }) + + it('should validate n_parallel correctly', () => { + expect(validationRules.n_parallel(2)).toBe(true) + expect(validationRules.n_parallel(1)).toBe(true) + expect(validationRules.n_parallel(0)).toBe(true) + expect(validationRules.n_parallel(-1)).toBe(false) + expect(validationRules.n_parallel('2')).toBe(false) + }) + + it('should validate cpu_threads correctly', () => { + expect(validationRules.cpu_threads(4)).toBe(true) + expect(validationRules.cpu_threads(1)).toBe(true) + expect(validationRules.cpu_threads(0)).toBe(true) + expect(validationRules.cpu_threads(-1)).toBe(false) + expect(validationRules.cpu_threads('4')).toBe(false) + }) + + it('should validate prompt_template correctly', () => { + expect(validationRules.prompt_template('template')).toBe(true) + expect(validationRules.prompt_template('')).toBe(true) + expect(validationRules.prompt_template(123)).toBe(false) + }) + + it('should validate llama_model_path correctly', () => { + expect(validationRules.llama_model_path('path')).toBe(true) + expect(validationRules.llama_model_path('')).toBe(true) + expect(validationRules.llama_model_path(123)).toBe(false) + }) + + it('should validate mmproj correctly', () => { + expect(validationRules.mmproj('mmproj')).toBe(true) + expect(validationRules.mmproj('')).toBe(true) + expect(validationRules.mmproj(123)).toBe(false) + }) + + it('should validate vision_model correctly', () => { + expect(validationRules.vision_model(true)).toBe(true) + expect(validationRules.vision_model(false)).toBe(true) + expect(validationRules.vision_model('true')).toBe(false) + expect(validationRules.vision_model(1)).toBe(false) + }) + + it('should validate text_model correctly', () => { + expect(validationRules.text_model(true)).toBe(true) + expect(validationRules.text_model(false)).toBe(true) + expect(validationRules.text_model('true')).toBe(false) + expect(validationRules.text_model(1)).toBe(false) + }) +}) + +describe('normalizeValue', () => { + it('should normalize ctx_len correctly', () => { + expect(normalizeValue('ctx_len', 100.5)).toBe(100) + expect(normalizeValue('ctx_len', '2')).toBe(2) + expect(normalizeValue('ctx_len', 100)).toBe(100) + }) + it('should normalize token_limit correctly', () => { + expect(normalizeValue('token_limit', 100.5)).toBe(100) + expect(normalizeValue('token_limit', '1')).toBe(1) + expect(normalizeValue('token_limit', 0)).toBe(0) + }) + it('should normalize max_tokens correctly', () => { + expect(normalizeValue('max_tokens', 100.5)).toBe(100) + expect(normalizeValue('max_tokens', '1')).toBe(1) + expect(normalizeValue('max_tokens', 0)).toBe(0) + }) + it('should normalize ngl correctly', () => { + expect(normalizeValue('ngl', 12.5)).toBe(12) + expect(normalizeValue('ngl', '2')).toBe(2) + expect(normalizeValue('ngl', 0)).toBe(0) + }) + it('should normalize n_parallel correctly', () => { + expect(normalizeValue('n_parallel', 2.5)).toBe(2) + expect(normalizeValue('n_parallel', '2')).toBe(2) + expect(normalizeValue('n_parallel', 0)).toBe(0) + }) + it('should normalize cpu_threads correctly', () => { + expect(normalizeValue('cpu_threads', 4.5)).toBe(4) + expect(normalizeValue('cpu_threads', '4')).toBe(4) + expect(normalizeValue('cpu_threads', 0)).toBe(0) + }) +}) diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts index a6d144c3e..dda9cf761 100644 --- a/web/utils/modelParam.ts +++ b/web/utils/modelParam.ts @@ -1,9 +1,69 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/naming-convention */ import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' import { ModelParams } from '@/helpers/atoms/Thread.atom' -export const toRuntimeParams = ( - modelParams?: ModelParams +/** + * Validation rules for model parameters + */ +export const validationRules: { [key: string]: (value: any) => boolean } = { + temperature: (value: any) => + typeof value === 'number' && value >= 0 && value <= 2, + token_limit: (value: any) => Number.isInteger(value) && value >= 0, + top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + stream: (value: any) => typeof value === 'boolean', + max_tokens: (value: any) => Number.isInteger(value) && value >= 0, + stop: (value: any) => + Array.isArray(value) && value.every((v) => typeof v === 'string'), + frequency_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + presence_penalty: (value: any) => + typeof value === 'number' && value >= 0 && value <= 1, + + ctx_len: (value: any) => Number.isInteger(value) && value >= 0, + ngl: (value: any) => Number.isInteger(value) && value >= 0, + embedding: (value: any) => typeof value === 'boolean', + n_parallel: (value: any) => Number.isInteger(value) && value >= 0, + cpu_threads: (value: any) => Number.isInteger(value) && value >= 0, + prompt_template: (value: any) => typeof value === 'string', + llama_model_path: (value: any) => typeof value === 'string', + mmproj: (value: any) => typeof value === 'string', + vision_model: (value: any) => typeof value === 'boolean', + text_model: (value: any) => typeof value === 'boolean', +} + +/** + * There are some parameters that need to be normalized before being sent to the server + * E.g. ctx_len should be an integer, but it can be a float from the input field + * @param key + * @param value + * @returns + */ +export const normalizeValue = (key: string, value: any) => { + if ( + key === 'token_limit' || + key === 'max_tokens' || + key === 'ctx_len' || + key === 'ngl' || + key === 'n_parallel' || + key === 'cpu_threads' + ) { + // Convert to integer + return Math.floor(Number(value)) + } + return value +} + +/** + * Extract inference parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractInferenceParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelRuntimeParams => { if (!modelParams) return {} const defaultModelParams: ModelRuntimeParams = { @@ -22,15 +82,35 @@ export const toRuntimeParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultModelParams) { - Object.assign(runtimeParams, { ...runtimeParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(runtimeParams, { + ...runtimeParams, + [key]: normalizeValue(key, value), + }) + } } } return runtimeParams } -export const toSettingParams = ( - modelParams?: ModelParams +/** + * Extract model load parameters from flat model parameters + * @param modelParams + * @returns + */ +export const extractModelLoadParams = ( + modelParams?: ModelParams, + originParams?: ModelParams ): ModelSettingParams => { if (!modelParams) return {} const defaultSettingParams: ModelSettingParams = { @@ -49,7 +129,21 @@ export const toSettingParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultSettingParams) { - Object.assign(settingParams, { ...settingParams, [key]: value }) + const validate = validationRules[key] + if (validate && !validate(normalizeValue(key, value))) { + // Invalid value - fall back to origin value + if (originParams && key in originParams) { + Object.assign(modelParams, { + ...modelParams, + [key]: originParams[key as keyof typeof originParams], + }) + } + } else { + Object.assign(settingParams, { + ...settingParams, + [key]: normalizeValue(key, value), + }) + } } } From 670013baa037003f82c29607345446a03dc07c0c Mon Sep 17 00:00:00 2001 From: Ronnie Ghose <1313566+RONNCC@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:25:08 -0700 Subject: [PATCH 02/37] Add support for 'o1-preview' and 'o1-mini' models (#3659) Add support for 'o1-preview' and 'o1-mini' model names in the OpenAI API. * **Update `models.json`**: - Add 'o1-preview' model details with appropriate parameters and metadata. - Add 'o1-mini' model details with appropriate parameters and metadata. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/janhq/jan?shareId=XXXX-XXXX-XXXX-XXXX). --- .../resources/models.json | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json index 6852a1892..72517d540 100644 --- a/extensions/inference-openai-extension/resources/models.json +++ b/extensions/inference-openai-extension/resources/models.json @@ -119,5 +119,65 @@ ] }, "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-preview", + "object": "model", + "name": "OpenAI o1-preview", + "version": "1.0", + "description": "OpenAI o1-preview is a new model with complex reasoning", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" + }, + { + "sources": [ + { + "url": "https://openai.com" + } + ], + "id": "o1-mini", + "object": "model", + "name": "OpenAI o1-mini", + "version": "1.0", + "description": "OpenAI o1-mini is a lightweight reasoning model", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.95, + "stream": true, + "stop": [], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "metadata": { + "author": "OpenAI", + "tags": [ + "General" + ] + }, + "engine": "openai" } ] From c8a08f11155a64a2789a233f1518a0df041df623 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 17 Sep 2024 09:25:55 +0700 Subject: [PATCH 03/37] fix: correct prompt template for Phi3 Medium model (#3670) --- extensions/inference-nitro-extension/package.json | 2 +- .../resources/models/phi3-medium/model.json | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json index 425e4b49c..ac3ed180a 100644 --- a/extensions/inference-nitro-extension/package.json +++ b/extensions/inference-nitro-extension/package.json @@ -1,7 +1,7 @@ { "name": "@janhq/inference-cortex-extension", "productName": "Cortex Inference Engine", - "version": "1.0.16", + "version": "1.0.17", "description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.", "main": "dist/index.js", "node": "dist/node/index.cjs.js", diff --git a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json index 50944b9fe..7331b2fd8 100644 --- a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json +++ b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json @@ -8,12 +8,12 @@ "id": "phi3-medium", "object": "model", "name": "Phi-3 Medium Instruct Q4", - "version": "1.3", + "version": "1.4", "description": "Phi-3 Medium is Microsoft's latest SOTA model.", "format": "gguf", "settings": { "ctx_len": 128000, - "prompt_template": "<|user|> {prompt}<|end|><|assistant|><|end|>", + "prompt_template": "<|user|> {prompt}<|end|><|assistant|>", "llama_model_path": "Phi-3-medium-128k-instruct-Q4_K_M.gguf", "ngl": 33 }, From c3cb1924866e3b94c868d9eacc5c43190a955452 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 17 Sep 2024 16:09:38 +0700 Subject: [PATCH 04/37] fix: #3667 - The recommended label should be hidden (#3687) --- web/containers/ModelLabel/ModelLabel.test.tsx | 100 ++++++++++++++++++ web/containers/ModelLabel/index.tsx | 6 +- 2 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 web/containers/ModelLabel/ModelLabel.test.tsx diff --git a/web/containers/ModelLabel/ModelLabel.test.tsx b/web/containers/ModelLabel/ModelLabel.test.tsx new file mode 100644 index 000000000..48504ff6a --- /dev/null +++ b/web/containers/ModelLabel/ModelLabel.test.tsx @@ -0,0 +1,100 @@ +import React from 'react' +import { render, waitFor, screen } from '@testing-library/react' +import { useAtomValue } from 'jotai' +import { useActiveModel } from '@/hooks/useActiveModel' +import { useSettings } from '@/hooks/useSettings' +import ModelLabel from '@/containers/ModelLabel' + +jest.mock('jotai', () => ({ + useAtomValue: jest.fn(), + atom: jest.fn(), +})) + +jest.mock('@/hooks/useActiveModel', () => ({ + useActiveModel: jest.fn(), +})) + +jest.mock('@/hooks/useSettings', () => ({ + useSettings: jest.fn(), +})) + +describe('ModelLabel', () => { + const mockUseAtomValue = useAtomValue as jest.Mock + const mockUseActiveModel = useActiveModel as jest.Mock + const mockUseSettings = useSettings as jest.Mock + + const defaultProps: any = { + metadata: { + author: 'John Doe', // Add the 'author' property with a value + tags: ['8B'], + size: 100, + }, + compact: false, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('renders NotEnoughMemoryLabel when minimumRamModel is greater than totalRam', async () => { + mockUseAtomValue + .mockReturnValueOnce(0) + .mockReturnValueOnce(0) + .mockReturnValueOnce(0) + mockUseActiveModel.mockReturnValue({ + activeModel: { metadata: { size: 0 } }, + }) + mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } }) + + render() + await waitFor(() => { + expect(screen.getByText('Not enough RAM')).toBeDefined() + }) + }) + + it('renders SlowOnYourDeviceLabel when minimumRamModel is less than totalRam but greater than availableRam', async () => { + mockUseAtomValue + .mockReturnValueOnce(100) + .mockReturnValueOnce(50) + .mockReturnValueOnce(10) + mockUseActiveModel.mockReturnValue({ + activeModel: { metadata: { size: 0 } }, + }) + mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } }) + + const props = { + ...defaultProps, + metadata: { + ...defaultProps.metadata, + size: 50, + }, + } + + render() + await waitFor(() => { + expect(screen.getByText('Slow on your device')).toBeDefined() + }) + }) + + it('renders nothing when minimumRamModel is less than availableRam', () => { + mockUseAtomValue + .mockReturnValueOnce(100) + .mockReturnValueOnce(50) + .mockReturnValueOnce(0) + mockUseActiveModel.mockReturnValue({ + activeModel: { metadata: { size: 0 } }, + }) + mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } }) + + const props = { + ...defaultProps, + metadata: { + ...defaultProps.metadata, + size: 10, + }, + } + + const { container } = render() + expect(container.firstChild).toBeNull() + }) +}) diff --git a/web/containers/ModelLabel/index.tsx b/web/containers/ModelLabel/index.tsx index 2c32e288c..b0a3da96f 100644 --- a/web/containers/ModelLabel/index.tsx +++ b/web/containers/ModelLabel/index.tsx @@ -10,8 +10,6 @@ import { useSettings } from '@/hooks/useSettings' import NotEnoughMemoryLabel from './NotEnoughMemoryLabel' -import RecommendedLabel from './RecommendedLabel' - import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel' import { @@ -53,9 +51,7 @@ const ModelLabel = ({ metadata, compact }: Props) => { /> ) } - if (minimumRamModel < availableRam && !compact) { - return - } + if (minimumRamModel < totalRam && minimumRamModel > availableRam) { return } From 8e603bd5dbb80ef3050e313a0b046101ac81cc03 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 17 Sep 2024 16:43:47 +0700 Subject: [PATCH 05/37] fix: #3476 - Mismatch id between model json and path (#3645) * fix: mismatch between model json and path * chore: revert preserve model settings * test: add tests --- .gitignore | 1 + core/src/browser/core.test.ts | 179 +++--- core/src/browser/core.ts | 8 + .../browser/extensions/engines/AIEngine.ts | 6 +- .../extensions/engines/LocalOAIEngine.ts | 16 +- core/src/browser/extensions/model.ts | 10 +- core/src/node/api/processors/app.test.ts | 75 ++- core/src/node/api/processors/app.ts | 16 +- core/src/types/api/index.ts | 1 + core/src/types/file/index.ts | 15 + core/src/types/model/modelEntity.ts | 7 + core/src/types/model/modelInterface.ts | 11 +- .../inference-nitro-extension/src/index.ts | 3 +- .../src/node/index.ts | 4 +- extensions/model-extension/jest.config.js | 9 + extensions/model-extension/package.json | 1 + extensions/model-extension/rollup.config.ts | 4 +- extensions/model-extension/src/index.test.ts | 564 ++++++++++++++++++ extensions/model-extension/src/index.ts | 85 ++- extensions/model-extension/tsconfig.json | 3 +- .../tensorrt-llm-extension/src/index.ts | 3 +- web/containers/ModelDropdown/index.tsx | 22 +- web/helpers/atoms/AppConfig.atom.ts | 7 - web/helpers/atoms/Model.atom.ts | 19 +- web/hooks/useActiveModel.ts | 6 +- web/hooks/useCreateNewThread.ts | 20 +- web/hooks/useDeleteModel.ts | 12 +- web/hooks/useModels.ts | 5 +- web/hooks/useRecommendedModel.ts | 12 +- web/hooks/useUpdateModelParameters.ts | 60 +- .../Hub/ModelList/ModelHeader/index.tsx | 4 +- web/screens/Hub/ModelList/ModelItem/index.tsx | 4 +- web/screens/Hub/ModelList/index.tsx | 14 +- .../ModelDownloadRow/index.tsx | 8 +- .../Settings/MyModels/MyModelList/index.tsx | 4 +- 35 files changed, 879 insertions(+), 339 deletions(-) create mode 100644 extensions/model-extension/jest.config.js create mode 100644 extensions/model-extension/src/index.test.ts diff --git a/.gitignore b/.gitignore index 646e6842a..eaee28a62 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ core/test_results.html coverage .yarn .yarnrc +*.tsbuildinfo diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts index 84250888e..f38cc0b40 100644 --- a/core/src/browser/core.test.ts +++ b/core/src/browser/core.test.ts @@ -1,98 +1,109 @@ -import { openExternalUrl } from './core'; -import { joinPath } from './core'; -import { openFileExplorer } from './core'; -import { getJanDataFolderPath } from './core'; -import { abortDownload } from './core'; -import { getFileSize } from './core'; -import { executeOnMain } from './core'; +import { openExternalUrl } from './core' +import { joinPath } from './core' +import { openFileExplorer } from './core' +import { getJanDataFolderPath } from './core' +import { abortDownload } from './core' +import { getFileSize } from './core' +import { executeOnMain } from './core' -it('should open external url', async () => { - const url = 'http://example.com'; - globalThis.core = { - api: { - openExternalUrl: jest.fn().mockResolvedValue('opened') +describe('test core apis', () => { + it('should open external url', async () => { + const url = 'http://example.com' + globalThis.core = { + api: { + openExternalUrl: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openExternalUrl(url); - expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url); - expect(result).toBe('opened'); -}); + const result = await openExternalUrl(url) + expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url) + expect(result).toBe('opened') + }) - -it('should join paths', async () => { - const paths = ['/path/one', '/path/two']; - globalThis.core = { - api: { - joinPath: jest.fn().mockResolvedValue('/path/one/path/two') + it('should join paths', async () => { + const paths = ['/path/one', '/path/two'] + globalThis.core = { + api: { + joinPath: jest.fn().mockResolvedValue('/path/one/path/two'), + }, } - }; - const result = await joinPath(paths); - expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths); - expect(result).toBe('/path/one/path/two'); -}); + const result = await joinPath(paths) + expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths) + expect(result).toBe('/path/one/path/two') + }) - -it('should open file explorer', async () => { - const path = '/path/to/open'; - globalThis.core = { - api: { - openFileExplorer: jest.fn().mockResolvedValue('opened') + it('should open file explorer', async () => { + const path = '/path/to/open' + globalThis.core = { + api: { + openFileExplorer: jest.fn().mockResolvedValue('opened'), + }, } - }; - const result = await openFileExplorer(path); - expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path); - expect(result).toBe('opened'); -}); + const result = await openFileExplorer(path) + expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path) + expect(result).toBe('opened') + }) - -it('should get jan data folder path', async () => { - globalThis.core = { - api: { - getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data') + it('should get jan data folder path', async () => { + globalThis.core = { + api: { + getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'), + }, } - }; - const result = await getJanDataFolderPath(); - expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled(); - expect(result).toBe('/path/to/jan/data'); -}); + const result = await getJanDataFolderPath() + expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled() + expect(result).toBe('/path/to/jan/data') + }) - -it('should abort download', async () => { - const fileName = 'testFile'; - globalThis.core = { - api: { - abortDownload: jest.fn().mockResolvedValue('aborted') + it('should abort download', async () => { + const fileName = 'testFile' + globalThis.core = { + api: { + abortDownload: jest.fn().mockResolvedValue('aborted'), + }, } - }; - const result = await abortDownload(fileName); - expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName); - expect(result).toBe('aborted'); -}); + const result = await abortDownload(fileName) + expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName) + expect(result).toBe('aborted') + }) - -it('should get file size', async () => { - const url = 'http://example.com/file'; - globalThis.core = { - api: { - getFileSize: jest.fn().mockResolvedValue(1024) + it('should get file size', async () => { + const url = 'http://example.com/file' + globalThis.core = { + api: { + getFileSize: jest.fn().mockResolvedValue(1024), + }, } - }; - const result = await getFileSize(url); - expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url); - expect(result).toBe(1024); -}); + const result = await getFileSize(url) + expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url) + expect(result).toBe(1024) + }) - -it('should execute function on main process', async () => { - const extension = 'testExtension'; - const method = 'testMethod'; - const args = ['arg1', 'arg2']; - globalThis.core = { - api: { - invokeExtensionFunc: jest.fn().mockResolvedValue('result') + it('should execute function on main process', async () => { + const extension = 'testExtension' + const method = 'testMethod' + const args = ['arg1', 'arg2'] + globalThis.core = { + api: { + invokeExtensionFunc: jest.fn().mockResolvedValue('result'), + }, } - }; - const result = await executeOnMain(extension, method, ...args); - expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args); - expect(result).toBe('result'); -}); + const result = await executeOnMain(extension, method, ...args) + expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args) + expect(result).toBe('result') + }) +}) + +describe('dirName - just a pass thru api', () => { + it('should retrieve the directory name from a file path', async () => { + const mockDirName = jest.fn() + globalThis.core = { + api: { + dirName: mockDirName.mockResolvedValue('/path/to'), + }, + } + // Normal file path with extension + const path = '/path/to/file.txt' + await globalThis.core.api.dirName(path) + expect(mockDirName).toHaveBeenCalledWith(path) + }) +}) diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts index fdbceb06b..b19e0b339 100644 --- a/core/src/browser/core.ts +++ b/core/src/browser/core.ts @@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) => const joinPath: (paths: string[]) => Promise = (paths) => globalThis.core.api?.joinPath(paths) +/** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + * @returns {Promise} A promise that resolves the dirname. + */ +const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path) + /** * Retrieve the basename from an url. * @param path - The path to retrieve. @@ -161,5 +168,6 @@ export { systemInformation, showToast, getFileSize, + dirName, FileStat, } diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 7cd9f513e..75354de88 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core' import { events } from '../../events' import { BaseExtension } from '../../extension' import { fs } from '../../fs' -import { MessageRequest, Model, ModelEvent } from '../../../types' +import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types' import { EngineManager } from './EngineManager' /** @@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension { override onLoad() { this.registerEngine() - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } @@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads the model. */ - async loadModel(model: Model): Promise { + async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return Promise.resolve() events.emit(ModelEvent.OnModelReady, model) return Promise.resolve() diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index fb9e4962c..123b9a593 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,6 +1,6 @@ -import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core' +import { executeOnMain, systemInformation, dirName } from '../../core' import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' +import { Model, ModelEvent, ModelFile } from '../../../types' import { OAIEngine } from './OAIEngine' /** @@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine { unloadModelFunctionName: string = 'unloadModel' /** - * On extension load, subscribe to events. + * This class represents a base for local inference providers in the OpenAI architecture. + * It extends the OAIEngine class and provides the implementation of loading and unloading models locally. + * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated. + * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped. */ override onLoad() { super.onLoad() // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** * Load the model. */ - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if (model.engine.toString() !== this.provider) return - const modelFolderName = 'models' - const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id]) + const modelFolder = await dirName(model.file_path) const systemInfo = await systemInformation() const res = await executeOnMain( this.nodeModule, diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index 5b3089403..040542927 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -4,6 +4,7 @@ import { HuggingFaceRepoData, ImportingModel, Model, + ModelFile, ModelInterface, OptionType, } from '../../types' @@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter network?: { proxy: string; ignoreSSL?: boolean } ): Promise abstract cancelModelDownload(modelId: string): Promise - abstract deleteModel(modelId: string): Promise - abstract saveModel(model: Model): Promise - abstract getDownloadedModels(): Promise - abstract getConfiguredModels(): Promise + abstract deleteModel(model: ModelFile): Promise + abstract getDownloadedModels(): Promise + abstract getConfiguredModels(): Promise abstract importModels(models: ImportingModel[], optionType: OptionType): Promise - abstract updateModelInfo(modelInfo: Partial): Promise + abstract updateModelInfo(modelInfo: Partial): Promise abstract fetchHuggingFaceRepoData(repoId: string): Promise abstract getDefaultModel(): Promise } diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts index 3ada5df1e..5c4daef29 100644 --- a/core/src/node/api/processors/app.test.ts +++ b/core/src/node/api/processors/app.test.ts @@ -1,40 +1,57 @@ -import { App } from './app'; +jest.mock('../../helper', () => ({ + ...jest.requireActual('../../helper'), + getJanDataFolderPath: () => './app', +})) +import { dirname } from 'path' +import { App } from './app' it('should call stopServer', () => { - const app = new App(); - const stopServerMock = jest.fn().mockResolvedValue('Server stopped'); + const app = new App() + const stopServerMock = jest.fn().mockResolvedValue('Server stopped') jest.mock('@janhq/server', () => ({ - stopServer: stopServerMock - })); - const result = app.stopServer(); - expect(stopServerMock).toHaveBeenCalled(); -}); + stopServer: stopServerMock, + })) + app.stopServer() + expect(stopServerMock).toHaveBeenCalled() +}) it('should correctly retrieve basename', () => { - const app = new App(); - const result = app.baseName('/path/to/file.txt'); - expect(result).toBe('file.txt'); -}); + const app = new App() + const result = app.baseName('/path/to/file.txt') + expect(result).toBe('file.txt') +}) it('should correctly identify subdirectories', () => { - const app = new App(); - const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'; - const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'; - const result = app.isSubdirectory(basePath, subPath); - expect(result).toBe(true); -}); + const app = new App() + const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to' + const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir' + const result = app.isSubdirectory(basePath, subPath) + expect(result).toBe(true) +}) it('should correctly join multiple paths', () => { - const app = new App(); - const result = app.joinPath(['path', 'to', 'file']); - const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'; - expect(result).toBe(expectedPath); -}); + const app = new App() + const result = app.joinPath(['path', 'to', 'file']) + const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file' + expect(result).toBe(expectedPath) +}) it('should call correct function with provided arguments using process method', () => { - const app = new App(); - const mockFunc = jest.fn(); - app.joinPath = mockFunc; - app.process('joinPath', ['path1', 'path2']); - expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']); -}); + const app = new App() + const mockFunc = jest.fn() + app.joinPath = mockFunc + app.process('joinPath', ['path1', 'path2']) + expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']) +}) + +it('should retrieve the directory name from a file path (Unix/Windows)', async () => { + const app = new App() + const path = 'C:/Users/John Doe/Desktop/file.txt' + expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop') +}) + +it('should retrieve the directory name when using file protocol', async () => { + const app = new App() + const path = 'file:/models/file.txt' + expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models') +}) diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts index 15460ba56..a0808c5ac 100644 --- a/core/src/node/api/processors/app.ts +++ b/core/src/node/api/processors/app.ts @@ -1,4 +1,4 @@ -import { basename, isAbsolute, join, relative } from 'path' +import { basename, dirname, isAbsolute, join, relative } from 'path' import { Processor } from './Processor' import { @@ -6,6 +6,8 @@ import { appResourcePath, getAppConfigurations as appConfiguration, updateAppConfiguration, + normalizeFilePath, + getJanDataFolderPath, } from '../../helper' export class App implements Processor { @@ -28,6 +30,18 @@ export class App implements Processor { return join(...args) } + /** + * Get dirname of a file path. + * @param path - The file path to retrieve dirname. + */ + dirName(path: string) { + const arg = + path.startsWith(`file:/`) || path.startsWith(`file:\\`) + ? join(getJanDataFolderPath(), normalizeFilePath(path)) + : path + return dirname(arg) + } + /** * Checks if the given path is a subdirectory of the given directory. * diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index bca11c0a8..8f1ff70bf 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -37,6 +37,7 @@ export enum AppRoute { getAppConfigurations = 'getAppConfigurations', updateAppConfiguration = 'updateAppConfiguration', joinPath = 'joinPath', + dirName = 'dirName', isSubdirectory = 'isSubdirectory', baseName = 'baseName', startServer = 'startServer', diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 1b36a5777..4db956b1e 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -52,3 +52,18 @@ type DownloadSize = { total: number transferred: number } + +/** + * The file metadata + */ +export type FileMetadata = { + /** + * The origin file path. + */ + file_path: string + + /** + * The file name. + */ + file_name: string +} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index f154f7f04..933c698c3 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -1,3 +1,5 @@ +import { FileMetadata } from '../file' + /** * Represents the information about a model. * @stored @@ -151,3 +153,8 @@ export type ModelRuntimeParams = { export type ModelInitFailed = Model & { error: Error } + +/** + * ModelFile is the model.json entity and it's file metadata + */ +export type ModelFile = Model & FileMetadata diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 639c7c8d3..5b5856231 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -1,5 +1,5 @@ import { GpuSetting } from '../miscellaneous' -import { Model } from './modelEntity' +import { Model, ModelFile } from './modelEntity' /** * Model extension for managing models. @@ -29,14 +29,7 @@ export interface ModelInterface { * @param modelId - The ID of the model to delete. * @returns A Promise that resolves when the model has been deleted. */ - deleteModel(modelId: string): Promise - - /** - * Saves a model. - * @param model - The model to save. - * @returns A Promise that resolves when the model has been saved. - */ - saveModel(model: Model): Promise + deleteModel(model: ModelFile): Promise /** * Gets a list of downloaded models. diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index d79e076d4..6e825e8fd 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -22,6 +22,7 @@ import { downloadFile, DownloadState, DownloadEvent, + ModelFile, } from '@janhq/core' declare const CUDA_DOWNLOAD_URL: string @@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine { this.nitroProcessInfo = health } - override loadModel(model: Model): Promise { + override loadModel(model: ModelFile): Promise { if (model.engine !== this.provider) return Promise.resolve() this.getNitroProcessHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index 3a969ad5e..98ca4572f 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry' import { log, getSystemResourceInfo, - Model, InferenceEngine, ModelSettingParams, PromptTemplate, SystemInformation, getJanDataFolderPath, + ModelFile, } from '@janhq/core/node' import { executableNitroFile } from './execute' import terminate from 'terminate' @@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch) */ interface ModelInitOptions { modelFolder: string - model: Model + model: ModelFile } // The PORT to use for the Nitro subprocess const PORT = 3928 diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/model-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json index 4a2c61b71..9a406dcf4 100644 --- a/extensions/model-extension/package.json +++ b/extensions/model-extension/package.json @@ -8,6 +8,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs", "build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install" }, diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts index c3f3acc77..d36d8ffac 100644 --- a/extensions/model-extension/rollup.config.ts +++ b/extensions/model-extension/rollup.config.ts @@ -27,7 +27,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Compile TypeScript files // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) // commonjs(), @@ -62,7 +62,7 @@ export default [ // Allow json resolution json(), // Compile TypeScript files - typescript({ useTsconfigDeclarationDir: true }), + typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }), // Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs) commonjs(), // Allow node_modules resolution, so you can use 'external' to control diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts new file mode 100644 index 000000000..6816d7101 --- /dev/null +++ b/extensions/model-extension/src/index.test.ts @@ -0,0 +1,564 @@ +const readDirSyncMock = jest.fn() +const existMock = jest.fn() +const readFileSyncMock = jest.fn() + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + fs: { + existsSync: existMock, + readdirSync: readDirSyncMock, + readFileSync: readFileSyncMock, + fileStat: () => ({ + isDirectory: false, + }), + }, + dirName: jest.fn(), + joinPath: (paths) => paths.join('/'), + ModelExtension: jest.fn(), +})) + +import JanModelExtension from '.' +import { fs, dirName } from '@janhq/core' + +describe('JanModelExtension', () => { + let sut: JanModelExtension + + beforeAll(() => { + // @ts-ignore + sut = new JanModelExtension() + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + describe('getConfiguredModels', () => { + describe("when there's no models are pre-populated", () => { + it('should return empty array', async () => { + // Mock configured models data + const configuredModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getConfiguredModels() + expect(result).toEqual([]) + }) + }) + + describe("when there's are pre-populated models - all flattened", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe("when there's are pre-populated models - there are nested folders", () => { + it('returns configured models data - flatten folder - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else if (path.includes('model2/model2-1')) + return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getConfiguredModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('getDownloadedModels', () => { + describe('no models downloaded', () => { + it('should return empty array', async () => { + // Mock downloaded models data + const downloadedModels = [] + existMock.mockReturnValue(true) + readDirSyncMock.mockReturnValue([]) + + const result = await sut.getDownloadedModels() + expect(result).toEqual([]) + }) + }) + describe('only one model is downloaded', () => { + describe('flatten folder', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded', () => { + describe('nested folders', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('all models are downloaded with uppercased GGUF files', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.GGUF'] + else return ['model.json', 'test.gguf'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + + describe('all models are downloaded - GGUF & Tensort RT', () => { + it('returns downloaded models - with correct file_path and model id', async () => { + // Mock configured models data + const configuredModels = [ + { + id: '1', + name: 'Model 1', + version: '1.0.0', + description: 'Model 1 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model1', + }, + format: 'onnx', + sources: [], + created: new Date(), + updated: new Date(), + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + { + id: '2', + name: 'Model 2', + version: '2.0.0', + description: 'Model 2 description', + object: { + type: 'model', + uri: 'http://localhost:5000/models/model2', + }, + format: 'onnx', + sources: [], + parameters: {}, + settings: {}, + metadata: {}, + engine: 'test', + } as any, + ] + existMock.mockReturnValue(true) + + readDirSyncMock.mockImplementation((path) => { + if (path === 'file://models') return ['model1', 'model2/model2-1'] + else if (path === 'file://models/model1') + return ['model.json', 'test.gguf'] + else return ['model.json', 'test.engine'] + }) + + readFileSyncMock.mockImplementation((path) => { + if (path.includes('model1')) + return JSON.stringify(configuredModels[0]) + else return JSON.stringify(configuredModels[1]) + }) + + const result = await sut.getDownloadedModels() + expect(result).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + file_path: 'file://models/model1/model.json', + id: '1', + }), + expect.objectContaining({ + file_path: 'file://models/model2/model2-1/model.json', + id: '2', + }), + ]) + ) + }) + }) + }) + + describe('deleteModel', () => { + describe('model is a GGUF model', () => { + it('should delete the GGUF file', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockImplementation((path) => { + return ['model.json', 'test.gguf'] + }) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.gguf' + ) + }) + + it('no gguf file presented', async () => { + fs.unlinkSync = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + readDirSyncMock.mockReturnValue(['model.json']) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledTimes(0) + }) + + it('delete an imported model', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.gguf']) + + // MARK: This is a tricky logic implement? + // I will just add test for now but will align on the legacy implementation + fs.readFileSync = jest.fn().mockReturnValue( + JSON.stringify({ + metadata: { + author: 'user', + }, + }) + ) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.rm).toHaveBeenCalledWith('file://models/model1') + }) + + it('delete tensorrt-models', async () => { + fs.rm = jest.fn() + const dirMock = dirName as jest.Mock + dirMock.mockReturnValue('file://models/model1') + + readDirSyncMock.mockReturnValue(['model.json', 'test.engine']) + + fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({})) + + existMock.mockReturnValue(true) + + await sut.deleteModel({ + file_path: 'file://models/model1/model.json', + } as any) + + expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine') + }) + }) + }) +}) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index e2f68a58c..ac9b06a09 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -22,6 +22,8 @@ import { getFileSize, AllQuantizations, ModelEvent, + ModelFile, + dirName, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension { ] private static readonly _tensorRtEngineFormat = '.engine' private static readonly _supportedGpuArch = ['ampere', 'ada'] - private static readonly _safetensorsRegexs = [ - /model\.safetensors$/, - /model-[0-9]+-of-[0-9]+\.safetensors$/, - ] - private static readonly _pytorchRegexs = [ - /pytorch_model\.bin$/, - /consolidated\.[0-9]+\.pth$/, - /pytorch_model-[0-9]+-of-[0-9]+\.bin$/, - /.*\.pt$/, - ] + interrupted = false /** @@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension { * @param filePath - The path to the model file to delete. * @returns A Promise that resolves when the model is deleted. */ - async deleteModel(modelId: string): Promise { + async deleteModel(model: ModelFile): Promise { try { - const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) + const dirPath = await dirName(model.file_path) const jsonFilePath = await joinPath([ dirPath, JanModelExtension._modelMetadataFileName, @@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension { await this.readModelMetadata(jsonFilePath) ) as Model + // TODO: This is so tricky? + // Should depend on sources? const isUserImportModel = modelInfo.metadata?.author?.toLowerCase() === 'user' - if (isUserImportModel) { + if (isUserImportModel) { // just delete the folder return fs.rm(dirPath) } @@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension { } } - /** - * Saves a model file. - * @param model - The model to save. - * @returns A Promise that resolves when the model is saved. - */ - async saveModel(model: Model): Promise { - const jsonFilePath = await joinPath([ - JanModelExtension._homeDir, - model.id, - JanModelExtension._modelMetadataFileName, - ]) - - try { - await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2)) - } catch (err) { - console.error(err) - } - } - /** * Gets all downloaded models. * @returns A Promise that resolves with an array of all models. */ - async getDownloadedModels(): Promise { + async getDownloadedModels(): Promise { return await this.getModelsMetadata( async (modelDir: string, model: Model) => { if (!JanModelExtension._offlineInferenceEngine.includes(model.engine)) @@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension { ): Promise { // try to find model.json recursively inside each folder if (!(await fs.existsSync(folderFullPath))) return undefined + const files: string[] = await fs.readdirSync(folderFullPath) if (files.length === 0) return undefined + if (files.includes(JanModelExtension._modelMetadataFileName)) { return joinPath([ folderFullPath, @@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension { private async getModelsMetadata( selector?: (path: string, model: Model) => Promise - ): Promise { + ): Promise { try { if (!(await fs.existsSync(JanModelExtension._homeDir))) { console.debug('Model folder not found') @@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._homeDir, dirName, ]) + const jsonPath = await this.getModelJsonPath(folderFullPath) if (await fs.existsSync(jsonPath)) { @@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension { }, ] } + model.file_path = jsonPath + model.file_name = JanModelExtension._modelMetadataFileName if (selector && !(await selector?.(dirName, model))) { return @@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension { typeof result.value === 'object' ? result.value : JSON.parse(result.value) - return model as Model + return model as ModelFile } catch { console.debug(`Unable to parse model metadata: ${result.value}`) } @@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension { * Gets all available models. * @returns A Promise that resolves with an array of all models. */ - async getConfiguredModels(): Promise { + async getConfiguredModels(): Promise { return this.getModelsMetadata() } @@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension { modelBinaryPath: string, modelFolderName: string, modelFolderPath: string - ): Promise { + ): Promise { const fileStats = await fs.fileStat(modelBinaryPath, true) const binaryFileSize = fileStats.size @@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension { await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2)) - return model + return { + ...model, + file_path: modelFilePath, + file_name: JanModelExtension._modelMetadataFileName, + } } - async updateModelInfo(modelInfo: Partial): Promise { - const modelId = modelInfo.id + async updateModelInfo(modelInfo: Partial): Promise { if (modelInfo.id == null) throw new Error('Model ID is required') - const janDataFolderPath = await getJanDataFolderPath() - const jsonFilePath = await joinPath([ - janDataFolderPath, - 'models', - modelId, - JanModelExtension._modelMetadataFileName, - ]) const model = JSON.parse( - await this.readModelMetadata(jsonFilePath) - ) as Model + await this.readModelMetadata(modelInfo.file_path) + ) as ModelFile - const updatedModel: Model = { + const updatedModel: ModelFile = { ...model, ...modelInfo, parameters: { @@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension { ...model.metadata, ...modelInfo.metadata, }, + // Should not persist file_path & file_name + file_path: undefined, + file_name: undefined, } - await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2)) + await fs.writeFileSync( + modelInfo.file_path, + JSON.stringify(updatedModel, null, 2) + ) return updatedModel } diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json index addd8e127..0d3252934 100644 --- a/extensions/model-extension/tsconfig.json +++ b/extensions/model-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] } diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index 189abc706..7f68c43bd 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -23,6 +23,7 @@ import { ModelEvent, getJanDataFolderPath, SystemInformation, + ModelFile, } from '@janhq/core' /** @@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.emit(ModelEvent.OnModelsUpdate, {}) } - override async loadModel(model: Model): Promise { + override async loadModel(model: ModelFile): Promise { if ((await this.installationState()) === 'Installed') return super.loadModel(model) diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 92d8addd0..d8743ddce 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -46,7 +46,6 @@ import { import { extensionManager } from '@/extension' -import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { configuredModelsAtom, @@ -91,8 +90,6 @@ const ModelDropdown = ({ const featuredModel = configuredModels.filter((x) => x.metadata.tags.includes('Featured') ) - const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) - const { updateThreadMetadata } = useCreateNewThread() useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ @@ -191,27 +188,14 @@ const ModelDropdown = ({ ], }) - // Default setting ctx_len for the model for a better onboarding experience - // TODO: When Cortex support hardware instructions, we should remove this - const defaultContextLength = preserveModelSettings - ? model?.metadata?.default_ctx_len - : 2048 - const defaultMaxTokens = preserveModelSettings - ? model?.metadata?.default_max_tokens - : 2048 const overriddenSettings = - model?.settings.ctx_len && model.settings.ctx_len > 2048 - ? { ctx_len: defaultContextLength ?? 2048 } - : {} - const overriddenParameters = - model?.parameters.max_tokens && model.parameters.max_tokens - ? { max_tokens: defaultMaxTokens ?? 2048 } + model?.settings.ctx_len && model.settings.ctx_len > 4096 + ? { ctx_len: 4096 } : {} const modelParams = { ...model?.parameters, ...model?.settings, - ...overriddenParameters, ...overriddenSettings, } @@ -222,6 +206,7 @@ const ModelDropdown = ({ if (model) updateModelParameter(activeThread, { params: modelParams, + modelPath: model.file_path, modelId: model.id, engine: model.engine, }) @@ -235,7 +220,6 @@ const ModelDropdown = ({ setThreadModelParams, updateModelParameter, updateThreadMetadata, - preserveModelSettings, ] ) diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts index e7b7efaec..f4acc7dc2 100644 --- a/web/helpers/atoms/AppConfig.atom.ts +++ b/web/helpers/atoms/AppConfig.atom.ts @@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled' const IGNORE_SSL = 'ignoreSSLFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' const QUICK_ASK_ENABLED = 'quickAskEnabled' -const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings' export const janDataFolderPathAtom = atom('') @@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false) export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) export const hostAtom = atom('http://localhost:1337/') - -// This feature is to allow user to cache model settings on thread creation -export const preserveModelSettingsAtom = atomWithStorage( - PRESERVE_MODEL_SETTINGS, - false -) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 77b1bfa4e..d2d0ca9f4 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -1,4 +1,4 @@ -import { ImportingModel, Model, InferenceEngine } from '@janhq/core' +import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core' import { atom } from 'jotai' import { localEngines } from '@/utils/modelEngine' @@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom( } ) -export const downloadedModelsAtom = atom([]) - -export const updateDownloadedModelAtom = atom( - null, - (get, set, updatedModel: Model) => { - const models: Model[] = get(downloadedModelsAtom).map((c) => - c.id === updatedModel.id ? updatedModel : c - ) - - set(downloadedModelsAtom, models) - } -) +export const downloadedModelsAtom = atom([]) export const removeDownloadedModelAtom = atom( null, @@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom( } ) -export const configuredModelsAtom = atom([]) +export const configuredModelsAtom = atom([]) export const defaultModelAtom = atom(undefined) @@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom( } ) -export const selectedModelAtom = atom(undefined) +export const selectedModelAtom = atom(undefined) export const showEngineListModelAtom = atom(localEngines) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 9768ac4c4..2d53678c3 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useRef } from 'react' -import { EngineManager, Model } from '@janhq/core' +import { EngineManager, Model, ModelFile } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { toaster } from '@/containers/Toast' @@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' -export const activeModelAtom = atom(undefined) +export const activeModelAtom = atom(undefined) export const loadModelErrorAtom = atom(undefined) type ModelState = { @@ -37,7 +37,7 @@ export function useActiveModel() { const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) - const downloadedModelsRef = useRef([]) + const downloadedModelsRef = useRef([]) useEffect(() => { downloadedModelsRef.current = downloadedModels diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 80acfa3cc..5548259fd 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -7,8 +7,8 @@ import { Thread, ThreadAssistantInfo, ThreadState, - Model, AssistantTool, + ModelFile, } from '@janhq/core' import { atom, useAtomValue, useSetAtom } from 'jotai' @@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread' import { extensionManager } from '@/extension' -import { - experimentalFeatureEnabledAtom, - preserveModelSettingsAtom, -} from '@/helpers/atoms/AppConfig.atom' +import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, @@ -67,7 +64,6 @@ export const useCreateNewThread = () => { const copyOverInstructionEnabled = useAtomValue( copyOverInstructionEnabledAtom ) - const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) const activeThread = useAtomValue(activeThreadAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) @@ -80,7 +76,7 @@ export const useCreateNewThread = () => { const requestCreateNewThread = async ( assistant: Assistant, - model?: Model | undefined + model?: ModelFile | undefined ) => { // Stop generating if any setIsGeneratingResponse(false) @@ -109,19 +105,13 @@ export const useCreateNewThread = () => { enabled: true, settings: assistant.tools && assistant.tools[0].settings, } - const defaultContextLength = preserveModelSettings - ? defaultModel?.metadata?.default_ctx_len - : 2048 - const defaultMaxTokens = preserveModelSettings - ? defaultModel?.metadata?.default_max_tokens - : 2048 const overriddenSettings = defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 - ? { ctx_len: defaultContextLength ?? 2048 } + ? { ctx_len: 4096 } : {} const overriddenParameters = defaultModel?.parameters.max_tokens - ? { max_tokens: defaultMaxTokens ?? 2048 } + ? { max_tokens: 4096 } : {} const createdAt = Date.now() diff --git a/web/hooks/useDeleteModel.ts b/web/hooks/useDeleteModel.ts index 9736f8256..5a7a319b2 100644 --- a/web/hooks/useDeleteModel.ts +++ b/web/hooks/useDeleteModel.ts @@ -1,6 +1,6 @@ import { useCallback } from 'react' -import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core' +import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core' import { useSetAtom } from 'jotai' @@ -13,8 +13,8 @@ export default function useDeleteModel() { const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom) const deleteModel = useCallback( - async (model: Model) => { - await localDeleteModel(model.id) + async (model: ModelFile) => { + await localDeleteModel(model) removeDownloadedModel(model.id) toaster({ title: 'Model Deletion Successful', @@ -28,5 +28,7 @@ export default function useDeleteModel() { return { deleteModel } } -const localDeleteModel = async (id: string) => - extensionManager.get(ExtensionTypeEnum.Model)?.deleteModel(id) +const localDeleteModel = async (model: ModelFile) => + extensionManager + .get(ExtensionTypeEnum.Model) + ?.deleteModel(model) diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index 5a6f13e03..8333c35c3 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -5,6 +5,7 @@ import { Model, ModelEvent, ModelExtension, + ModelFile, events, } from '@janhq/core' @@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise => .get(ExtensionTypeEnum.Model) ?.getDefaultModel() -const getLocalConfiguredModels = async (): Promise => +const getLocalConfiguredModels = async (): Promise => extensionManager .get(ExtensionTypeEnum.Model) ?.getConfiguredModels() ?? [] -const getLocalDownloadedModels = async (): Promise => +const getLocalDownloadedModels = async (): Promise => extensionManager .get(ExtensionTypeEnum.Model) ?.getDownloadedModels() ?? [] diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts index 21a9c69e7..ed56efa55 100644 --- a/web/hooks/useRecommendedModel.ts +++ b/web/hooks/useRecommendedModel.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useState } from 'react' -import { Model, InferenceEngine } from '@janhq/core' +import { Model, InferenceEngine, ModelFile } from '@janhq/core' import { atom, useAtomValue } from 'jotai' @@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id' */ export default function useRecommendedModel() { const activeModel = useAtomValue(activeModelAtom) - const [sortedModels, setSortedModels] = useState([]) - const [recommendedModel, setRecommendedModel] = useState() + const [sortedModels, setSortedModels] = useState([]) + const [recommendedModel, setRecommendedModel] = useState< + ModelFile | undefined + >() const activeThread = useAtomValue(activeThreadAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) - const getAndSortDownloadedModels = useCallback(async (): Promise => { + const getAndSortDownloadedModels = useCallback(async (): Promise< + ModelFile[] + > => { const models = downloadedModels.sort((a, b) => a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro ? 1 diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 46bf07cd5..af30210ad 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -4,8 +4,6 @@ import { ConversationalExtension, ExtensionTypeEnum, InferenceEngine, - Model, - ModelExtension, Thread, ThreadAssistantInfo, } from '@janhq/core' @@ -17,14 +15,8 @@ import { extractModelLoadParams, } from '@/utils/modelParam' -import useRecommendedModel from './useRecommendedModel' - import { extensionManager } from '@/extension' -import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' -import { - selectedModelAtom, - updateDownloadedModelAtom, -} from '@/helpers/atoms/Model.atom' +import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { ModelParams, getActiveThreadModelParamsAtom, @@ -34,16 +26,14 @@ import { export type UpdateModelParameter = { params?: ModelParams modelId?: string + modelPath?: string engine?: InferenceEngine } export default function useUpdateModelParameters() { const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) - const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) + const [selectedModel] = useAtom(selectedModelAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) - const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom) - const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom) - const { recommendedModel, setRecommendedModel } = useRecommendedModel() const updateModelParameter = useCallback( async (thread: Thread, settings: UpdateModelParameter) => { @@ -83,50 +73,8 @@ export default function useUpdateModelParameters() { await extensionManager .get(ExtensionTypeEnum.Conversational) ?.saveThread(updatedThread) - - // Persists default settings to model file - // Do not overwrite ctx_len and max_tokens - if (preserveModelFeatureEnabled) { - const defaultContextLength = settingParams.ctx_len - const defaultMaxTokens = runtimeParams.max_tokens - - // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars - const { ctx_len, ...toSaveSettings } = settingParams - // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars - const { max_tokens, ...toSaveParams } = runtimeParams - - const updatedModel = { - id: settings.modelId ?? selectedModel?.id, - parameters: { - ...toSaveSettings, - }, - settings: { - ...toSaveParams, - }, - metadata: { - default_ctx_len: defaultContextLength, - default_max_tokens: defaultMaxTokens, - }, - } as Partial - - const model = await extensionManager - .get(ExtensionTypeEnum.Model) - ?.updateModelInfo(updatedModel) - if (model) updateDownloadedModel(model) - if (selectedModel?.id === model?.id) setSelectedModel(model) - if (recommendedModel?.id === model?.id) setRecommendedModel(model) - } }, - [ - activeModelParams, - selectedModel, - setThreadModelParams, - preserveModelFeatureEnabled, - updateDownloadedModel, - setSelectedModel, - recommendedModel, - setRecommendedModel, - ] + [activeModelParams, selectedModel, setThreadModelParams] ) const processStopWords = (params: ModelParams): ModelParams => { diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx index b20977aff..44a3fd278 100644 --- a/web/screens/Hub/ModelList/ModelHeader/index.tsx +++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx @@ -1,6 +1,6 @@ import { useCallback } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { Button, Badge, Tooltip } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' @@ -38,7 +38,7 @@ import { } from '@/helpers/atoms/SystemBar.atom' type Props = { - model: Model + model: ModelFile onClick: () => void open: string } diff --git a/web/screens/Hub/ModelList/ModelItem/index.tsx b/web/screens/Hub/ModelList/ModelItem/index.tsx index c9b2f1329..ec9d885a1 100644 --- a/web/screens/Hub/ModelList/ModelItem/index.tsx +++ b/web/screens/Hub/ModelList/ModelItem/index.tsx @@ -1,6 +1,6 @@ import { useState } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { Badge } from '@janhq/joi' import { twMerge } from 'tailwind-merge' @@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader' import { toGibibytes } from '@/utils/converter' type Props = { - model: Model + model: ModelFile } const ModelItem: React.FC = ({ model }) => { diff --git a/web/screens/Hub/ModelList/index.tsx b/web/screens/Hub/ModelList/index.tsx index aea67b4e3..8fc30d541 100644 --- a/web/screens/Hub/ModelList/index.tsx +++ b/web/screens/Hub/ModelList/index.tsx @@ -1,6 +1,6 @@ import { useMemo } from 'react' -import { Model } from '@janhq/core' +import { ModelFile } from '@janhq/core' import { useAtomValue } from 'jotai' @@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' type Props = { - models: Model[] + models: ModelFile[] } const ModelList = ({ models }: Props) => { const downloadedModels = useAtomValue(downloadedModelsAtom) - const sortedModels: Model[] = useMemo(() => { - const featuredModels: Model[] = [] - const remoteModels: Model[] = [] - const localModels: Model[] = [] - const remainingModels: Model[] = [] + const sortedModels: ModelFile[] = useMemo(() => { + const featuredModels: ModelFile[] = [] + const remoteModels: ModelFile[] = [] + const localModels: ModelFile[] = [] + const remainingModels: ModelFile[] = [] models.forEach((m) => { if (m.metadata?.tags?.includes('Featured')) { featuredModels.push(m) diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx index 951a11d59..c3f09f171 100644 --- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx +++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx @@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC = ({ const { requestCreateNewThread } = useCreateNewThread() const setMainViewState = useSetAtom(mainViewStateAtom) const assistants = useAtomValue(assistantsAtom) - const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null + const downloadedModel = downloadedModels.find((md) => md.id === fileName) const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom) const defaultModel = useAtomValue(defaultModelAtom) @@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC = ({ alert('No assistant available') return } - await requestCreateNewThread(assistants[0], model) + await requestCreateNewThread(assistants[0], downloadedModel) setMainViewState(MainViewState.Thread) setHfImportingStage('NONE') }, [ assistants, - model, + downloadedModel, requestCreateNewThread, setMainViewState, setHfImportingStage, @@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC = ({ - {isDownloaded ? ( + {downloadedModel ? ( diff --git a/web/screens/Settings/Advanced/index.test.tsx b/web/screens/Settings/Advanced/index.test.tsx new file mode 100644 index 000000000..10ea810b1 --- /dev/null +++ b/web/screens/Settings/Advanced/index.test.tsx @@ -0,0 +1,154 @@ +import React from 'react' +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import '@testing-library/jest-dom' +import Advanced from '.' + +class ResizeObserverMock { + observe() {} + unobserve() {} + disconnect() {} +} + +global.ResizeObserver = ResizeObserverMock +// @ts-ignore +global.window.core = { + api: { + getAppConfigurations: () => jest.fn(), + updateAppConfiguration: () => jest.fn(), + relaunch: () => jest.fn(), + }, +} + +const setSettingsMock = jest.fn() + +// Mock useSettings hook +jest.mock('@/hooks/useSettings', () => ({ + __esModule: true, + useSettings: () => ({ + readSettings: () => ({ + run_mode: 'gpu', + experimental: false, + proxy: false, + gpus: [{ name: 'gpu-1' }, { name: 'gpu-2' }], + gpus_in_use: ['0'], + quick_ask: false, + }), + setSettings: setSettingsMock, + }), +})) + +import * as toast from '@/containers/Toast' + +jest.mock('@/containers/Toast') + +jest.mock('@janhq/core', () => ({ + __esModule: true, + ...jest.requireActual('@janhq/core'), + fs: { + rm: jest.fn(), + }, +})) + +// Simulate a full advanced settings screen +// @ts-ignore +global.isMac = false +// @ts-ignore +global.isWindows = true + +describe('Advanced', () => { + it('renders the component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Experimental Mode')).toBeInTheDocument() + expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument() + expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument() + expect(screen.getByText('Jan Data Folder')).toBeInTheDocument() + expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument() + }) + }) + + it('updates Experimental enabled', async () => { + render() + let experimentalToggle + await waitFor(() => { + experimentalToggle = screen.getByTestId(/experimental-switch/i) + fireEvent.click(experimentalToggle!) + }) + expect(experimentalToggle).toBeChecked() + }) + + it('updates Experimental disabled', async () => { + render() + + let experimentalToggle + await waitFor(() => { + experimentalToggle = screen.getByTestId(/experimental-switch/i) + fireEvent.click(experimentalToggle!) + }) + expect(experimentalToggle).not.toBeChecked() + }) + + it('clears logs', async () => { + const jestMock = jest.fn() + jest.spyOn(toast, 'toaster').mockImplementation(jestMock) + + render() + let clearLogsButton + await waitFor(() => { + clearLogsButton = screen.getByTestId(/clear-logs/i) + fireEvent.click(clearLogsButton) + }) + expect(clearLogsButton).toBeInTheDocument() + expect(jestMock).toHaveBeenCalled() + }) + + it('toggles proxy enabled', async () => { + render() + let proxyToggle + await waitFor(() => { + expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument() + proxyToggle = screen.getByTestId(/proxy-switch/i) + fireEvent.click(proxyToggle) + }) + expect(proxyToggle).toBeChecked() + }) + + it('updates proxy settings', async () => { + render() + let proxyInput + await waitFor(() => { + const proxyToggle = screen.getByTestId(/proxy-switch/i) + fireEvent.click(proxyToggle) + proxyInput = screen.getByTestId(/proxy-input/i) + fireEvent.change(proxyInput, { target: { value: 'http://proxy.com' } }) + }) + expect(proxyInput).toHaveValue('http://proxy.com') + }) + + it('toggles ignore SSL certificates', async () => { + render() + let ignoreSslToggle + await waitFor(() => { + expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument() + ignoreSslToggle = screen.getByTestId(/ignore-ssl-switch/i) + fireEvent.click(ignoreSslToggle) + }) + expect(ignoreSslToggle).toBeChecked() + }) + + it('renders DataFolder component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Jan Data Folder')).toBeInTheDocument() + expect(screen.getByTestId(/jan-data-folder-input/i)).toBeInTheDocument() + }) + }) + + it('renders FactoryReset component', async () => { + render() + await waitFor(() => { + expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument() + expect(screen.getByTestId(/reset-button/i)).toBeInTheDocument() + }) + }) +}) diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index f132f81e7..1384f5688 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -43,19 +43,10 @@ type GPU = { name: string } -const test = [ - { - id: 'test a', - vram: 2, - name: 'nvidia A', - }, - { - id: 'test', - vram: 2, - name: 'nvidia B', - }, -] - +/** + * Advanced Settings Screen + * @returns + */ const Advanced = () => { const [experimentalEnabled, setExperimentalEnabled] = useAtom( experimentalFeatureEnabledAtom @@ -69,7 +60,7 @@ const Advanced = () => { const [partialProxy, setPartialProxy] = useState(proxy) const [gpuEnabled, setGpuEnabled] = useState(false) - const [gpuList, setGpuList] = useState(test) + const [gpuList, setGpuList] = useState([]) const [gpusInUse, setGpusInUse] = useState([]) const [dropdownOptions, setDropdownOptions] = useState( null @@ -87,6 +78,9 @@ const Advanced = () => { return y['name'] }) + /** + * Handle proxy change + */ const onProxyChange = useCallback( (event: ChangeEvent) => { const value = event.target.value || '' @@ -100,6 +94,12 @@ const Advanced = () => { [setPartialProxy, setProxy] ) + /** + * Update Quick Ask Enabled + * @param e + * @param relaunch + * @returns void + */ const updateQuickAskEnabled = async ( e: boolean, relaunch: boolean = true @@ -111,6 +111,12 @@ const Advanced = () => { if (relaunch) window.core?.api?.relaunch() } + /** + * Update Vulkan Enabled + * @param e + * @param relaunch + * @returns void + */ const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => { toaster({ title: 'Reload', @@ -123,11 +129,19 @@ const Advanced = () => { if (relaunch) window.location.reload() } + /** + * Update Experimental Enabled + * @param e + * @returns + */ const updateExperimentalEnabled = async ( e: ChangeEvent ) => { setExperimentalEnabled(e.target.checked) - if (e) return + + // If it checked, we don't need to do anything else + // Otherwise have to reset other settings + if (e.target.checked) return // It affects other settings, so we need to reset them const isRelaunch = quickAskEnabled || vulkanEnabled @@ -136,6 +150,9 @@ const Advanced = () => { if (isRelaunch) window.core?.api?.relaunch() } + /** + * useEffect to set GPU enabled if possible + */ useEffect(() => { const setUseGpuIfPossible = async () => { const settings = await readSettings() @@ -149,6 +166,10 @@ const Advanced = () => { setUseGpuIfPossible() }, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled]) + /** + * Clear logs + * @returns + */ const clearLogs = async () => { try { await fs.rm(`file://logs`) @@ -163,6 +184,11 @@ const Advanced = () => { }) } + /** + * Handle GPU Change + * @param gpuId + * @returns + */ const handleGPUChange = (gpuId: string) => { let updatedGpusInUse = [...gpusInUse] if (updatedGpusInUse.includes(gpuId)) { @@ -188,6 +214,9 @@ const Advanced = () => { const gpuSelectionPlaceHolder = gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU" + /** + * Handle click outside + */ useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) return ( @@ -204,6 +233,7 @@ const Advanced = () => {

@@ -401,11 +431,13 @@ const Advanced = () => {
setProxyEnabled(!proxyEnabled)} />
:@:'} value={partialProxy} onChange={onProxyChange} @@ -428,6 +460,7 @@ const Advanced = () => {

setIgnoreSSL(e.target.checked)} /> @@ -448,6 +481,7 @@ const Advanced = () => {

{ toaster({ @@ -471,7 +505,11 @@ const Advanced = () => { Clear all logs from Jan app.

- From c62b6e984282003d14160ce1b222c66fa4b79038 Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Tue, 17 Sep 2024 22:13:18 +0700 Subject: [PATCH 07/37] fix: small leftover issues with new starter screen (#3661) * fix: fix duplicate render progress component * fix: minor ui issue * chore: add manual recommend model * chore: make button create thread invisible * chore: fix conflict * chore: remove selector create thread icon * test: added unit test thread screen --- electron/tests/e2e/thread.e2e.spec.ts | 29 ++++--- web/containers/Layout/RibbonPanel/index.tsx | 16 ++-- web/containers/Layout/TopPanel/index.tsx | 5 +- web/helpers/atoms/Thread.atom.ts | 3 + web/hooks/useStarterScreen.ts | 7 +- .../ChatBody/OnDeviceStarterScreen/index.tsx | 78 +++++++++++-------- web/screens/Thread/index.test.tsx | 35 +++++++++ 7 files changed, 109 insertions(+), 64 deletions(-) create mode 100644 web/screens/Thread/index.test.tsx diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts index c13e91119..5d7328053 100644 --- a/electron/tests/e2e/thread.e2e.spec.ts +++ b/electron/tests/e2e/thread.e2e.spec.ts @@ -1,32 +1,29 @@ import { expect } from '@playwright/test' import { page, test, TIMEOUT } from '../config/fixtures' -test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => { +test('Select GPT model from Hub and Chat with Invalid API Key', async ({ + hubPage, +}) => { await hubPage.navigateByMenu() await hubPage.verifyContainerVisible() // Select the first GPT model await page .locator('[data-testid^="use-model-btn"][data-testid*="gpt"]') - .first().click() - - // Attempt to create thread and chat in Thread page - await page - .getByTestId('btn-create-thread') + .first() .click() - await page - .getByTestId('txt-input-chat') - .fill('dummy value') + await page.getByTestId('txt-input-chat').fill('dummy value') - await page - .getByTestId('btn-send-chat') - .click() + await page.getByTestId('btn-send-chat').click() - await page.waitForFunction(() => { - const loaders = document.querySelectorAll('[data-testid$="loader"]'); - return !loaders.length; - }, { timeout: TIMEOUT }); + await page.waitForFunction( + () => { + const loaders = document.querySelectorAll('[data-testid$="loader"]') + return !loaders.length + }, + { timeout: TIMEOUT } + ) const APIKeyError = page.getByTestId('invalid-API-key-error') await expect(APIKeyError).toBeVisible({ diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx index 6bed2b424..7613584e0 100644 --- a/web/containers/Layout/RibbonPanel/index.tsx +++ b/web/containers/Layout/RibbonPanel/index.tsx @@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge' import { MainViewState } from '@/constants/screens' -import { localEngines } from '@/utils/modelEngine' - import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom' import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' + import { reduceTransparentAtom, selectedSettingAtom, } from '@/helpers/atoms/Setting.atom' -import { threadsAtom } from '@/helpers/atoms/Thread.atom' +import { + isDownloadALocalModelAtom, + threadsAtom, +} from '@/helpers/atoms/Thread.atom' export default function RibbonPanel() { const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom) @@ -32,8 +33,9 @@ export default function RibbonPanel() { const matches = useMediaQuery('(max-width: 880px)') const reduceTransparent = useAtomValue(reduceTransparentAtom) const setSelectedSetting = useSetAtom(selectedSettingAtom) - const downloadedModels = useAtomValue(downloadedModelsAtom) + const threads = useAtomValue(threadsAtom) + const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom) const onMenuClick = (state: MainViewState) => { if (mainViewState === state) return @@ -43,10 +45,6 @@ export default function RibbonPanel() { setEditMessage('') } - const isDownloadALocalModel = downloadedModels.some((x) => - localEngines.includes(x.engine) - ) - const RibbonNavMenus = [ { name: 'Thread', diff --git a/web/containers/Layout/TopPanel/index.tsx b/web/containers/Layout/TopPanel/index.tsx index 213f7dfa9..aff616973 100644 --- a/web/containers/Layout/TopPanel/index.tsx +++ b/web/containers/Layout/TopPanel/index.tsx @@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast' import { MainViewState } from '@/constants/screens' import { useCreateNewThread } from '@/hooks/useCreateNewThread' +import { useStarterScreen } from '@/hooks/useStarterScreen' import { mainViewStateAtom, @@ -58,6 +59,8 @@ const TopPanel = () => { requestCreateNewThread(assistants[0]) } + const { isShowStarterScreen } = useStarterScreen() + return (
{ )} )} - {mainViewState === MainViewState.Thread && ( + {mainViewState === MainViewState.Thread && !isShowStarterScreen && ( - + {toGibibytes(featModel.metadata.size)}
@@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { ) })} -
+

Cloud Models

@@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { return (
{row.map((remoteEngine) => { const engineLogo = getLogoEngine( @@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { /> )} -

+

{getTitleByEngine( remoteEngine as InferenceEngine )} diff --git a/web/screens/Thread/index.test.tsx b/web/screens/Thread/index.test.tsx new file mode 100644 index 000000000..01af0ffc5 --- /dev/null +++ b/web/screens/Thread/index.test.tsx @@ -0,0 +1,35 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import ThreadScreen from './index' +import { useStarterScreen } from '../../hooks/useStarterScreen' +import '@testing-library/jest-dom' + +global.ResizeObserver = class { + observe() {} + unobserve() {} + disconnect() {} +} +// Mock the useStarterScreen hook +jest.mock('@/hooks/useStarterScreen') + +describe('ThreadScreen', () => { + it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => { + ;(useStarterScreen as jest.Mock).mockReturnValue({ + isShowStarterScreen: true, + extensionHasSettings: false, + }) + + const { getByText } = render() + expect(getByText('Select a model to start')).toBeInTheDocument() + }) + + it('renders Thread panels when isShowStarterScreen is false', () => { + ;(useStarterScreen as jest.Mock).mockReturnValue({ + isShowStarterScreen: false, + extensionHasSettings: false, + }) + + const { getByText } = render() + expect(getByText('Welcome!')).toBeInTheDocument() + }) +}) From 3949515c8a68dee16e2209b513bcf239d2b5343a Mon Sep 17 00:00:00 2001 From: 0xSage Date: Wed, 18 Sep 2024 17:02:41 +0800 Subject: [PATCH 08/37] chore: copy nits --- .../BottomPanel/SystemMonitor/TableActiveModel/index.tsx | 2 +- web/screens/LocalServer/LocalServerLeftPanel/index.tsx | 2 +- web/screens/Settings/Advanced/FactoryReset/index.tsx | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx index c9d86e5e8..e68f843a9 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx @@ -79,7 +79,7 @@ const TableActiveModel = () => { ) : ( - No on-device model running + No models are loaded into memory )} diff --git a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx index f66945929..16aa75af5 100644 --- a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx +++ b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx @@ -130,7 +130,7 @@ const LocalServerLeftPanel = () => { {serverEnabled && ( )} diff --git a/web/screens/Settings/Advanced/FactoryReset/index.tsx b/web/screens/Settings/Advanced/FactoryReset/index.tsx index 3bbce39ef..181b0bd4b 100644 --- a/web/screens/Settings/Advanced/FactoryReset/index.tsx +++ b/web/screens/Settings/Advanced/FactoryReset/index.tsx @@ -17,8 +17,7 @@ const FactoryReset = () => {

- Reset the application to its initial state, deleting all your usage - data, including conversation history. This action is irreversible and + Restore app to initial state, erasing all models and chat history. This action is irreversible and recommended only if the application is in a corrupted state.

From 062af9bcda43256b6cc14d6c5dd0cbd927a43c7b Mon Sep 17 00:00:00 2001 From: 0xSage Date: Wed, 18 Sep 2024 17:42:35 +0800 Subject: [PATCH 09/37] nits --- .../Settings/Advanced/FactoryReset/ModalConfirmReset.tsx | 5 ++--- web/screens/Settings/Advanced/FactoryReset/index.tsx | 4 ++-- web/screens/Settings/CancelModelImportModal/index.tsx | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx index 8173574a9..268192627 100644 --- a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx +++ b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx @@ -30,9 +30,8 @@ const ModalConfirmReset = () => { content={

- It will reset the application to its original state, deleting all - your usage data, including model customizations and conversation - history. This action is irreversible. + Restore app to initial state, erasing all models and chat history. This + action is irreversible and recommended only if the application is corrupted.

diff --git a/web/screens/Settings/Advanced/FactoryReset/index.tsx b/web/screens/Settings/Advanced/FactoryReset/index.tsx index e79bfe54c..fb789e5b3 100644 --- a/web/screens/Settings/Advanced/FactoryReset/index.tsx +++ b/web/screens/Settings/Advanced/FactoryReset/index.tsx @@ -17,8 +17,8 @@ const FactoryReset = () => {

- Restore app to initial state, erasing all models and chat history. This action is irreversible and - recommended only if the application is in a corrupted state. + Restore app to initial state, erasing all models and chat history. This + action is irreversible and recommended only if the application is corrupted.

{displayDate(props.created)} From 8fe376340a4de0c9c1c998bdce32350ff6b3f23c Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Thu, 19 Sep 2024 10:06:27 +0700 Subject: [PATCH 11/37] chore: fix linter issue CI --- web/screens/Settings/CoreExtensions/ExtensionItem.tsx | 4 ++-- .../Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/screens/Settings/CoreExtensions/ExtensionItem.tsx b/web/screens/Settings/CoreExtensions/ExtensionItem.tsx index ec72f5f43..497b8ac4a 100644 --- a/web/screens/Settings/CoreExtensions/ExtensionItem.tsx +++ b/web/screens/Settings/CoreExtensions/ExtensionItem.tsx @@ -32,8 +32,8 @@ const ExtensionItem: React.FC = ({ item }) => { ) const progress = isInstalling - ? installingExtensions.find((e) => e.extensionId === item.name) - ?.percentage ?? -1 + ? (installingExtensions.find((e) => e.extensionId === item.name) + ?.percentage ?? -1) : -1 useEffect(() => { diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx index d7d52a093..abbe6db43 100644 --- a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx @@ -178,7 +178,7 @@ const SimpleTextMessage: React.FC = (props) => { > {isUser ? props.role - : activeThread?.assistants[0].assistant_name ?? props.role} + : (activeThread?.assistants[0].assistant_name ?? props.role)}

{displayDate(props.created)} From ba3c07eba8973b184cc0701f90f3c955a8a4b894 Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Thu, 19 Sep 2024 10:10:30 +0700 Subject: [PATCH 12/37] feat: textarea auto resize (#3695) * feat: improve textarea user experience with autoresize * chore: remove log * chore: update test * chore: update test and cleanup logic useEffect --- joi/src/core/TextArea/TextArea.test.tsx | 39 ++++++++++++++++++- joi/src/core/TextArea/index.tsx | 32 ++++++++++++--- web/containers/ModelConfigInput/index.tsx | 2 +- web/screens/Thread/ThreadRightPanel/index.tsx | 2 +- 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/joi/src/core/TextArea/TextArea.test.tsx b/joi/src/core/TextArea/TextArea.test.tsx index 8bc64010f..e29eed5d0 100644 --- a/joi/src/core/TextArea/TextArea.test.tsx +++ b/joi/src/core/TextArea/TextArea.test.tsx @@ -1,9 +1,8 @@ import React from 'react' -import { render, screen } from '@testing-library/react' +import { render, screen, act } from '@testing-library/react' import '@testing-library/jest-dom' import { TextArea } from './index' -// Mock the styles import jest.mock('./styles.scss', () => ({})) describe('@joi/core/TextArea', () => { @@ -31,4 +30,40 @@ describe('@joi/core/TextArea', () => { const textareaElement = screen.getByTestId('custom-textarea') expect(textareaElement).toHaveAttribute('rows', '5') }) + + it('should auto resize the textarea based on minResize', () => { + render(