From 98bef7b7cffa811a67945d8c8f4659862c15026c Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 08:34:58 +0700
Subject: [PATCH 01/37] test: add model parameter validation rules and
persistence tests (#3618)
* test: add model parameter validation rules and persistence tests
* chore: fix CI cov step
* fix: invalid model settings should fallback to origin value
* test: support fallback integer settings
---
.../src/node/index.ts | 32 +-
.../tensorrt-llm-extension/src/node/index.ts | 4 +-
web/containers/Providers/EventHandler.tsx | 4 +-
web/containers/SliderRightPanel/index.tsx | 28 +-
web/hooks/useSendChatMessage.ts | 9 +-
web/hooks/useUpdateModelParameters.test.ts | 314 ++++++++++++++++++
web/hooks/useUpdateModelParameters.ts | 16 +-
.../LocalServerRightPanel/index.tsx | 13 +-
web/screens/Thread/ThreadRightPanel/index.tsx | 26 +-
web/utils/modelParam.test.ts | 183 ++++++++++
web/utils/modelParam.ts | 106 +++++-
11 files changed, 681 insertions(+), 54 deletions(-)
create mode 100644 web/hooks/useUpdateModelParameters.test.ts
create mode 100644 web/utils/modelParam.test.ts
diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts
index edc2d013d..3a969ad5e 100644
--- a/extensions/inference-nitro-extension/src/node/index.ts
+++ b/extensions/inference-nitro-extension/src/node/index.ts
@@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise {
if (!settings?.ngl) {
settings.ngl = 100
}
- log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
+ log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
method: 'POST',
headers: {
@@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise {
})
.then((res) => {
log(
- `[CORTEX]::Debug: Load model success with response ${JSON.stringify(
+ `[CORTEX]:: Load model success with response ${JSON.stringify(
res
)}`
)
@@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise {
async function validateModelStatus(modelId: string): Promise {
// Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
- log(`[CORTEX]::Debug: Validating model ${modelId}`)
+ log(`[CORTEX]:: Validating model ${modelId}`)
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
method: 'POST',
body: JSON.stringify({
@@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise {
retryDelay: 300,
}).then(async (res: Response) => {
log(
- `[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
+ `[CORTEX]:: Validate model state with response ${JSON.stringify(
res.status
)}`
)
@@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise {
// Otherwise, return an object with an error message.
if (body.model_loaded) {
log(
- `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
+ `[CORTEX]:: Validate model state success with response ${JSON.stringify(
body
)}`
)
@@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise {
}
const errorBody = await res.text()
log(
- `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
+ `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
res.statusText
)}`
)
@@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise {
async function killSubprocess(): Promise {
const controller = new AbortController()
setTimeout(() => controller.abort(), 5000)
- log(`[CORTEX]::Debug: Request to kill cortex`)
+ log(`[CORTEX]:: Request to kill cortex`)
const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, {
@@ -321,17 +321,17 @@ async function killSubprocess(): Promise {
.then(() =>
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
)
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch((err) => {
log(
- `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
+ `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
})
}
if (subprocess?.pid && process.platform !== 'darwin') {
- log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
+ log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@@ -341,7 +341,7 @@ async function killSubprocess(): Promise {
} else {
tcpPortUsed
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.then(() => resolve())
.catch(() => {
log(
@@ -362,7 +362,7 @@ async function killSubprocess(): Promise {
* @returns A promise that resolves when the Nitro subprocess is started.
*/
function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
- log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
+ log(`[CORTEX]:: Spawning cortex subprocess...`)
return new Promise(async (resolve, reject) => {
let executableOptions = executableNitroFile(
@@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary
log(
- `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
+ `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
)
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
@@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
// Handle subprocess output
subprocess.stdout.on('data', (data: any) => {
- log(`[CORTEX]::Debug: ${data}`)
+ log(`[CORTEX]:: ${data}`)
})
subprocess.stderr.on('data', (data: any) => {
@@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
})
subprocess.on('close', (code: any) => {
- log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
+ log(`[CORTEX]:: cortex exited with code: ${code}`)
subprocess = undefined
reject(`child process exited with code ${code}`)
})
@@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
tcpPortUsed
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
.then(() => {
- log(`[CORTEX]::Debug: cortex is ready`)
+ log(`[CORTEX]:: cortex is ready`)
resolve()
})
})
diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts
index c8bc48459..77003389f 100644
--- a/extensions/tensorrt-llm-extension/src/node/index.ts
+++ b/extensions/tensorrt-llm-extension/src/node/index.ts
@@ -97,7 +97,7 @@ function unloadModel(): Promise {
}
if (subprocess?.pid) {
- log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
+ log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@@ -107,7 +107,7 @@ function unloadModel(): Promise {
return tcpPortUsed
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
.then(() => resolve())
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch(() => {
killRequest()
})
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index e4c96aeb7..4809ce83e 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
-import { toRuntimeParams } from '@/utils/modelParam'
+import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import {
@@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
},
]
- const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
+ const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
const messageRequest: MessageRequest = {
id: msgId,
diff --git a/web/containers/SliderRightPanel/index.tsx b/web/containers/SliderRightPanel/index.tsx
index df415ffb5..c00d9f002 100644
--- a/web/containers/SliderRightPanel/index.tsx
+++ b/web/containers/SliderRightPanel/index.tsx
@@ -87,26 +87,28 @@ const SliderRightPanel = ({
onValueChanged?.(Number(min))
setVal(min.toString())
setShowTooltip({ max: false, min: true })
+ } else {
+ setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
}
}}
onChange={(e) => {
- // Should not accept invalid value or NaN
- // E.g. anything changes that trigger onValueChanged
- // Which is incorrect
- if (Number(e.target.value) > Number(max)) {
- setVal(max.toString())
- } else if (
- Number(e.target.value) < Number(min) ||
- !e.target.value.length
- ) {
- setVal(min.toString())
- } else if (Number.isNaN(Number(e.target.value))) return
-
- onValueChanged?.(Number(e.target.value))
// TODO: How to support negative number input?
+ // Passthru since it validates again onBlur
if (/^\d*\.?\d*$/.test(e.target.value)) {
setVal(e.target.value)
}
+
+ // Should not accept invalid value or NaN
+ // E.g. anything changes that trigger onValueChanged
+ // Which is incorrect
+ if (
+ Number(e.target.value) > Number(max) ||
+ Number(e.target.value) < Number(min) ||
+ Number.isNaN(Number(e.target.value))
+ ) {
+ return
+ }
+ onValueChanged?.(Number(e.target.value))
}}
/>
}
diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts
index 8c6013505..1dbd5b45e 100644
--- a/web/hooks/useSendChatMessage.ts
+++ b/web/hooks/useSendChatMessage.ts
@@ -23,7 +23,10 @@ import {
import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
@@ -189,8 +192,8 @@ export default function useSendChatMessage() {
if (engineParamsUpdate) setReloadModel(true)
- const runtimeParams = toRuntimeParams(activeModelParams)
- const settingParams = toSettingParams(activeModelParams)
+ const runtimeParams = extractInferenceParams(activeModelParams)
+ const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim()
diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts
new file mode 100644
index 000000000..bc60aa631
--- /dev/null
+++ b/web/hooks/useUpdateModelParameters.test.ts
@@ -0,0 +1,314 @@
+import { renderHook, act } from '@testing-library/react'
+// Mock dependencies
+jest.mock('ulidx')
+jest.mock('@/extension')
+
+import useUpdateModelParameters from './useUpdateModelParameters'
+import { extensionManager } from '@/extension'
+
+// Mock data
+let model: any = {
+ id: 'model-1',
+ engine: 'nitro',
+}
+
+let extension: any = {
+ saveThread: jest.fn(),
+}
+
+const mockThread: any = {
+ id: 'thread-1',
+ assistants: [
+ {
+ model: {
+ parameters: {},
+ settings: {},
+ },
+ },
+ ],
+ object: 'thread',
+ title: 'New Thread',
+ created: 0,
+ updated: 0,
+}
+
+describe('useUpdateModelParameters', () => {
+ beforeAll(() => {
+ jest.clearAllMocks()
+ jest.mock('./useRecommendedModel', () => ({
+ useRecommendedModel: () => ({
+ recommendedModel: model,
+ setRecommendedModel: jest.fn(),
+ downloadedModels: [],
+ }),
+ }))
+ })
+
+ it('should update model parameters and save thread when params are valid', async () => {
+ const mockValidParameters: any = {
+ params: {
+ // Inference
+ stop: ['', ''],
+ temperature: 0.5,
+ token_limit: 1000,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 0.3,
+ presence_penalty: 0.2,
+
+ // Load model
+ ctx_len: 1024,
+ ngl: 12,
+ embedding: true,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(mockThread, mockValidParameters)
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: ['', ''],
+ temperature: 0.5,
+ token_limit: 1000,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 0.3,
+ presence_penalty: 0.2,
+ },
+ settings: {
+ ctx_len: 1024,
+ ngl: 12,
+ embedding: true,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should not update invalid model parameters', async () => {
+ const mockInvalidParameters: any = {
+ params: {
+ // Inference
+ stop: [1, ''],
+ temperature: '0.5',
+ token_limit: '1000',
+ top_k: '0.7',
+ top_p: '0.1',
+ stream: 'true',
+ max_tokens: '1000',
+ frequency_penalty: '0.3',
+ presence_penalty: '0.2',
+
+ // Load model
+ ctx_len: '1024',
+ ngl: '12',
+ embedding: 'true',
+ n_parallel: '2',
+ cpu_threads: '4',
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockInvalidParameters
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ max_tokens: 1000,
+ token_limit: 1000,
+ },
+ settings: {
+ cpu_threads: 4,
+ ctx_len: 1024,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ n_parallel: 2,
+ ngl: 12,
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should update valid model parameters only', async () => {
+ const mockInvalidParameters: any = {
+ params: {
+ // Inference
+ stop: [''],
+ temperature: -0.5,
+ token_limit: 100.2,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 1.2,
+ presence_penalty: 0.2,
+
+ // Load model
+ ctx_len: 1024,
+ ngl: 0,
+ embedding: 'true',
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockInvalidParameters
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: [''],
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ token_limit: 100,
+ max_tokens: 1000,
+ presence_penalty: 0.2,
+ },
+ settings: {
+ ctx_len: 1024,
+ ngl: 0,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should handle missing modelId and engine gracefully', async () => {
+ const mockParametersWithoutModelIdAndEngine: any = {
+ params: {
+ stop: ['', ''],
+ temperature: 0.5,
+ },
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockParametersWithoutModelIdAndEngine
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: ['', ''],
+ temperature: 0.5,
+ },
+ settings: {},
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+})
diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts
index 79d877456..46bf07cd5 100644
--- a/web/hooks/useUpdateModelParameters.ts
+++ b/web/hooks/useUpdateModelParameters.ts
@@ -12,7 +12,10 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import useRecommendedModel from './useRecommendedModel'
@@ -47,12 +50,17 @@ export default function useUpdateModelParameters() {
const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId
? toUpdateSettings
- : { ...activeModelParams, ...toUpdateSettings }
+ : {
+ ...selectedModel?.parameters,
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ ...toUpdateSettings,
+ }
// update the state
setThreadModelParams(thread.id, updatedModelParams)
- const runtimeParams = toRuntimeParams(updatedModelParams)
- const settingParams = toSettingParams(updatedModelParams)
+ const runtimeParams = extractInferenceParams(updatedModelParams)
+ const settingParams = extractModelLoadParams(updatedModelParams)
const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {
diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx
index 309709c26..13e3cad57 100644
--- a/web/screens/LocalServer/LocalServerRightPanel/index.tsx
+++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx
@@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { getConfigurationsData } from '@/utils/componentSettings'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
@@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
const selectedModel = useAtomValue(selectedModelAtom)
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
- toSettingParams(selectedModel?.settings)
+ extractModelLoadParams(selectedModel?.settings)
)
useEffect(() => {
if (selectedModel) {
- setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
+ setCurrentModelSettingParams(
+ extractModelLoadParams(selectedModel?.settings)
+ )
}
}, [selectedModel])
- const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
+ const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,
diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx
index 9e7cdf7d8..e7d0a27b9 100644
--- a/web/screens/Thread/ThreadRightPanel/index.tsx
+++ b/web/screens/Thread/ThreadRightPanel/index.tsx
@@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
import { localEngines } from '@/utils/modelEngine'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools'
@@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
const settings = useMemo(() => {
// runtime setting
- const modelRuntimeParams = toRuntimeParams(activeModelParams)
+ const modelRuntimeParams = extractInferenceParams(
+ {
+ ...selectedModel?.parameters,
+ ...activeModelParams,
+ },
+ selectedModel?.parameters
+ )
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,
selectedModel
).filter((x) => x.key !== 'prompt_template')
// engine setting
- const modelEngineParams = toSettingParams(activeModelParams)
+ const modelEngineParams = extractModelLoadParams(
+ {
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ },
+ selectedModel?.settings
+ )
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel
@@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
}, [activeModelParams, selectedModel])
const promptTemplateSettings = useMemo(() => {
- const modelEngineParams = toSettingParams(activeModelParams)
+ const modelEngineParams = extractModelLoadParams({
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ })
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel
diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts
new file mode 100644
index 000000000..f1b858955
--- /dev/null
+++ b/web/utils/modelParam.test.ts
@@ -0,0 +1,183 @@
+// web/utils/modelParam.test.ts
+import { normalizeValue, validationRules } from './modelParam'
+
+describe('validationRules', () => {
+ it('should validate temperature correctly', () => {
+ expect(validationRules.temperature(0.5)).toBe(true)
+ expect(validationRules.temperature(2)).toBe(true)
+ expect(validationRules.temperature(0)).toBe(true)
+ expect(validationRules.temperature(-0.1)).toBe(false)
+ expect(validationRules.temperature(2.3)).toBe(false)
+ expect(validationRules.temperature('0.5')).toBe(false)
+ })
+
+ it('should validate token_limit correctly', () => {
+ expect(validationRules.token_limit(100)).toBe(true)
+ expect(validationRules.token_limit(1)).toBe(true)
+ expect(validationRules.token_limit(0)).toBe(true)
+ expect(validationRules.token_limit(-1)).toBe(false)
+ expect(validationRules.token_limit('100')).toBe(false)
+ })
+
+ it('should validate top_k correctly', () => {
+ expect(validationRules.top_k(0.5)).toBe(true)
+ expect(validationRules.top_k(1)).toBe(true)
+ expect(validationRules.top_k(0)).toBe(true)
+ expect(validationRules.top_k(-0.1)).toBe(false)
+ expect(validationRules.top_k(1.1)).toBe(false)
+ expect(validationRules.top_k('0.5')).toBe(false)
+ })
+
+ it('should validate top_p correctly', () => {
+ expect(validationRules.top_p(0.5)).toBe(true)
+ expect(validationRules.top_p(1)).toBe(true)
+ expect(validationRules.top_p(0)).toBe(true)
+ expect(validationRules.top_p(-0.1)).toBe(false)
+ expect(validationRules.top_p(1.1)).toBe(false)
+ expect(validationRules.top_p('0.5')).toBe(false)
+ })
+
+ it('should validate stream correctly', () => {
+ expect(validationRules.stream(true)).toBe(true)
+ expect(validationRules.stream(false)).toBe(true)
+ expect(validationRules.stream('true')).toBe(false)
+ expect(validationRules.stream(1)).toBe(false)
+ })
+
+ it('should validate max_tokens correctly', () => {
+ expect(validationRules.max_tokens(100)).toBe(true)
+ expect(validationRules.max_tokens(1)).toBe(true)
+ expect(validationRules.max_tokens(0)).toBe(true)
+ expect(validationRules.max_tokens(-1)).toBe(false)
+ expect(validationRules.max_tokens('100')).toBe(false)
+ })
+
+ it('should validate stop correctly', () => {
+ expect(validationRules.stop(['word1', 'word2'])).toBe(true)
+ expect(validationRules.stop([])).toBe(true)
+ expect(validationRules.stop(['word1', 2])).toBe(false)
+ expect(validationRules.stop('word1')).toBe(false)
+ })
+
+ it('should validate frequency_penalty correctly', () => {
+ expect(validationRules.frequency_penalty(0.5)).toBe(true)
+ expect(validationRules.frequency_penalty(1)).toBe(true)
+ expect(validationRules.frequency_penalty(0)).toBe(true)
+ expect(validationRules.frequency_penalty(-0.1)).toBe(false)
+ expect(validationRules.frequency_penalty(1.1)).toBe(false)
+ expect(validationRules.frequency_penalty('0.5')).toBe(false)
+ })
+
+ it('should validate presence_penalty correctly', () => {
+ expect(validationRules.presence_penalty(0.5)).toBe(true)
+ expect(validationRules.presence_penalty(1)).toBe(true)
+ expect(validationRules.presence_penalty(0)).toBe(true)
+ expect(validationRules.presence_penalty(-0.1)).toBe(false)
+ expect(validationRules.presence_penalty(1.1)).toBe(false)
+ expect(validationRules.presence_penalty('0.5')).toBe(false)
+ })
+
+ it('should validate ctx_len correctly', () => {
+ expect(validationRules.ctx_len(1024)).toBe(true)
+ expect(validationRules.ctx_len(1)).toBe(true)
+ expect(validationRules.ctx_len(0)).toBe(true)
+ expect(validationRules.ctx_len(-1)).toBe(false)
+ expect(validationRules.ctx_len('1024')).toBe(false)
+ })
+
+ it('should validate ngl correctly', () => {
+ expect(validationRules.ngl(12)).toBe(true)
+ expect(validationRules.ngl(1)).toBe(true)
+ expect(validationRules.ngl(0)).toBe(true)
+ expect(validationRules.ngl(-1)).toBe(false)
+ expect(validationRules.ngl('12')).toBe(false)
+ })
+
+ it('should validate embedding correctly', () => {
+ expect(validationRules.embedding(true)).toBe(true)
+ expect(validationRules.embedding(false)).toBe(true)
+ expect(validationRules.embedding('true')).toBe(false)
+ expect(validationRules.embedding(1)).toBe(false)
+ })
+
+ it('should validate n_parallel correctly', () => {
+ expect(validationRules.n_parallel(2)).toBe(true)
+ expect(validationRules.n_parallel(1)).toBe(true)
+ expect(validationRules.n_parallel(0)).toBe(true)
+ expect(validationRules.n_parallel(-1)).toBe(false)
+ expect(validationRules.n_parallel('2')).toBe(false)
+ })
+
+ it('should validate cpu_threads correctly', () => {
+ expect(validationRules.cpu_threads(4)).toBe(true)
+ expect(validationRules.cpu_threads(1)).toBe(true)
+ expect(validationRules.cpu_threads(0)).toBe(true)
+ expect(validationRules.cpu_threads(-1)).toBe(false)
+ expect(validationRules.cpu_threads('4')).toBe(false)
+ })
+
+ it('should validate prompt_template correctly', () => {
+ expect(validationRules.prompt_template('template')).toBe(true)
+ expect(validationRules.prompt_template('')).toBe(true)
+ expect(validationRules.prompt_template(123)).toBe(false)
+ })
+
+ it('should validate llama_model_path correctly', () => {
+ expect(validationRules.llama_model_path('path')).toBe(true)
+ expect(validationRules.llama_model_path('')).toBe(true)
+ expect(validationRules.llama_model_path(123)).toBe(false)
+ })
+
+ it('should validate mmproj correctly', () => {
+ expect(validationRules.mmproj('mmproj')).toBe(true)
+ expect(validationRules.mmproj('')).toBe(true)
+ expect(validationRules.mmproj(123)).toBe(false)
+ })
+
+ it('should validate vision_model correctly', () => {
+ expect(validationRules.vision_model(true)).toBe(true)
+ expect(validationRules.vision_model(false)).toBe(true)
+ expect(validationRules.vision_model('true')).toBe(false)
+ expect(validationRules.vision_model(1)).toBe(false)
+ })
+
+ it('should validate text_model correctly', () => {
+ expect(validationRules.text_model(true)).toBe(true)
+ expect(validationRules.text_model(false)).toBe(true)
+ expect(validationRules.text_model('true')).toBe(false)
+ expect(validationRules.text_model(1)).toBe(false)
+ })
+})
+
+describe('normalizeValue', () => {
+ it('should normalize ctx_len correctly', () => {
+ expect(normalizeValue('ctx_len', 100.5)).toBe(100)
+ expect(normalizeValue('ctx_len', '2')).toBe(2)
+ expect(normalizeValue('ctx_len', 100)).toBe(100)
+ })
+ it('should normalize token_limit correctly', () => {
+ expect(normalizeValue('token_limit', 100.5)).toBe(100)
+ expect(normalizeValue('token_limit', '1')).toBe(1)
+ expect(normalizeValue('token_limit', 0)).toBe(0)
+ })
+ it('should normalize max_tokens correctly', () => {
+ expect(normalizeValue('max_tokens', 100.5)).toBe(100)
+ expect(normalizeValue('max_tokens', '1')).toBe(1)
+ expect(normalizeValue('max_tokens', 0)).toBe(0)
+ })
+ it('should normalize ngl correctly', () => {
+ expect(normalizeValue('ngl', 12.5)).toBe(12)
+ expect(normalizeValue('ngl', '2')).toBe(2)
+ expect(normalizeValue('ngl', 0)).toBe(0)
+ })
+ it('should normalize n_parallel correctly', () => {
+ expect(normalizeValue('n_parallel', 2.5)).toBe(2)
+ expect(normalizeValue('n_parallel', '2')).toBe(2)
+ expect(normalizeValue('n_parallel', 0)).toBe(0)
+ })
+ it('should normalize cpu_threads correctly', () => {
+ expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
+ expect(normalizeValue('cpu_threads', '4')).toBe(4)
+ expect(normalizeValue('cpu_threads', 0)).toBe(0)
+ })
+})
diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts
index a6d144c3e..dda9cf761 100644
--- a/web/utils/modelParam.ts
+++ b/web/utils/modelParam.ts
@@ -1,9 +1,69 @@
+/* eslint-disable @typescript-eslint/no-explicit-any */
+/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/helpers/atoms/Thread.atom'
-export const toRuntimeParams = (
- modelParams?: ModelParams
+/**
+ * Validation rules for model parameters
+ */
+export const validationRules: { [key: string]: (value: any) => boolean } = {
+ temperature: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 2,
+ token_limit: (value: any) => Number.isInteger(value) && value >= 0,
+ top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
+ top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
+ stream: (value: any) => typeof value === 'boolean',
+ max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
+ stop: (value: any) =>
+ Array.isArray(value) && value.every((v) => typeof v === 'string'),
+ frequency_penalty: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 1,
+ presence_penalty: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 1,
+
+ ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
+ ngl: (value: any) => Number.isInteger(value) && value >= 0,
+ embedding: (value: any) => typeof value === 'boolean',
+ n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
+ cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
+ prompt_template: (value: any) => typeof value === 'string',
+ llama_model_path: (value: any) => typeof value === 'string',
+ mmproj: (value: any) => typeof value === 'string',
+ vision_model: (value: any) => typeof value === 'boolean',
+ text_model: (value: any) => typeof value === 'boolean',
+}
+
+/**
+ * There are some parameters that need to be normalized before being sent to the server
+ * E.g. ctx_len should be an integer, but it can be a float from the input field
+ * @param key
+ * @param value
+ * @returns
+ */
+export const normalizeValue = (key: string, value: any) => {
+ if (
+ key === 'token_limit' ||
+ key === 'max_tokens' ||
+ key === 'ctx_len' ||
+ key === 'ngl' ||
+ key === 'n_parallel' ||
+ key === 'cpu_threads'
+ ) {
+ // Convert to integer
+ return Math.floor(Number(value))
+ }
+ return value
+}
+
+/**
+ * Extract inference parameters from flat model parameters
+ * @param modelParams
+ * @returns
+ */
+export const extractInferenceParams = (
+ modelParams?: ModelParams,
+ originParams?: ModelParams
): ModelRuntimeParams => {
if (!modelParams) return {}
const defaultModelParams: ModelRuntimeParams = {
@@ -22,15 +82,35 @@ export const toRuntimeParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) {
- Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
+ const validate = validationRules[key]
+ if (validate && !validate(normalizeValue(key, value))) {
+ // Invalid value - fall back to origin value
+ if (originParams && key in originParams) {
+ Object.assign(runtimeParams, {
+ ...runtimeParams,
+ [key]: originParams[key as keyof typeof originParams],
+ })
+ }
+ } else {
+ Object.assign(runtimeParams, {
+ ...runtimeParams,
+ [key]: normalizeValue(key, value),
+ })
+ }
}
}
return runtimeParams
}
-export const toSettingParams = (
- modelParams?: ModelParams
+/**
+ * Extract model load parameters from flat model parameters
+ * @param modelParams
+ * @returns
+ */
+export const extractModelLoadParams = (
+ modelParams?: ModelParams,
+ originParams?: ModelParams
): ModelSettingParams => {
if (!modelParams) return {}
const defaultSettingParams: ModelSettingParams = {
@@ -49,7 +129,21 @@ export const toSettingParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) {
- Object.assign(settingParams, { ...settingParams, [key]: value })
+ const validate = validationRules[key]
+ if (validate && !validate(normalizeValue(key, value))) {
+ // Invalid value - fall back to origin value
+ if (originParams && key in originParams) {
+ Object.assign(modelParams, {
+ ...modelParams,
+ [key]: originParams[key as keyof typeof originParams],
+ })
+ }
+ } else {
+ Object.assign(settingParams, {
+ ...settingParams,
+ [key]: normalizeValue(key, value),
+ })
+ }
}
}
From 670013baa037003f82c29607345446a03dc07c0c Mon Sep 17 00:00:00 2001
From: Ronnie Ghose <1313566+RONNCC@users.noreply.github.com>
Date: Mon, 16 Sep 2024 19:25:08 -0700
Subject: [PATCH 02/37] Add support for 'o1-preview' and 'o1-mini' models
(#3659)
Add support for 'o1-preview' and 'o1-mini' model names in the OpenAI API.
* **Update `models.json`**:
- Add 'o1-preview' model details with appropriate parameters and metadata.
- Add 'o1-mini' model details with appropriate parameters and metadata.
---
For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/janhq/jan?shareId=XXXX-XXXX-XXXX-XXXX).
---
.../resources/models.json | 60 +++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json
index 6852a1892..72517d540 100644
--- a/extensions/inference-openai-extension/resources/models.json
+++ b/extensions/inference-openai-extension/resources/models.json
@@ -119,5 +119,65 @@
]
},
"engine": "openai"
+ },
+ {
+ "sources": [
+ {
+ "url": "https://openai.com"
+ }
+ ],
+ "id": "o1-preview",
+ "object": "model",
+ "name": "OpenAI o1-preview",
+ "version": "1.0",
+ "description": "OpenAI o1-preview is a new model with complex reasoning",
+ "format": "api",
+ "settings": {},
+ "parameters": {
+ "max_tokens": 4096,
+ "temperature": 0.7,
+ "top_p": 0.95,
+ "stream": true,
+ "stop": [],
+ "frequency_penalty": 0,
+ "presence_penalty": 0
+ },
+ "metadata": {
+ "author": "OpenAI",
+ "tags": [
+ "General"
+ ]
+ },
+ "engine": "openai"
+ },
+ {
+ "sources": [
+ {
+ "url": "https://openai.com"
+ }
+ ],
+ "id": "o1-mini",
+ "object": "model",
+ "name": "OpenAI o1-mini",
+ "version": "1.0",
+ "description": "OpenAI o1-mini is a lightweight reasoning model",
+ "format": "api",
+ "settings": {},
+ "parameters": {
+ "max_tokens": 4096,
+ "temperature": 0.7,
+ "top_p": 0.95,
+ "stream": true,
+ "stop": [],
+ "frequency_penalty": 0,
+ "presence_penalty": 0
+ },
+ "metadata": {
+ "author": "OpenAI",
+ "tags": [
+ "General"
+ ]
+ },
+ "engine": "openai"
}
]
From c8a08f11155a64a2789a233f1518a0df041df623 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 09:25:55 +0700
Subject: [PATCH 03/37] fix: correct prompt template for Phi3 Medium model
(#3670)
---
extensions/inference-nitro-extension/package.json | 2 +-
.../resources/models/phi3-medium/model.json | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json
index 425e4b49c..ac3ed180a 100644
--- a/extensions/inference-nitro-extension/package.json
+++ b/extensions/inference-nitro-extension/package.json
@@ -1,7 +1,7 @@
{
"name": "@janhq/inference-cortex-extension",
"productName": "Cortex Inference Engine",
- "version": "1.0.16",
+ "version": "1.0.17",
"description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.",
"main": "dist/index.js",
"node": "dist/node/index.cjs.js",
diff --git a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
index 50944b9fe..7331b2fd8 100644
--- a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
+++ b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
@@ -8,12 +8,12 @@
"id": "phi3-medium",
"object": "model",
"name": "Phi-3 Medium Instruct Q4",
- "version": "1.3",
+ "version": "1.4",
"description": "Phi-3 Medium is Microsoft's latest SOTA model.",
"format": "gguf",
"settings": {
"ctx_len": 128000,
- "prompt_template": "<|user|> {prompt}<|end|><|assistant|><|end|>",
+ "prompt_template": "<|user|> {prompt}<|end|><|assistant|>",
"llama_model_path": "Phi-3-medium-128k-instruct-Q4_K_M.gguf",
"ngl": 33
},
From c3cb1924866e3b94c868d9eacc5c43190a955452 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 16:09:38 +0700
Subject: [PATCH 04/37] fix: #3667 - The recommended label should be hidden
(#3687)
---
web/containers/ModelLabel/ModelLabel.test.tsx | 100 ++++++++++++++++++
web/containers/ModelLabel/index.tsx | 6 +-
2 files changed, 101 insertions(+), 5 deletions(-)
create mode 100644 web/containers/ModelLabel/ModelLabel.test.tsx
diff --git a/web/containers/ModelLabel/ModelLabel.test.tsx b/web/containers/ModelLabel/ModelLabel.test.tsx
new file mode 100644
index 000000000..48504ff6a
--- /dev/null
+++ b/web/containers/ModelLabel/ModelLabel.test.tsx
@@ -0,0 +1,100 @@
+import React from 'react'
+import { render, waitFor, screen } from '@testing-library/react'
+import { useAtomValue } from 'jotai'
+import { useActiveModel } from '@/hooks/useActiveModel'
+import { useSettings } from '@/hooks/useSettings'
+import ModelLabel from '@/containers/ModelLabel'
+
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ atom: jest.fn(),
+}))
+
+jest.mock('@/hooks/useActiveModel', () => ({
+ useActiveModel: jest.fn(),
+}))
+
+jest.mock('@/hooks/useSettings', () => ({
+ useSettings: jest.fn(),
+}))
+
+describe('ModelLabel', () => {
+ const mockUseAtomValue = useAtomValue as jest.Mock
+ const mockUseActiveModel = useActiveModel as jest.Mock
+ const mockUseSettings = useSettings as jest.Mock
+
+ const defaultProps: any = {
+ metadata: {
+ author: 'John Doe', // Add the 'author' property with a value
+ tags: ['8B'],
+ size: 100,
+ },
+ compact: false,
+ }
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('renders NotEnoughMemoryLabel when minimumRamModel is greater than totalRam', async () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(0)
+ .mockReturnValueOnce(0)
+ .mockReturnValueOnce(0)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ render( )
+ await waitFor(() => {
+ expect(screen.getByText('Not enough RAM')).toBeDefined()
+ })
+ })
+
+ it('renders SlowOnYourDeviceLabel when minimumRamModel is less than totalRam but greater than availableRam', async () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(100)
+ .mockReturnValueOnce(50)
+ .mockReturnValueOnce(10)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ const props = {
+ ...defaultProps,
+ metadata: {
+ ...defaultProps.metadata,
+ size: 50,
+ },
+ }
+
+ render( )
+ await waitFor(() => {
+ expect(screen.getByText('Slow on your device')).toBeDefined()
+ })
+ })
+
+ it('renders nothing when minimumRamModel is less than availableRam', () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(100)
+ .mockReturnValueOnce(50)
+ .mockReturnValueOnce(0)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ const props = {
+ ...defaultProps,
+ metadata: {
+ ...defaultProps.metadata,
+ size: 10,
+ },
+ }
+
+ const { container } = render( )
+ expect(container.firstChild).toBeNull()
+ })
+})
diff --git a/web/containers/ModelLabel/index.tsx b/web/containers/ModelLabel/index.tsx
index 2c32e288c..b0a3da96f 100644
--- a/web/containers/ModelLabel/index.tsx
+++ b/web/containers/ModelLabel/index.tsx
@@ -10,8 +10,6 @@ import { useSettings } from '@/hooks/useSettings'
import NotEnoughMemoryLabel from './NotEnoughMemoryLabel'
-import RecommendedLabel from './RecommendedLabel'
-
import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel'
import {
@@ -53,9 +51,7 @@ const ModelLabel = ({ metadata, compact }: Props) => {
/>
)
}
- if (minimumRamModel < availableRam && !compact) {
- return
- }
+
if (minimumRamModel < totalRam && minimumRamModel > availableRam) {
return
}
From 8e603bd5dbb80ef3050e313a0b046101ac81cc03 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 16:43:47 +0700
Subject: [PATCH 05/37] fix: #3476 - Mismatch id between model json and path
(#3645)
* fix: mismatch between model json and path
* chore: revert preserve model settings
* test: add tests
---
.gitignore | 1 +
core/src/browser/core.test.ts | 179 +++---
core/src/browser/core.ts | 8 +
.../browser/extensions/engines/AIEngine.ts | 6 +-
.../extensions/engines/LocalOAIEngine.ts | 16 +-
core/src/browser/extensions/model.ts | 10 +-
core/src/node/api/processors/app.test.ts | 75 ++-
core/src/node/api/processors/app.ts | 16 +-
core/src/types/api/index.ts | 1 +
core/src/types/file/index.ts | 15 +
core/src/types/model/modelEntity.ts | 7 +
core/src/types/model/modelInterface.ts | 11 +-
.../inference-nitro-extension/src/index.ts | 3 +-
.../src/node/index.ts | 4 +-
extensions/model-extension/jest.config.js | 9 +
extensions/model-extension/package.json | 1 +
extensions/model-extension/rollup.config.ts | 4 +-
extensions/model-extension/src/index.test.ts | 564 ++++++++++++++++++
extensions/model-extension/src/index.ts | 85 ++-
extensions/model-extension/tsconfig.json | 3 +-
.../tensorrt-llm-extension/src/index.ts | 3 +-
web/containers/ModelDropdown/index.tsx | 22 +-
web/helpers/atoms/AppConfig.atom.ts | 7 -
web/helpers/atoms/Model.atom.ts | 19 +-
web/hooks/useActiveModel.ts | 6 +-
web/hooks/useCreateNewThread.ts | 20 +-
web/hooks/useDeleteModel.ts | 12 +-
web/hooks/useModels.ts | 5 +-
web/hooks/useRecommendedModel.ts | 12 +-
web/hooks/useUpdateModelParameters.ts | 60 +-
.../Hub/ModelList/ModelHeader/index.tsx | 4 +-
web/screens/Hub/ModelList/ModelItem/index.tsx | 4 +-
web/screens/Hub/ModelList/index.tsx | 14 +-
.../ModelDownloadRow/index.tsx | 8 +-
.../Settings/MyModels/MyModelList/index.tsx | 4 +-
35 files changed, 879 insertions(+), 339 deletions(-)
create mode 100644 extensions/model-extension/jest.config.js
create mode 100644 extensions/model-extension/src/index.test.ts
diff --git a/.gitignore b/.gitignore
index 646e6842a..eaee28a62 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,3 +45,4 @@ core/test_results.html
coverage
.yarn
.yarnrc
+*.tsbuildinfo
diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts
index 84250888e..f38cc0b40 100644
--- a/core/src/browser/core.test.ts
+++ b/core/src/browser/core.test.ts
@@ -1,98 +1,109 @@
-import { openExternalUrl } from './core';
-import { joinPath } from './core';
-import { openFileExplorer } from './core';
-import { getJanDataFolderPath } from './core';
-import { abortDownload } from './core';
-import { getFileSize } from './core';
-import { executeOnMain } from './core';
+import { openExternalUrl } from './core'
+import { joinPath } from './core'
+import { openFileExplorer } from './core'
+import { getJanDataFolderPath } from './core'
+import { abortDownload } from './core'
+import { getFileSize } from './core'
+import { executeOnMain } from './core'
-it('should open external url', async () => {
- const url = 'http://example.com';
- globalThis.core = {
- api: {
- openExternalUrl: jest.fn().mockResolvedValue('opened')
+describe('test core apis', () => {
+ it('should open external url', async () => {
+ const url = 'http://example.com'
+ globalThis.core = {
+ api: {
+ openExternalUrl: jest.fn().mockResolvedValue('opened'),
+ },
}
- };
- const result = await openExternalUrl(url);
- expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url);
- expect(result).toBe('opened');
-});
+ const result = await openExternalUrl(url)
+ expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
+ expect(result).toBe('opened')
+ })
-
-it('should join paths', async () => {
- const paths = ['/path/one', '/path/two'];
- globalThis.core = {
- api: {
- joinPath: jest.fn().mockResolvedValue('/path/one/path/two')
+ it('should join paths', async () => {
+ const paths = ['/path/one', '/path/two']
+ globalThis.core = {
+ api: {
+ joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
+ },
}
- };
- const result = await joinPath(paths);
- expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths);
- expect(result).toBe('/path/one/path/two');
-});
+ const result = await joinPath(paths)
+ expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
+ expect(result).toBe('/path/one/path/two')
+ })
-
-it('should open file explorer', async () => {
- const path = '/path/to/open';
- globalThis.core = {
- api: {
- openFileExplorer: jest.fn().mockResolvedValue('opened')
+ it('should open file explorer', async () => {
+ const path = '/path/to/open'
+ globalThis.core = {
+ api: {
+ openFileExplorer: jest.fn().mockResolvedValue('opened'),
+ },
}
- };
- const result = await openFileExplorer(path);
- expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path);
- expect(result).toBe('opened');
-});
+ const result = await openFileExplorer(path)
+ expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
+ expect(result).toBe('opened')
+ })
-
-it('should get jan data folder path', async () => {
- globalThis.core = {
- api: {
- getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data')
+ it('should get jan data folder path', async () => {
+ globalThis.core = {
+ api: {
+ getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
+ },
}
- };
- const result = await getJanDataFolderPath();
- expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled();
- expect(result).toBe('/path/to/jan/data');
-});
+ const result = await getJanDataFolderPath()
+ expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
+ expect(result).toBe('/path/to/jan/data')
+ })
-
-it('should abort download', async () => {
- const fileName = 'testFile';
- globalThis.core = {
- api: {
- abortDownload: jest.fn().mockResolvedValue('aborted')
+ it('should abort download', async () => {
+ const fileName = 'testFile'
+ globalThis.core = {
+ api: {
+ abortDownload: jest.fn().mockResolvedValue('aborted'),
+ },
}
- };
- const result = await abortDownload(fileName);
- expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName);
- expect(result).toBe('aborted');
-});
+ const result = await abortDownload(fileName)
+ expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
+ expect(result).toBe('aborted')
+ })
-
-it('should get file size', async () => {
- const url = 'http://example.com/file';
- globalThis.core = {
- api: {
- getFileSize: jest.fn().mockResolvedValue(1024)
+ it('should get file size', async () => {
+ const url = 'http://example.com/file'
+ globalThis.core = {
+ api: {
+ getFileSize: jest.fn().mockResolvedValue(1024),
+ },
}
- };
- const result = await getFileSize(url);
- expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url);
- expect(result).toBe(1024);
-});
+ const result = await getFileSize(url)
+ expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
+ expect(result).toBe(1024)
+ })
-
-it('should execute function on main process', async () => {
- const extension = 'testExtension';
- const method = 'testMethod';
- const args = ['arg1', 'arg2'];
- globalThis.core = {
- api: {
- invokeExtensionFunc: jest.fn().mockResolvedValue('result')
+ it('should execute function on main process', async () => {
+ const extension = 'testExtension'
+ const method = 'testMethod'
+ const args = ['arg1', 'arg2']
+ globalThis.core = {
+ api: {
+ invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
+ },
}
- };
- const result = await executeOnMain(extension, method, ...args);
- expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args);
- expect(result).toBe('result');
-});
+ const result = await executeOnMain(extension, method, ...args)
+ expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
+ expect(result).toBe('result')
+ })
+})
+
+describe('dirName - just a pass thru api', () => {
+ it('should retrieve the directory name from a file path', async () => {
+ const mockDirName = jest.fn()
+ globalThis.core = {
+ api: {
+ dirName: mockDirName.mockResolvedValue('/path/to'),
+ },
+ }
+ // Normal file path with extension
+ const path = '/path/to/file.txt'
+ await globalThis.core.api.dirName(path)
+ expect(mockDirName).toHaveBeenCalledWith(path)
+ })
+})
diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts
index fdbceb06b..b19e0b339 100644
--- a/core/src/browser/core.ts
+++ b/core/src/browser/core.ts
@@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) =>
const joinPath: (paths: string[]) => Promise = (paths) =>
globalThis.core.api?.joinPath(paths)
+/**
+ * Get dirname of a file path.
+ * @param path - The file path to retrieve dirname.
+ * @returns {Promise} A promise that resolves the dirname.
+ */
+const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path)
+
/**
* Retrieve the basename from an url.
* @param path - The path to retrieve.
@@ -161,5 +168,6 @@ export {
systemInformation,
showToast,
getFileSize,
+ dirName,
FileStat,
}
diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts
index 7cd9f513e..75354de88 100644
--- a/core/src/browser/extensions/engines/AIEngine.ts
+++ b/core/src/browser/extensions/engines/AIEngine.ts
@@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
-import { MessageRequest, Model, ModelEvent } from '../../../types'
+import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
import { EngineManager } from './EngineManager'
/**
@@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
override onLoad() {
this.registerEngine()
- events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
+ events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
@@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads the model.
*/
- async loadModel(model: Model): Promise {
+ async loadModel(model: ModelFile): Promise {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts
index fb9e4962c..123b9a593 100644
--- a/core/src/browser/extensions/engines/LocalOAIEngine.ts
+++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts
@@ -1,6 +1,6 @@
-import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
+import { executeOnMain, systemInformation, dirName } from '../../core'
import { events } from '../../events'
-import { Model, ModelEvent } from '../../../types'
+import { Model, ModelEvent, ModelFile } from '../../../types'
import { OAIEngine } from './OAIEngine'
/**
@@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
unloadModelFunctionName: string = 'unloadModel'
/**
- * On extension load, subscribe to events.
+ * This class represents a base for local inference providers in the OpenAI architecture.
+ * It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
+ * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
+ * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
*/
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
- events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
+ events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
- override async loadModel(model: Model): Promise {
+ override async loadModel(model: ModelFile): Promise {
if (model.engine.toString() !== this.provider) return
- const modelFolderName = 'models'
- const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
+ const modelFolder = await dirName(model.file_path)
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts
index 5b3089403..040542927 100644
--- a/core/src/browser/extensions/model.ts
+++ b/core/src/browser/extensions/model.ts
@@ -4,6 +4,7 @@ import {
HuggingFaceRepoData,
ImportingModel,
Model,
+ ModelFile,
ModelInterface,
OptionType,
} from '../../types'
@@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
network?: { proxy: string; ignoreSSL?: boolean }
): Promise
abstract cancelModelDownload(modelId: string): Promise
- abstract deleteModel(modelId: string): Promise
- abstract saveModel(model: Model): Promise
- abstract getDownloadedModels(): Promise
- abstract getConfiguredModels(): Promise
+ abstract deleteModel(model: ModelFile): Promise
+ abstract getDownloadedModels(): Promise
+ abstract getConfiguredModels(): Promise
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise
- abstract updateModelInfo(modelInfo: Partial): Promise
+ abstract updateModelInfo(modelInfo: Partial): Promise
abstract fetchHuggingFaceRepoData(repoId: string): Promise
abstract getDefaultModel(): Promise
}
diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts
index 3ada5df1e..5c4daef29 100644
--- a/core/src/node/api/processors/app.test.ts
+++ b/core/src/node/api/processors/app.test.ts
@@ -1,40 +1,57 @@
-import { App } from './app';
+jest.mock('../../helper', () => ({
+ ...jest.requireActual('../../helper'),
+ getJanDataFolderPath: () => './app',
+}))
+import { dirname } from 'path'
+import { App } from './app'
it('should call stopServer', () => {
- const app = new App();
- const stopServerMock = jest.fn().mockResolvedValue('Server stopped');
+ const app = new App()
+ const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
jest.mock('@janhq/server', () => ({
- stopServer: stopServerMock
- }));
- const result = app.stopServer();
- expect(stopServerMock).toHaveBeenCalled();
-});
+ stopServer: stopServerMock,
+ }))
+ app.stopServer()
+ expect(stopServerMock).toHaveBeenCalled()
+})
it('should correctly retrieve basename', () => {
- const app = new App();
- const result = app.baseName('/path/to/file.txt');
- expect(result).toBe('file.txt');
-});
+ const app = new App()
+ const result = app.baseName('/path/to/file.txt')
+ expect(result).toBe('file.txt')
+})
it('should correctly identify subdirectories', () => {
- const app = new App();
- const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to';
- const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir';
- const result = app.isSubdirectory(basePath, subPath);
- expect(result).toBe(true);
-});
+ const app = new App()
+ const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
+ const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
+ const result = app.isSubdirectory(basePath, subPath)
+ expect(result).toBe(true)
+})
it('should correctly join multiple paths', () => {
- const app = new App();
- const result = app.joinPath(['path', 'to', 'file']);
- const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file';
- expect(result).toBe(expectedPath);
-});
+ const app = new App()
+ const result = app.joinPath(['path', 'to', 'file'])
+ const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
+ expect(result).toBe(expectedPath)
+})
it('should call correct function with provided arguments using process method', () => {
- const app = new App();
- const mockFunc = jest.fn();
- app.joinPath = mockFunc;
- app.process('joinPath', ['path1', 'path2']);
- expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']);
-});
+ const app = new App()
+ const mockFunc = jest.fn()
+ app.joinPath = mockFunc
+ app.process('joinPath', ['path1', 'path2'])
+ expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
+})
+
+it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
+ const app = new App()
+ const path = 'C:/Users/John Doe/Desktop/file.txt'
+ expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
+})
+
+it('should retrieve the directory name when using file protocol', async () => {
+ const app = new App()
+ const path = 'file:/models/file.txt'
+ expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
+})
diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts
index 15460ba56..a0808c5ac 100644
--- a/core/src/node/api/processors/app.ts
+++ b/core/src/node/api/processors/app.ts
@@ -1,4 +1,4 @@
-import { basename, isAbsolute, join, relative } from 'path'
+import { basename, dirname, isAbsolute, join, relative } from 'path'
import { Processor } from './Processor'
import {
@@ -6,6 +6,8 @@ import {
appResourcePath,
getAppConfigurations as appConfiguration,
updateAppConfiguration,
+ normalizeFilePath,
+ getJanDataFolderPath,
} from '../../helper'
export class App implements Processor {
@@ -28,6 +30,18 @@ export class App implements Processor {
return join(...args)
}
+ /**
+ * Get dirname of a file path.
+ * @param path - The file path to retrieve dirname.
+ */
+ dirName(path: string) {
+ const arg =
+ path.startsWith(`file:/`) || path.startsWith(`file:\\`)
+ ? join(getJanDataFolderPath(), normalizeFilePath(path))
+ : path
+ return dirname(arg)
+ }
+
/**
* Checks if the given path is a subdirectory of the given directory.
*
diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts
index bca11c0a8..8f1ff70bf 100644
--- a/core/src/types/api/index.ts
+++ b/core/src/types/api/index.ts
@@ -37,6 +37,7 @@ export enum AppRoute {
getAppConfigurations = 'getAppConfigurations',
updateAppConfiguration = 'updateAppConfiguration',
joinPath = 'joinPath',
+ dirName = 'dirName',
isSubdirectory = 'isSubdirectory',
baseName = 'baseName',
startServer = 'startServer',
diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts
index 1b36a5777..4db956b1e 100644
--- a/core/src/types/file/index.ts
+++ b/core/src/types/file/index.ts
@@ -52,3 +52,18 @@ type DownloadSize = {
total: number
transferred: number
}
+
+/**
+ * The file metadata
+ */
+export type FileMetadata = {
+ /**
+ * The origin file path.
+ */
+ file_path: string
+
+ /**
+ * The file name.
+ */
+ file_name: string
+}
diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts
index f154f7f04..933c698c3 100644
--- a/core/src/types/model/modelEntity.ts
+++ b/core/src/types/model/modelEntity.ts
@@ -1,3 +1,5 @@
+import { FileMetadata } from '../file'
+
/**
* Represents the information about a model.
* @stored
@@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
export type ModelInitFailed = Model & {
error: Error
}
+
+/**
+ * ModelFile is the model.json entity and it's file metadata
+ */
+export type ModelFile = Model & FileMetadata
diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts
index 639c7c8d3..5b5856231 100644
--- a/core/src/types/model/modelInterface.ts
+++ b/core/src/types/model/modelInterface.ts
@@ -1,5 +1,5 @@
import { GpuSetting } from '../miscellaneous'
-import { Model } from './modelEntity'
+import { Model, ModelFile } from './modelEntity'
/**
* Model extension for managing models.
@@ -29,14 +29,7 @@ export interface ModelInterface {
* @param modelId - The ID of the model to delete.
* @returns A Promise that resolves when the model has been deleted.
*/
- deleteModel(modelId: string): Promise
-
- /**
- * Saves a model.
- * @param model - The model to save.
- * @returns A Promise that resolves when the model has been saved.
- */
- saveModel(model: Model): Promise
+ deleteModel(model: ModelFile): Promise
/**
* Gets a list of downloaded models.
diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts
index d79e076d4..6e825e8fd 100644
--- a/extensions/inference-nitro-extension/src/index.ts
+++ b/extensions/inference-nitro-extension/src/index.ts
@@ -22,6 +22,7 @@ import {
downloadFile,
DownloadState,
DownloadEvent,
+ ModelFile,
} from '@janhq/core'
declare const CUDA_DOWNLOAD_URL: string
@@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
this.nitroProcessInfo = health
}
- override loadModel(model: Model): Promise {
+ override loadModel(model: ModelFile): Promise {
if (model.engine !== this.provider) return Promise.resolve()
this.getNitroProcessHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts
index 3a969ad5e..98ca4572f 100644
--- a/extensions/inference-nitro-extension/src/node/index.ts
+++ b/extensions/inference-nitro-extension/src/node/index.ts
@@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
import {
log,
getSystemResourceInfo,
- Model,
InferenceEngine,
ModelSettingParams,
PromptTemplate,
SystemInformation,
getJanDataFolderPath,
+ ModelFile,
} from '@janhq/core/node'
import { executableNitroFile } from './execute'
import terminate from 'terminate'
@@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
*/
interface ModelInitOptions {
modelFolder: string
- model: Model
+ model: ModelFile
}
// The PORT to use for the Nitro subprocess
const PORT = 3928
diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js
new file mode 100644
index 000000000..3e32adceb
--- /dev/null
+++ b/extensions/model-extension/jest.config.js
@@ -0,0 +1,9 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ transform: {
+ 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
+ },
+ transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
+}
diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json
index 4a2c61b71..9a406dcf4 100644
--- a/extensions/model-extension/package.json
+++ b/extensions/model-extension/package.json
@@ -8,6 +8,7 @@
"author": "Jan ",
"license": "AGPL-3.0",
"scripts": {
+ "test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
},
diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts
index c3f3acc77..d36d8ffac 100644
--- a/extensions/model-extension/rollup.config.ts
+++ b/extensions/model-extension/rollup.config.ts
@@ -27,7 +27,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
- typescript({ useTsconfigDeclarationDir: true }),
+ typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Compile TypeScript files
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
// commonjs(),
@@ -62,7 +62,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
- typescript({ useTsconfigDeclarationDir: true }),
+ typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
commonjs(),
// Allow node_modules resolution, so you can use 'external' to control
diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts
new file mode 100644
index 000000000..6816d7101
--- /dev/null
+++ b/extensions/model-extension/src/index.test.ts
@@ -0,0 +1,564 @@
+const readDirSyncMock = jest.fn()
+const existMock = jest.fn()
+const readFileSyncMock = jest.fn()
+
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core/node'),
+ fs: {
+ existsSync: existMock,
+ readdirSync: readDirSyncMock,
+ readFileSync: readFileSyncMock,
+ fileStat: () => ({
+ isDirectory: false,
+ }),
+ },
+ dirName: jest.fn(),
+ joinPath: (paths) => paths.join('/'),
+ ModelExtension: jest.fn(),
+}))
+
+import JanModelExtension from '.'
+import { fs, dirName } from '@janhq/core'
+
+describe('JanModelExtension', () => {
+ let sut: JanModelExtension
+
+ beforeAll(() => {
+ // @ts-ignore
+ sut = new JanModelExtension()
+ })
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('getConfiguredModels', () => {
+ describe("when there's no models are pre-populated", () => {
+ it('should return empty array', async () => {
+ // Mock configured models data
+ const configuredModels = []
+ existMock.mockReturnValue(true)
+ readDirSyncMock.mockReturnValue([])
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual([])
+ })
+ })
+
+ describe("when there's are pre-populated models - all flattened", () => {
+ it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+
+ describe("when there's are pre-populated models - there are nested folders", () => {
+ it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else if (path.includes('model2/model2-1'))
+ return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('getDownloadedModels', () => {
+ describe('no models downloaded', () => {
+ it('should return empty array', async () => {
+ // Mock downloaded models data
+ const downloadedModels = []
+ existMock.mockReturnValue(true)
+ readDirSyncMock.mockReturnValue([])
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual([])
+ })
+ })
+ describe('only one model is downloaded', () => {
+ describe('flatten folder', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.gguf']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('all models are downloaded', () => {
+ describe('nested folders', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else return ['model.json', 'test.gguf']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('all models are downloaded with uppercased GGUF files', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.GGUF']
+ else return ['model.json', 'test.gguf']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+
+ describe('all models are downloaded - GGUF & Tensort RT', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.gguf']
+ else return ['model.json', 'test.engine']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('deleteModel', () => {
+ describe('model is a GGUF model', () => {
+ it('should delete the GGUF file', async () => {
+ fs.unlinkSync = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ readDirSyncMock.mockImplementation((path) => {
+ return ['model.json', 'test.gguf']
+ })
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledWith(
+ 'file://models/model1/test.gguf'
+ )
+ })
+
+ it('no gguf file presented', async () => {
+ fs.unlinkSync = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ readDirSyncMock.mockReturnValue(['model.json'])
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
+ })
+
+ it('delete an imported model', async () => {
+ fs.rm = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
+
+ // MARK: This is a tricky logic implement?
+ // I will just add test for now but will align on the legacy implementation
+ fs.readFileSync = jest.fn().mockReturnValue(
+ JSON.stringify({
+ metadata: {
+ author: 'user',
+ },
+ })
+ )
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
+ })
+
+ it('delete tensorrt-models', async () => {
+ fs.rm = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
+ })
+ })
+ })
+})
diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts
index e2f68a58c..ac9b06a09 100644
--- a/extensions/model-extension/src/index.ts
+++ b/extensions/model-extension/src/index.ts
@@ -22,6 +22,8 @@ import {
getFileSize,
AllQuantizations,
ModelEvent,
+ ModelFile,
+ dirName,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
]
private static readonly _tensorRtEngineFormat = '.engine'
private static readonly _supportedGpuArch = ['ampere', 'ada']
- private static readonly _safetensorsRegexs = [
- /model\.safetensors$/,
- /model-[0-9]+-of-[0-9]+\.safetensors$/,
- ]
- private static readonly _pytorchRegexs = [
- /pytorch_model\.bin$/,
- /consolidated\.[0-9]+\.pth$/,
- /pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
- /.*\.pt$/,
- ]
+
interrupted = false
/**
@@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension {
* @param filePath - The path to the model file to delete.
* @returns A Promise that resolves when the model is deleted.
*/
- async deleteModel(modelId: string): Promise {
+ async deleteModel(model: ModelFile): Promise {
try {
- const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
+ const dirPath = await dirName(model.file_path)
const jsonFilePath = await joinPath([
dirPath,
JanModelExtension._modelMetadataFileName,
@@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension {
await this.readModelMetadata(jsonFilePath)
) as Model
+ // TODO: This is so tricky?
+ // Should depend on sources?
const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user'
- if (isUserImportModel) {
+ if (isUserImportModel) {
// just delete the folder
return fs.rm(dirPath)
}
@@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension {
}
}
- /**
- * Saves a model file.
- * @param model - The model to save.
- * @returns A Promise that resolves when the model is saved.
- */
- async saveModel(model: Model): Promise {
- const jsonFilePath = await joinPath([
- JanModelExtension._homeDir,
- model.id,
- JanModelExtension._modelMetadataFileName,
- ])
-
- try {
- await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
- } catch (err) {
- console.error(err)
- }
- }
-
/**
* Gets all downloaded models.
* @returns A Promise that resolves with an array of all models.
*/
- async getDownloadedModels(): Promise {
+ async getDownloadedModels(): Promise {
return await this.getModelsMetadata(
async (modelDir: string, model: Model) => {
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
@@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension {
): Promise {
// try to find model.json recursively inside each folder
if (!(await fs.existsSync(folderFullPath))) return undefined
+
const files: string[] = await fs.readdirSync(folderFullPath)
if (files.length === 0) return undefined
+
if (files.includes(JanModelExtension._modelMetadataFileName)) {
return joinPath([
folderFullPath,
@@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension {
private async getModelsMetadata(
selector?: (path: string, model: Model) => Promise
- ): Promise {
+ ): Promise {
try {
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
console.debug('Model folder not found')
@@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._homeDir,
dirName,
])
+
const jsonPath = await this.getModelJsonPath(folderFullPath)
if (await fs.existsSync(jsonPath)) {
@@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension {
},
]
}
+ model.file_path = jsonPath
+ model.file_name = JanModelExtension._modelMetadataFileName
if (selector && !(await selector?.(dirName, model))) {
return
@@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension {
typeof result.value === 'object'
? result.value
: JSON.parse(result.value)
- return model as Model
+ return model as ModelFile
} catch {
console.debug(`Unable to parse model metadata: ${result.value}`)
}
@@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension {
* Gets all available models.
* @returns A Promise that resolves with an array of all models.
*/
- async getConfiguredModels(): Promise {
+ async getConfiguredModels(): Promise {
return this.getModelsMetadata()
}
@@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension {
modelBinaryPath: string,
modelFolderName: string,
modelFolderPath: string
- ): Promise {
+ ): Promise {
const fileStats = await fs.fileStat(modelBinaryPath, true)
const binaryFileSize = fileStats.size
@@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension {
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
- return model
+ return {
+ ...model,
+ file_path: modelFilePath,
+ file_name: JanModelExtension._modelMetadataFileName,
+ }
}
- async updateModelInfo(modelInfo: Partial): Promise {
- const modelId = modelInfo.id
+ async updateModelInfo(modelInfo: Partial): Promise {
if (modelInfo.id == null) throw new Error('Model ID is required')
- const janDataFolderPath = await getJanDataFolderPath()
- const jsonFilePath = await joinPath([
- janDataFolderPath,
- 'models',
- modelId,
- JanModelExtension._modelMetadataFileName,
- ])
const model = JSON.parse(
- await this.readModelMetadata(jsonFilePath)
- ) as Model
+ await this.readModelMetadata(modelInfo.file_path)
+ ) as ModelFile
- const updatedModel: Model = {
+ const updatedModel: ModelFile = {
...model,
...modelInfo,
parameters: {
@@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension {
...model.metadata,
...modelInfo.metadata,
},
+ // Should not persist file_path & file_name
+ file_path: undefined,
+ file_name: undefined,
}
- await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2))
+ await fs.writeFileSync(
+ modelInfo.file_path,
+ JSON.stringify(updatedModel, null, 2)
+ )
return updatedModel
}
diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json
index addd8e127..0d3252934 100644
--- a/extensions/model-extension/tsconfig.json
+++ b/extensions/model-extension/tsconfig.json
@@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
- "include": ["./src"]
+ "include": ["./src"],
+ "exclude": ["**/*.test.ts"]
}
diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts
index 189abc706..7f68c43bd 100644
--- a/extensions/tensorrt-llm-extension/src/index.ts
+++ b/extensions/tensorrt-llm-extension/src/index.ts
@@ -23,6 +23,7 @@ import {
ModelEvent,
getJanDataFolderPath,
SystemInformation,
+ ModelFile,
} from '@janhq/core'
/**
@@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {})
}
- override async loadModel(model: Model): Promise {
+ override async loadModel(model: ModelFile): Promise {
if ((await this.installationState()) === 'Installed')
return super.loadModel(model)
diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx
index 92d8addd0..d8743ddce 100644
--- a/web/containers/ModelDropdown/index.tsx
+++ b/web/containers/ModelDropdown/index.tsx
@@ -46,7 +46,6 @@ import {
import { extensionManager } from '@/extension'
-import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import {
configuredModelsAtom,
@@ -91,8 +90,6 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured')
)
- const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
-
const { updateThreadMetadata } = useCreateNewThread()
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
@@ -191,27 +188,14 @@ const ModelDropdown = ({
],
})
- // Default setting ctx_len for the model for a better onboarding experience
- // TODO: When Cortex support hardware instructions, we should remove this
- const defaultContextLength = preserveModelSettings
- ? model?.metadata?.default_ctx_len
- : 2048
- const defaultMaxTokens = preserveModelSettings
- ? model?.metadata?.default_max_tokens
- : 2048
const overriddenSettings =
- model?.settings.ctx_len && model.settings.ctx_len > 2048
- ? { ctx_len: defaultContextLength ?? 2048 }
- : {}
- const overriddenParameters =
- model?.parameters.max_tokens && model.parameters.max_tokens
- ? { max_tokens: defaultMaxTokens ?? 2048 }
+ model?.settings.ctx_len && model.settings.ctx_len > 4096
+ ? { ctx_len: 4096 }
: {}
const modelParams = {
...model?.parameters,
...model?.settings,
- ...overriddenParameters,
...overriddenSettings,
}
@@ -222,6 +206,7 @@ const ModelDropdown = ({
if (model)
updateModelParameter(activeThread, {
params: modelParams,
+ modelPath: model.file_path,
modelId: model.id,
engine: model.engine,
})
@@ -235,7 +220,6 @@ const ModelDropdown = ({
setThreadModelParams,
updateModelParameter,
updateThreadMetadata,
- preserveModelSettings,
]
)
diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts
index e7b7efaec..f4acc7dc2 100644
--- a/web/helpers/atoms/AppConfig.atom.ts
+++ b/web/helpers/atoms/AppConfig.atom.ts
@@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled'
-const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
export const janDataFolderPathAtom = atom('')
@@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
export const hostAtom = atom('http://localhost:1337/')
-
-// This feature is to allow user to cache model settings on thread creation
-export const preserveModelSettingsAtom = atomWithStorage(
- PRESERVE_MODEL_SETTINGS,
- false
-)
diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts
index 77b1bfa4e..d2d0ca9f4 100644
--- a/web/helpers/atoms/Model.atom.ts
+++ b/web/helpers/atoms/Model.atom.ts
@@ -1,4 +1,4 @@
-import { ImportingModel, Model, InferenceEngine } from '@janhq/core'
+import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
import { localEngines } from '@/utils/modelEngine'
@@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom(
}
)
-export const downloadedModelsAtom = atom([])
-
-export const updateDownloadedModelAtom = atom(
- null,
- (get, set, updatedModel: Model) => {
- const models: Model[] = get(downloadedModelsAtom).map((c) =>
- c.id === updatedModel.id ? updatedModel : c
- )
-
- set(downloadedModelsAtom, models)
- }
-)
+export const downloadedModelsAtom = atom([])
export const removeDownloadedModelAtom = atom(
null,
@@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom(
}
)
-export const configuredModelsAtom = atom([])
+export const configuredModelsAtom = atom([])
export const defaultModelAtom = atom(undefined)
@@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom(
}
)
-export const selectedModelAtom = atom(undefined)
+export const selectedModelAtom = atom(undefined)
export const showEngineListModelAtom = atom(localEngines)
diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts
index 9768ac4c4..2d53678c3 100644
--- a/web/hooks/useActiveModel.ts
+++ b/web/hooks/useActiveModel.ts
@@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef } from 'react'
-import { EngineManager, Model } from '@janhq/core'
+import { EngineManager, Model, ModelFile } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast'
@@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
-export const activeModelAtom = atom(undefined)
+export const activeModelAtom = atom(undefined)
export const loadModelErrorAtom = atom(undefined)
type ModelState = {
@@ -37,7 +37,7 @@ export function useActiveModel() {
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
- const downloadedModelsRef = useRef([])
+ const downloadedModelsRef = useRef([])
useEffect(() => {
downloadedModelsRef.current = downloadedModels
diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts
index 80acfa3cc..5548259fd 100644
--- a/web/hooks/useCreateNewThread.ts
+++ b/web/hooks/useCreateNewThread.ts
@@ -7,8 +7,8 @@ import {
Thread,
ThreadAssistantInfo,
ThreadState,
- Model,
AssistantTool,
+ ModelFile,
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'
@@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension'
-import {
- experimentalFeatureEnabledAtom,
- preserveModelSettingsAtom,
-} from '@/helpers/atoms/AppConfig.atom'
+import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
@@ -67,7 +64,6 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
- const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
@@ -80,7 +76,7 @@ export const useCreateNewThread = () => {
const requestCreateNewThread = async (
assistant: Assistant,
- model?: Model | undefined
+ model?: ModelFile | undefined
) => {
// Stop generating if any
setIsGeneratingResponse(false)
@@ -109,19 +105,13 @@ export const useCreateNewThread = () => {
enabled: true,
settings: assistant.tools && assistant.tools[0].settings,
}
- const defaultContextLength = preserveModelSettings
- ? defaultModel?.metadata?.default_ctx_len
- : 2048
- const defaultMaxTokens = preserveModelSettings
- ? defaultModel?.metadata?.default_max_tokens
- : 2048
const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
- ? { ctx_len: defaultContextLength ?? 2048 }
+ ? { ctx_len: 4096 }
: {}
const overriddenParameters = defaultModel?.parameters.max_tokens
- ? { max_tokens: defaultMaxTokens ?? 2048 }
+ ? { max_tokens: 4096 }
: {}
const createdAt = Date.now()
diff --git a/web/hooks/useDeleteModel.ts b/web/hooks/useDeleteModel.ts
index 9736f8256..5a7a319b2 100644
--- a/web/hooks/useDeleteModel.ts
+++ b/web/hooks/useDeleteModel.ts
@@ -1,6 +1,6 @@
import { useCallback } from 'react'
-import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core'
+import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core'
import { useSetAtom } from 'jotai'
@@ -13,8 +13,8 @@ export default function useDeleteModel() {
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
const deleteModel = useCallback(
- async (model: Model) => {
- await localDeleteModel(model.id)
+ async (model: ModelFile) => {
+ await localDeleteModel(model)
removeDownloadedModel(model.id)
toaster({
title: 'Model Deletion Successful',
@@ -28,5 +28,7 @@ export default function useDeleteModel() {
return { deleteModel }
}
-const localDeleteModel = async (id: string) =>
- extensionManager.get(ExtensionTypeEnum.Model)?.deleteModel(id)
+const localDeleteModel = async (model: ModelFile) =>
+ extensionManager
+ .get(ExtensionTypeEnum.Model)
+ ?.deleteModel(model)
diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts
index 5a6f13e03..8333c35c3 100644
--- a/web/hooks/useModels.ts
+++ b/web/hooks/useModels.ts
@@ -5,6 +5,7 @@ import {
Model,
ModelEvent,
ModelExtension,
+ ModelFile,
events,
} from '@janhq/core'
@@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise =>
.get(ExtensionTypeEnum.Model)
?.getDefaultModel()
-const getLocalConfiguredModels = async (): Promise =>
+const getLocalConfiguredModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getConfiguredModels() ?? []
-const getLocalDownloadedModels = async (): Promise =>
+const getLocalDownloadedModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getDownloadedModels() ?? []
diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts
index 21a9c69e7..ed56efa55 100644
--- a/web/hooks/useRecommendedModel.ts
+++ b/web/hooks/useRecommendedModel.ts
@@ -1,6 +1,6 @@
import { useCallback, useEffect, useState } from 'react'
-import { Model, InferenceEngine } from '@janhq/core'
+import { Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom, useAtomValue } from 'jotai'
@@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id'
*/
export default function useRecommendedModel() {
const activeModel = useAtomValue(activeModelAtom)
- const [sortedModels, setSortedModels] = useState([])
- const [recommendedModel, setRecommendedModel] = useState()
+ const [sortedModels, setSortedModels] = useState([])
+ const [recommendedModel, setRecommendedModel] = useState<
+ ModelFile | undefined
+ >()
const activeThread = useAtomValue(activeThreadAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
- const getAndSortDownloadedModels = useCallback(async (): Promise => {
+ const getAndSortDownloadedModels = useCallback(async (): Promise<
+ ModelFile[]
+ > => {
const models = downloadedModels.sort((a, b) =>
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
? 1
diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts
index 46bf07cd5..af30210ad 100644
--- a/web/hooks/useUpdateModelParameters.ts
+++ b/web/hooks/useUpdateModelParameters.ts
@@ -4,8 +4,6 @@ import {
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
- Model,
- ModelExtension,
Thread,
ThreadAssistantInfo,
} from '@janhq/core'
@@ -17,14 +15,8 @@ import {
extractModelLoadParams,
} from '@/utils/modelParam'
-import useRecommendedModel from './useRecommendedModel'
-
import { extensionManager } from '@/extension'
-import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
-import {
- selectedModelAtom,
- updateDownloadedModelAtom,
-} from '@/helpers/atoms/Model.atom'
+import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
ModelParams,
getActiveThreadModelParamsAtom,
@@ -34,16 +26,14 @@ import {
export type UpdateModelParameter = {
params?: ModelParams
modelId?: string
+ modelPath?: string
engine?: InferenceEngine
}
export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
- const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
+ const [selectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
- const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
- const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
- const { recommendedModel, setRecommendedModel } = useRecommendedModel()
const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => {
@@ -83,50 +73,8 @@ export default function useUpdateModelParameters() {
await extensionManager
.get(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
-
- // Persists default settings to model file
- // Do not overwrite ctx_len and max_tokens
- if (preserveModelFeatureEnabled) {
- const defaultContextLength = settingParams.ctx_len
- const defaultMaxTokens = runtimeParams.max_tokens
-
- // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
- const { ctx_len, ...toSaveSettings } = settingParams
- // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
- const { max_tokens, ...toSaveParams } = runtimeParams
-
- const updatedModel = {
- id: settings.modelId ?? selectedModel?.id,
- parameters: {
- ...toSaveSettings,
- },
- settings: {
- ...toSaveParams,
- },
- metadata: {
- default_ctx_len: defaultContextLength,
- default_max_tokens: defaultMaxTokens,
- },
- } as Partial
-
- const model = await extensionManager
- .get(ExtensionTypeEnum.Model)
- ?.updateModelInfo(updatedModel)
- if (model) updateDownloadedModel(model)
- if (selectedModel?.id === model?.id) setSelectedModel(model)
- if (recommendedModel?.id === model?.id) setRecommendedModel(model)
- }
},
- [
- activeModelParams,
- selectedModel,
- setThreadModelParams,
- preserveModelFeatureEnabled,
- updateDownloadedModel,
- setSelectedModel,
- recommendedModel,
- setRecommendedModel,
- ]
+ [activeModelParams, selectedModel, setThreadModelParams]
)
const processStopWords = (params: ModelParams): ModelParams => {
diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx
index b20977aff..44a3fd278 100644
--- a/web/screens/Hub/ModelList/ModelHeader/index.tsx
+++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx
@@ -1,6 +1,6 @@
import { useCallback } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { Button, Badge, Tooltip } from '@janhq/joi'
import { useAtomValue, useSetAtom } from 'jotai'
@@ -38,7 +38,7 @@ import {
} from '@/helpers/atoms/SystemBar.atom'
type Props = {
- model: Model
+ model: ModelFile
onClick: () => void
open: string
}
diff --git a/web/screens/Hub/ModelList/ModelItem/index.tsx b/web/screens/Hub/ModelList/ModelItem/index.tsx
index c9b2f1329..ec9d885a1 100644
--- a/web/screens/Hub/ModelList/ModelItem/index.tsx
+++ b/web/screens/Hub/ModelList/ModelItem/index.tsx
@@ -1,6 +1,6 @@
import { useState } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { Badge } from '@janhq/joi'
import { twMerge } from 'tailwind-merge'
@@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader'
import { toGibibytes } from '@/utils/converter'
type Props = {
- model: Model
+ model: ModelFile
}
const ModelItem: React.FC = ({ model }) => {
diff --git a/web/screens/Hub/ModelList/index.tsx b/web/screens/Hub/ModelList/index.tsx
index aea67b4e3..8fc30d541 100644
--- a/web/screens/Hub/ModelList/index.tsx
+++ b/web/screens/Hub/ModelList/index.tsx
@@ -1,6 +1,6 @@
import { useMemo } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { useAtomValue } from 'jotai'
@@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
type Props = {
- models: Model[]
+ models: ModelFile[]
}
const ModelList = ({ models }: Props) => {
const downloadedModels = useAtomValue(downloadedModelsAtom)
- const sortedModels: Model[] = useMemo(() => {
- const featuredModels: Model[] = []
- const remoteModels: Model[] = []
- const localModels: Model[] = []
- const remainingModels: Model[] = []
+ const sortedModels: ModelFile[] = useMemo(() => {
+ const featuredModels: ModelFile[] = []
+ const remoteModels: ModelFile[] = []
+ const localModels: ModelFile[] = []
+ const remainingModels: ModelFile[] = []
models.forEach((m) => {
if (m.metadata?.tags?.includes('Featured')) {
featuredModels.push(m)
diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
index 951a11d59..c3f09f171 100644
--- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
+++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
@@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC = ({
const { requestCreateNewThread } = useCreateNewThread()
const setMainViewState = useSetAtom(mainViewStateAtom)
const assistants = useAtomValue(assistantsAtom)
- const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null
+ const downloadedModel = downloadedModels.find((md) => md.id === fileName)
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
const defaultModel = useAtomValue(defaultModelAtom)
@@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC = ({
alert('No assistant available')
return
}
- await requestCreateNewThread(assistants[0], model)
+ await requestCreateNewThread(assistants[0], downloadedModel)
setMainViewState(MainViewState.Thread)
setHfImportingStage('NONE')
}, [
assistants,
- model,
+ downloadedModel,
requestCreateNewThread,
setMainViewState,
setHfImportingStage,
@@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC = ({
- {isDownloaded ? (
+ {downloadedModel ? (
Date: Tue, 17 Sep 2024 18:17:23 +0700
Subject: [PATCH 06/37] Fix: #1142 setting groups toggle does not turn off it's
nested settings (#3681)
* fix: #1142 - Toggle off experimental toggle does not turn off gated features
* test: add tests
---
.../Settings/Advanced/DataFolder/index.tsx | 1 +
.../Settings/Advanced/FactoryReset/index.tsx | 6 +-
web/screens/Settings/Advanced/index.test.tsx | 154 ++++++++++++++++++
web/screens/Settings/Advanced/index.tsx | 70 ++++++--
4 files changed, 214 insertions(+), 17 deletions(-)
create mode 100644 web/screens/Settings/Advanced/index.test.tsx
diff --git a/web/screens/Settings/Advanced/DataFolder/index.tsx b/web/screens/Settings/Advanced/DataFolder/index.tsx
index 3bb059a87..985dc65c3 100644
--- a/web/screens/Settings/Advanced/DataFolder/index.tsx
+++ b/web/screens/Settings/Advanced/DataFolder/index.tsx
@@ -100,6 +100,7 @@ const DataFolder = () => {
{
recommended only if the application is in a corrupted state.
-
setModalValidation(true)}>
+ setModalValidation(true)}
+ >
Reset
diff --git a/web/screens/Settings/Advanced/index.test.tsx b/web/screens/Settings/Advanced/index.test.tsx
new file mode 100644
index 000000000..10ea810b1
--- /dev/null
+++ b/web/screens/Settings/Advanced/index.test.tsx
@@ -0,0 +1,154 @@
+import React from 'react'
+import { render, screen, fireEvent, waitFor } from '@testing-library/react'
+import '@testing-library/jest-dom'
+import Advanced from '.'
+
+class ResizeObserverMock {
+ observe() {}
+ unobserve() {}
+ disconnect() {}
+}
+
+global.ResizeObserver = ResizeObserverMock
+// @ts-ignore
+global.window.core = {
+ api: {
+ getAppConfigurations: () => jest.fn(),
+ updateAppConfiguration: () => jest.fn(),
+ relaunch: () => jest.fn(),
+ },
+}
+
+const setSettingsMock = jest.fn()
+
+// Mock useSettings hook
+jest.mock('@/hooks/useSettings', () => ({
+ __esModule: true,
+ useSettings: () => ({
+ readSettings: () => ({
+ run_mode: 'gpu',
+ experimental: false,
+ proxy: false,
+ gpus: [{ name: 'gpu-1' }, { name: 'gpu-2' }],
+ gpus_in_use: ['0'],
+ quick_ask: false,
+ }),
+ setSettings: setSettingsMock,
+ }),
+}))
+
+import * as toast from '@/containers/Toast'
+
+jest.mock('@/containers/Toast')
+
+jest.mock('@janhq/core', () => ({
+ __esModule: true,
+ ...jest.requireActual('@janhq/core'),
+ fs: {
+ rm: jest.fn(),
+ },
+}))
+
+// Simulate a full advanced settings screen
+// @ts-ignore
+global.isMac = false
+// @ts-ignore
+global.isWindows = true
+
+describe('Advanced', () => {
+ it('renders the component', async () => {
+ render( )
+ await waitFor(() => {
+ expect(screen.getByText('Experimental Mode')).toBeInTheDocument()
+ expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
+ expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
+ expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
+ expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
+ })
+ })
+
+ it('updates Experimental enabled', async () => {
+ render( )
+ let experimentalToggle
+ await waitFor(() => {
+ experimentalToggle = screen.getByTestId(/experimental-switch/i)
+ fireEvent.click(experimentalToggle!)
+ })
+ expect(experimentalToggle).toBeChecked()
+ })
+
+ it('updates Experimental disabled', async () => {
+ render( )
+
+ let experimentalToggle
+ await waitFor(() => {
+ experimentalToggle = screen.getByTestId(/experimental-switch/i)
+ fireEvent.click(experimentalToggle!)
+ })
+ expect(experimentalToggle).not.toBeChecked()
+ })
+
+ it('clears logs', async () => {
+ const jestMock = jest.fn()
+ jest.spyOn(toast, 'toaster').mockImplementation(jestMock)
+
+ render( )
+ let clearLogsButton
+ await waitFor(() => {
+ clearLogsButton = screen.getByTestId(/clear-logs/i)
+ fireEvent.click(clearLogsButton)
+ })
+ expect(clearLogsButton).toBeInTheDocument()
+ expect(jestMock).toHaveBeenCalled()
+ })
+
+ it('toggles proxy enabled', async () => {
+ render( )
+ let proxyToggle
+ await waitFor(() => {
+ expect(screen.getByText('HTTPS Proxy')).toBeInTheDocument()
+ proxyToggle = screen.getByTestId(/proxy-switch/i)
+ fireEvent.click(proxyToggle)
+ })
+ expect(proxyToggle).toBeChecked()
+ })
+
+ it('updates proxy settings', async () => {
+ render( )
+ let proxyInput
+ await waitFor(() => {
+ const proxyToggle = screen.getByTestId(/proxy-switch/i)
+ fireEvent.click(proxyToggle)
+ proxyInput = screen.getByTestId(/proxy-input/i)
+ fireEvent.change(proxyInput, { target: { value: 'http://proxy.com' } })
+ })
+ expect(proxyInput).toHaveValue('http://proxy.com')
+ })
+
+ it('toggles ignore SSL certificates', async () => {
+ render( )
+ let ignoreSslToggle
+ await waitFor(() => {
+ expect(screen.getByText('Ignore SSL certificates')).toBeInTheDocument()
+ ignoreSslToggle = screen.getByTestId(/ignore-ssl-switch/i)
+ fireEvent.click(ignoreSslToggle)
+ })
+ expect(ignoreSslToggle).toBeChecked()
+ })
+
+ it('renders DataFolder component', async () => {
+ render( )
+ await waitFor(() => {
+ expect(screen.getByText('Jan Data Folder')).toBeInTheDocument()
+ expect(screen.getByTestId(/jan-data-folder-input/i)).toBeInTheDocument()
+ })
+ })
+
+ it('renders FactoryReset component', async () => {
+ render( )
+ await waitFor(() => {
+ expect(screen.getByText('Reset to Factory Settings')).toBeInTheDocument()
+ expect(screen.getByTestId(/reset-button/i)).toBeInTheDocument()
+ })
+ })
+})
diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx
index f132f81e7..1384f5688 100644
--- a/web/screens/Settings/Advanced/index.tsx
+++ b/web/screens/Settings/Advanced/index.tsx
@@ -43,19 +43,10 @@ type GPU = {
name: string
}
-const test = [
- {
- id: 'test a',
- vram: 2,
- name: 'nvidia A',
- },
- {
- id: 'test',
- vram: 2,
- name: 'nvidia B',
- },
-]
-
+/**
+ * Advanced Settings Screen
+ * @returns
+ */
const Advanced = () => {
const [experimentalEnabled, setExperimentalEnabled] = useAtom(
experimentalFeatureEnabledAtom
@@ -69,7 +60,7 @@ const Advanced = () => {
const [partialProxy, setPartialProxy] = useState(proxy)
const [gpuEnabled, setGpuEnabled] = useState(false)
- const [gpuList, setGpuList] = useState(test)
+ const [gpuList, setGpuList] = useState([])
const [gpusInUse, setGpusInUse] = useState([])
const [dropdownOptions, setDropdownOptions] = useState(
null
@@ -87,6 +78,9 @@ const Advanced = () => {
return y['name']
})
+ /**
+ * Handle proxy change
+ */
const onProxyChange = useCallback(
(event: ChangeEvent) => {
const value = event.target.value || ''
@@ -100,6 +94,12 @@ const Advanced = () => {
[setPartialProxy, setProxy]
)
+ /**
+ * Update Quick Ask Enabled
+ * @param e
+ * @param relaunch
+ * @returns void
+ */
const updateQuickAskEnabled = async (
e: boolean,
relaunch: boolean = true
@@ -111,6 +111,12 @@ const Advanced = () => {
if (relaunch) window.core?.api?.relaunch()
}
+ /**
+ * Update Vulkan Enabled
+ * @param e
+ * @param relaunch
+ * @returns void
+ */
const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => {
toaster({
title: 'Reload',
@@ -123,11 +129,19 @@ const Advanced = () => {
if (relaunch) window.location.reload()
}
+ /**
+ * Update Experimental Enabled
+ * @param e
+ * @returns
+ */
const updateExperimentalEnabled = async (
e: ChangeEvent
) => {
setExperimentalEnabled(e.target.checked)
- if (e) return
+
+ // If it checked, we don't need to do anything else
+ // Otherwise have to reset other settings
+ if (e.target.checked) return
// It affects other settings, so we need to reset them
const isRelaunch = quickAskEnabled || vulkanEnabled
@@ -136,6 +150,9 @@ const Advanced = () => {
if (isRelaunch) window.core?.api?.relaunch()
}
+ /**
+ * useEffect to set GPU enabled if possible
+ */
useEffect(() => {
const setUseGpuIfPossible = async () => {
const settings = await readSettings()
@@ -149,6 +166,10 @@ const Advanced = () => {
setUseGpuIfPossible()
}, [readSettings, setGpuList, setGpuEnabled, setGpusInUse, setVulkanEnabled])
+ /**
+ * Clear logs
+ * @returns
+ */
const clearLogs = async () => {
try {
await fs.rm(`file://logs`)
@@ -163,6 +184,11 @@ const Advanced = () => {
})
}
+ /**
+ * Handle GPU Change
+ * @param gpuId
+ * @returns
+ */
const handleGPUChange = (gpuId: string) => {
let updatedGpusInUse = [...gpusInUse]
if (updatedGpusInUse.includes(gpuId)) {
@@ -188,6 +214,9 @@ const Advanced = () => {
const gpuSelectionPlaceHolder =
gpuList.length > 0 ? 'Select GPU' : "You don't have any compatible GPU"
+ /**
+ * Handle click outside
+ */
useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
return (
@@ -204,6 +233,7 @@ const Advanced = () => {
@@ -401,11 +431,13 @@ const Advanced = () => {
setProxyEnabled(!proxyEnabled)}
/>
setIgnoreSSL(e.target.checked)}
/>
@@ -448,6 +481,7 @@ const Advanced = () => {
{
toaster({
@@ -471,7 +505,11 @@ const Advanced = () => {
Clear all logs from Jan app.
-
+
Clear
From c62b6e984282003d14160ce1b222c66fa4b79038 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Tue, 17 Sep 2024 22:13:18 +0700
Subject: [PATCH 07/37] fix: small leftover issues with new starter screen
(#3661)
* fix: fix duplicate render progress component
* fix: minor ui issue
* chore: add manual recommend model
* chore: make button create thread invisible
* chore: fix conflict
* chore: remove selector create thread icon
* test: added unit test thread screen
---
electron/tests/e2e/thread.e2e.spec.ts | 29 ++++---
web/containers/Layout/RibbonPanel/index.tsx | 16 ++--
web/containers/Layout/TopPanel/index.tsx | 5 +-
web/helpers/atoms/Thread.atom.ts | 3 +
web/hooks/useStarterScreen.ts | 7 +-
.../ChatBody/OnDeviceStarterScreen/index.tsx | 78 +++++++++++--------
web/screens/Thread/index.test.tsx | 35 +++++++++
7 files changed, 109 insertions(+), 64 deletions(-)
create mode 100644 web/screens/Thread/index.test.tsx
diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts
index c13e91119..5d7328053 100644
--- a/electron/tests/e2e/thread.e2e.spec.ts
+++ b/electron/tests/e2e/thread.e2e.spec.ts
@@ -1,32 +1,29 @@
import { expect } from '@playwright/test'
import { page, test, TIMEOUT } from '../config/fixtures'
-test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => {
+test('Select GPT model from Hub and Chat with Invalid API Key', async ({
+ hubPage,
+}) => {
await hubPage.navigateByMenu()
await hubPage.verifyContainerVisible()
// Select the first GPT model
await page
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
- .first().click()
-
- // Attempt to create thread and chat in Thread page
- await page
- .getByTestId('btn-create-thread')
+ .first()
.click()
- await page
- .getByTestId('txt-input-chat')
- .fill('dummy value')
+ await page.getByTestId('txt-input-chat').fill('dummy value')
- await page
- .getByTestId('btn-send-chat')
- .click()
+ await page.getByTestId('btn-send-chat').click()
- await page.waitForFunction(() => {
- const loaders = document.querySelectorAll('[data-testid$="loader"]');
- return !loaders.length;
- }, { timeout: TIMEOUT });
+ await page.waitForFunction(
+ () => {
+ const loaders = document.querySelectorAll('[data-testid$="loader"]')
+ return !loaders.length
+ },
+ { timeout: TIMEOUT }
+ )
const APIKeyError = page.getByTestId('invalid-API-key-error')
await expect(APIKeyError).toBeVisible({
diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx
index 6bed2b424..7613584e0 100644
--- a/web/containers/Layout/RibbonPanel/index.tsx
+++ b/web/containers/Layout/RibbonPanel/index.tsx
@@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge'
import { MainViewState } from '@/constants/screens'
-import { localEngines } from '@/utils/modelEngine'
-
import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
-import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
+
import {
reduceTransparentAtom,
selectedSettingAtom,
} from '@/helpers/atoms/Setting.atom'
-import { threadsAtom } from '@/helpers/atoms/Thread.atom'
+import {
+ isDownloadALocalModelAtom,
+ threadsAtom,
+} from '@/helpers/atoms/Thread.atom'
export default function RibbonPanel() {
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
@@ -32,8 +33,9 @@ export default function RibbonPanel() {
const matches = useMediaQuery('(max-width: 880px)')
const reduceTransparent = useAtomValue(reduceTransparentAtom)
const setSelectedSetting = useSetAtom(selectedSettingAtom)
- const downloadedModels = useAtomValue(downloadedModelsAtom)
+
const threads = useAtomValue(threadsAtom)
+ const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom)
const onMenuClick = (state: MainViewState) => {
if (mainViewState === state) return
@@ -43,10 +45,6 @@ export default function RibbonPanel() {
setEditMessage('')
}
- const isDownloadALocalModel = downloadedModels.some((x) =>
- localEngines.includes(x.engine)
- )
-
const RibbonNavMenus = [
{
name: 'Thread',
diff --git a/web/containers/Layout/TopPanel/index.tsx b/web/containers/Layout/TopPanel/index.tsx
index 213f7dfa9..aff616973 100644
--- a/web/containers/Layout/TopPanel/index.tsx
+++ b/web/containers/Layout/TopPanel/index.tsx
@@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast'
import { MainViewState } from '@/constants/screens'
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
+import { useStarterScreen } from '@/hooks/useStarterScreen'
import {
mainViewStateAtom,
@@ -58,6 +59,8 @@ const TopPanel = () => {
requestCreateNewThread(assistants[0])
}
+ const { isShowStarterScreen } = useStarterScreen()
+
return (
{
)}
)}
- {mainViewState === MainViewState.Thread && (
+ {mainViewState === MainViewState.Thread && !isShowStarterScreen && (
(false)
+export const isAnyRemoteModelConfiguredAtom = atom(false)
diff --git a/web/hooks/useStarterScreen.ts b/web/hooks/useStarterScreen.ts
index fbd6ef578..1a6bbfbc7 100644
--- a/web/hooks/useStarterScreen.ts
+++ b/web/hooks/useStarterScreen.ts
@@ -63,14 +63,9 @@ export function useStarterScreen() {
(x) => x.apiKey.length > 1
)
- let isShowStarterScreen
-
- isShowStarterScreen =
+ const isShowStarterScreen =
!isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length
- // Remove this part when we rework on starter screen
- isShowStarterScreen = false
-
return {
extensionHasSettings,
isShowStarterScreen,
diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
index 3ae32ca8c..26036a627 100644
--- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
@@ -58,9 +58,21 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
const configuredModels = useAtomValue(configuredModelsAtom)
const setMainViewState = useSetAtom(mainViewStateAtom)
- const featuredModel = configuredModels.filter(
- (x) => x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
- )
+ const recommendModel = ['gemma-2-2b-it', 'llama3.1-8b-instruct']
+
+ const featuredModel = configuredModels.filter((x) => {
+ const manualRecommendModel = configuredModels.filter((x) =>
+ recommendModel.includes(x.id)
+ )
+
+ if (manualRecommendModel.length === 2) {
+ return x.id === recommendModel[0] || x.id === recommendModel[1]
+ } else {
+ return (
+ x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
+ )
+ }
+ })
const remoteModel = configuredModels.filter(
(x) => !localEngines.includes(x.engine)
@@ -105,7 +117,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
width={48}
height={48}
/>
- Select a model to start
+ Select a model to start
@@ -120,7 +132,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/>
@@ -205,39 +217,41 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return (
-
{featModel.name}
-
+
{featModel.name}
+
{featModel.metadata.author}
{isDownloading ? (
- {Object.values(downloadStates).map((item, i) => (
-
-
-
-
-
- {formatDownloadPercentage(item?.percent)}
-
+ {Object.values(downloadStates)
+ .filter((x) => x.modelId === featModel.id)
+ .map((item, i) => (
+
+
+
+
+
+ {formatDownloadPercentage(item?.percent)}
+
+
-
- ))}
+ ))}
) : (
@@ -248,7 +262,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
>
Download
-
+
{toGibibytes(featModel.metadata.size)}
@@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
)
})}
-
+
Cloud Models
@@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return (
{row.map((remoteEngine) => {
const engineLogo = getLogoEngine(
@@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/>
)}
-
+
{getTitleByEngine(
remoteEngine as InferenceEngine
)}
diff --git a/web/screens/Thread/index.test.tsx b/web/screens/Thread/index.test.tsx
new file mode 100644
index 000000000..01af0ffc5
--- /dev/null
+++ b/web/screens/Thread/index.test.tsx
@@ -0,0 +1,35 @@
+import React from 'react'
+import { render, screen } from '@testing-library/react'
+import ThreadScreen from './index'
+import { useStarterScreen } from '../../hooks/useStarterScreen'
+import '@testing-library/jest-dom'
+
+global.ResizeObserver = class {
+ observe() {}
+ unobserve() {}
+ disconnect() {}
+}
+// Mock the useStarterScreen hook
+jest.mock('@/hooks/useStarterScreen')
+
+describe('ThreadScreen', () => {
+ it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
+ ;(useStarterScreen as jest.Mock).mockReturnValue({
+ isShowStarterScreen: true,
+ extensionHasSettings: false,
+ })
+
+ const { getByText } = render( )
+ expect(getByText('Select a model to start')).toBeInTheDocument()
+ })
+
+ it('renders Thread panels when isShowStarterScreen is false', () => {
+ ;(useStarterScreen as jest.Mock).mockReturnValue({
+ isShowStarterScreen: false,
+ extensionHasSettings: false,
+ })
+
+ const { getByText } = render( )
+ expect(getByText('Welcome!')).toBeInTheDocument()
+ })
+})
From 3949515c8a68dee16e2209b513bcf239d2b5343a Mon Sep 17 00:00:00 2001
From: 0xSage
Date: Wed, 18 Sep 2024 17:02:41 +0800
Subject: [PATCH 08/37] chore: copy nits
---
.../BottomPanel/SystemMonitor/TableActiveModel/index.tsx | 2 +-
web/screens/LocalServer/LocalServerLeftPanel/index.tsx | 2 +-
web/screens/Settings/Advanced/FactoryReset/index.tsx | 3 +--
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
index c9d86e5e8..e68f843a9 100644
--- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
+++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
@@ -79,7 +79,7 @@ const TableActiveModel = () => {
) : (
- No on-device model running
+ No models are loaded into memory
)}
diff --git a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
index f66945929..16aa75af5 100644
--- a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
+++ b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
@@ -130,7 +130,7 @@ const LocalServerLeftPanel = () => {
{serverEnabled && (
- API Reference
+ API Playground
)}
diff --git a/web/screens/Settings/Advanced/FactoryReset/index.tsx b/web/screens/Settings/Advanced/FactoryReset/index.tsx
index 3bbce39ef..181b0bd4b 100644
--- a/web/screens/Settings/Advanced/FactoryReset/index.tsx
+++ b/web/screens/Settings/Advanced/FactoryReset/index.tsx
@@ -17,8 +17,7 @@ const FactoryReset = () => {
- Reset the application to its initial state, deleting all your usage
- data, including conversation history. This action is irreversible and
+ Restore app to initial state, erasing all models and chat history. This action is irreversible and
recommended only if the application is in a corrupted state.
From 062af9bcda43256b6cc14d6c5dd0cbd927a43c7b Mon Sep 17 00:00:00 2001
From: 0xSage
Date: Wed, 18 Sep 2024 17:42:35 +0800
Subject: [PATCH 09/37] nits
---
.../Settings/Advanced/FactoryReset/ModalConfirmReset.tsx | 5 ++---
web/screens/Settings/Advanced/FactoryReset/index.tsx | 4 ++--
web/screens/Settings/CancelModelImportModal/index.tsx | 3 +--
3 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
index 8173574a9..268192627 100644
--- a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
+++ b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
@@ -30,9 +30,8 @@ const ModalConfirmReset = () => {
content={
- It will reset the application to its original state, deleting all
- your usage data, including model customizations and conversation
- history. This action is irreversible.
+ Restore app to initial state, erasing all models and chat history. This
+ action is irreversible and recommended only if the application is corrupted.
diff --git a/web/screens/Settings/Advanced/FactoryReset/index.tsx b/web/screens/Settings/Advanced/FactoryReset/index.tsx
index e79bfe54c..fb789e5b3 100644
--- a/web/screens/Settings/Advanced/FactoryReset/index.tsx
+++ b/web/screens/Settings/Advanced/FactoryReset/index.tsx
@@ -17,8 +17,8 @@ const FactoryReset = () => {
- Restore app to initial state, erasing all models and chat history. This action is irreversible and
- recommended only if the application is in a corrupted state.
+ Restore app to initial state, erasing all models and chat history. This
+ action is irreversible and recommended only if the application is corrupted.
{
The model import process is not complete. Are you sure you want to
- cancel all ongoing model imports? This action is irreversible and
- the progress will be lost.
+ cancel?
From cfe657faf55ef12100a8fbfd078bb5cca616853f Mon Sep 17 00:00:00 2001
From: 0xSage
Date: Wed, 18 Sep 2024 17:46:53 +0800
Subject: [PATCH 10/37] fix: linter
---
.../Settings/Advanced/FactoryReset/ModalConfirmReset.tsx | 5 +++--
web/screens/Settings/Advanced/FactoryReset/index.tsx | 5 +++--
web/screens/Settings/CoreExtensions/ExtensionItem.tsx | 4 ++--
.../Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx | 2 +-
4 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
index 268192627..08ac35f04 100644
--- a/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
+++ b/web/screens/Settings/Advanced/FactoryReset/ModalConfirmReset.tsx
@@ -30,8 +30,9 @@ const ModalConfirmReset = () => {
content={
- Restore app to initial state, erasing all models and chat history. This
- action is irreversible and recommended only if the application is corrupted.
+ Restore app to initial state, erasing all models and chat history.
+ This action is irreversible and recommended only if the application
+ is corrupted.
diff --git a/web/screens/Settings/Advanced/FactoryReset/index.tsx b/web/screens/Settings/Advanced/FactoryReset/index.tsx
index fb789e5b3..10e8cbc20 100644
--- a/web/screens/Settings/Advanced/FactoryReset/index.tsx
+++ b/web/screens/Settings/Advanced/FactoryReset/index.tsx
@@ -17,8 +17,9 @@ const FactoryReset = () => {
- Restore app to initial state, erasing all models and chat history. This
- action is irreversible and recommended only if the application is corrupted.
+ Restore app to initial state, erasing all models and chat history.
+ This action is irreversible and recommended only if the application is
+ corrupted.
= ({ item }) => {
)
const progress = isInstalling
- ? (installingExtensions.find((e) => e.extensionId === item.name)
- ?.percentage ?? -1)
+ ? installingExtensions.find((e) => e.extensionId === item.name)
+ ?.percentage ?? -1
: -1
useEffect(() => {
diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
index abbe6db43..d7d52a093 100644
--- a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
@@ -178,7 +178,7 @@ const SimpleTextMessage: React.FC = (props) => {
>
{isUser
? props.role
- : (activeThread?.assistants[0].assistant_name ?? props.role)}
+ : activeThread?.assistants[0].assistant_name ?? props.role}
{displayDate(props.created)}
From 8fe376340a4de0c9c1c998bdce32350ff6b3f23c Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Thu, 19 Sep 2024 10:06:27 +0700
Subject: [PATCH 11/37] chore: fix linter issue CI
---
web/screens/Settings/CoreExtensions/ExtensionItem.tsx | 4 ++--
.../Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/web/screens/Settings/CoreExtensions/ExtensionItem.tsx b/web/screens/Settings/CoreExtensions/ExtensionItem.tsx
index ec72f5f43..497b8ac4a 100644
--- a/web/screens/Settings/CoreExtensions/ExtensionItem.tsx
+++ b/web/screens/Settings/CoreExtensions/ExtensionItem.tsx
@@ -32,8 +32,8 @@ const ExtensionItem: React.FC = ({ item }) => {
)
const progress = isInstalling
- ? installingExtensions.find((e) => e.extensionId === item.name)
- ?.percentage ?? -1
+ ? (installingExtensions.find((e) => e.extensionId === item.name)
+ ?.percentage ?? -1)
: -1
useEffect(() => {
diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
index d7d52a093..abbe6db43 100644
--- a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
@@ -178,7 +178,7 @@ const SimpleTextMessage: React.FC = (props) => {
>
{isUser
? props.role
- : activeThread?.assistants[0].assistant_name ?? props.role}
+ : (activeThread?.assistants[0].assistant_name ?? props.role)}
{displayDate(props.created)}
From ba3c07eba8973b184cc0701f90f3c955a8a4b894 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Thu, 19 Sep 2024 10:10:30 +0700
Subject: [PATCH 12/37] feat: textarea auto resize (#3695)
* feat: improve textarea user experience with autoresize
* chore: remove log
* chore: update test
* chore: update test and cleanup logic useEffect
---
joi/src/core/TextArea/TextArea.test.tsx | 39 ++++++++++++++++++-
joi/src/core/TextArea/index.tsx | 32 ++++++++++++---
web/containers/ModelConfigInput/index.tsx | 2 +-
web/screens/Thread/ThreadRightPanel/index.tsx | 2 +-
4 files changed, 66 insertions(+), 9 deletions(-)
diff --git a/joi/src/core/TextArea/TextArea.test.tsx b/joi/src/core/TextArea/TextArea.test.tsx
index 8bc64010f..e29eed5d0 100644
--- a/joi/src/core/TextArea/TextArea.test.tsx
+++ b/joi/src/core/TextArea/TextArea.test.tsx
@@ -1,9 +1,8 @@
import React from 'react'
-import { render, screen } from '@testing-library/react'
+import { render, screen, act } from '@testing-library/react'
import '@testing-library/jest-dom'
import { TextArea } from './index'
-// Mock the styles import
jest.mock('./styles.scss', () => ({}))
describe('@joi/core/TextArea', () => {
@@ -31,4 +30,40 @@ describe('@joi/core/TextArea', () => {
const textareaElement = screen.getByTestId('custom-textarea')
expect(textareaElement).toHaveAttribute('rows', '5')
})
+
+ it('should auto resize the textarea based on minResize', () => {
+ render()
+
+ const textarea = screen.getByRole('textbox') as HTMLTextAreaElement
+
+ Object.defineProperty(textarea, 'scrollHeight', {
+ value: 20,
+ writable: true,
+ })
+
+ act(() => {
+ textarea.value = 'Short text'
+ textarea.dispatchEvent(new Event('input', { bubbles: true }))
+ })
+
+ expect(textarea.style.height).toBe('10px')
+ })
+
+ it('should auto resize the textarea based on maxResize', () => {
+ render()
+
+ const textarea = screen.getByRole('textbox') as HTMLTextAreaElement
+
+ Object.defineProperty(textarea, 'scrollHeight', {
+ value: 100,
+ writable: true,
+ })
+
+ act(() => {
+ textarea.value = 'A very long text that should exceed max height'
+ textarea.dispatchEvent(new Event('input', { bubbles: true }))
+ })
+
+ expect(textarea.style.height).toBe('40px')
+ })
})
diff --git a/joi/src/core/TextArea/index.tsx b/joi/src/core/TextArea/index.tsx
index 33d6744ad..6807178ff 100644
--- a/joi/src/core/TextArea/index.tsx
+++ b/joi/src/core/TextArea/index.tsx
@@ -1,19 +1,41 @@
-import React, { ReactNode, forwardRef } from 'react'
+import React, { forwardRef, useRef, useEffect } from 'react'
import { twMerge } from 'tailwind-merge'
import './styles.scss'
-import { ScrollArea } from '../ScrollArea'
+
+type ResizeProps = {
+ autoResize?: boolean
+ minResize?: number
+ maxResize?: number
+}
export interface TextAreaProps
- extends React.TextareaHTMLAttributes {}
+ extends ResizeProps,
+ React.TextareaHTMLAttributes {}
const TextArea = forwardRef(
- ({ className, ...props }, ref) => {
+ (
+ { autoResize, minResize = 80, maxResize = 250, className, ...props },
+ ref
+ ) => {
+ const textareaRef = useRef(null)
+
+ useEffect(() => {
+ if (autoResize && textareaRef.current) {
+ const textarea = textareaRef.current
+ textarea.style.height = 'auto'
+ const scrollHeight = textarea.scrollHeight
+ const newHeight = Math.min(maxResize, Math.max(minResize, scrollHeight))
+ textarea.style.height = `${newHeight}px`
+ textarea.style.overflow = newHeight >= maxResize ? 'auto' : 'hidden'
+ }
+ }, [props.value, autoResize, minResize, maxResize])
+
return (
diff --git a/web/containers/ModelConfigInput/index.tsx b/web/containers/ModelConfigInput/index.tsx
index 840e2378e..f0e6ea1f2 100644
--- a/web/containers/ModelConfigInput/index.tsx
+++ b/web/containers/ModelConfigInput/index.tsx
@@ -36,7 +36,7 @@ const ModelConfigInput = ({
From 194093d95df5435c4609515ed73df62dfdc7906d Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Fri, 20 Sep 2024 10:06:20 +0700
Subject: [PATCH 13/37] fix: update the condition for generating the title
(#3702)
---
web/containers/Providers/EventHandler.tsx | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index 4809ce83e..2e4db4173 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -14,12 +14,14 @@ import {
ModelEvent,
Thread,
EngineManager,
+ InferenceEngine,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
+import { localEngines } from '@/utils/modelEngine'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
@@ -234,7 +236,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
return
}
- if (!activeModelRef.current) {
+ // Check model engine; we don't want to generate a title when it's not a local engine.
+ if (
+ !activeModelRef.current ||
+ !localEngines.includes(activeModelRef.current?.engine as InferenceEngine)
+ ) {
return
}
From 1aefb8f7abae2f6d237711d436ed55b001516971 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Fri, 20 Sep 2024 10:06:37 +0700
Subject: [PATCH 14/37] fix: enhance several minor UI (#3706)
---
web/helpers/atoms/Setting.atom.ts | 2 +-
web/screens/Settings/Advanced/index.tsx | 4 ++--
web/screens/Settings/Appearance/index.tsx | 2 +-
.../Settings/MyModels/MyModelList/index.tsx | 16 +++++++++-------
.../SettingDetailTextInputItem/index.tsx | 1 +
5 files changed, 14 insertions(+), 11 deletions(-)
diff --git a/web/helpers/atoms/Setting.atom.ts b/web/helpers/atoms/Setting.atom.ts
index ced0fbe37..57ca87854 100644
--- a/web/helpers/atoms/Setting.atom.ts
+++ b/web/helpers/atoms/Setting.atom.ts
@@ -19,4 +19,4 @@ export const reduceTransparentAtom = atomWithStorage(
REDUCE_TRANSPARENT,
false
)
-export const spellCheckAtom = atomWithStorage(SPELL_CHECKING, true)
+export const spellCheckAtom = atomWithStorage(SPELL_CHECKING, false)
diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx
index 1384f5688..2c444371a 100644
--- a/web/screens/Settings/Advanced/index.tsx
+++ b/web/screens/Settings/Advanced/index.tsx
@@ -321,7 +321,7 @@ const Advanced = () => {
Choose device(s)
-
+
{
/>
Spell Check
- Toggle to disable spell checking.
+ Turn on to enable spell check
diff --git a/web/screens/Settings/MyModels/MyModelList/index.tsx b/web/screens/Settings/MyModels/MyModelList/index.tsx
index 88fff73f7..329248923 100644
--- a/web/screens/Settings/MyModels/MyModelList/index.tsx
+++ b/web/screens/Settings/MyModels/MyModelList/index.tsx
@@ -49,11 +49,11 @@ const MyModelList = ({ model }: Props) => {
return (
-
-
+
+
{
{model.engine === InferenceEngine.nitro && (
{model.id}
@@ -76,9 +76,11 @@ const MyModelList = ({ model }: Props) => {
{localEngines.includes(model.engine) && (
-
- {toGibibytes(model.metadata.size)}
-
+
+
+ {toGibibytes(model.metadata.size)}
+
+
{stateModel.loading && stateModel.model?.id === model.id ? (
diff --git a/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx b/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx
index b6a204e2e..3127f1578 100644
--- a/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx
+++ b/web/screens/Settings/SettingDetail/SettingDetailItem/SettingDetailTextInputItem/index.tsx
@@ -79,6 +79,7 @@ const SettingDetailTextInputItem = ({
textAlign={textAlign}
value={value}
onChange={(e) => onValueChanged?.(e.target.value)}
+ className="!pr-20"
suffixIcon={
Date: Fri, 20 Sep 2024 14:24:51 +0700
Subject: [PATCH 15/37] test: add web helpers, services, utils tests (#3669)
* test: add web helpers tests
* fix: coverage report
* test: add more tests
* test: add more generated tests
* chore: add more tests
* test: add more tests
---
.gitignore | 1 +
core/jest.config.js | 1 +
core/src/browser/extension.test.ts | 105 +++++++
.../extensions/engines/AIEngine.test.ts | 59 ++++
.../extensions/engines/EngineManager.test.ts | 43 +++
.../extensions/engines/LocalOAIEngine.test.ts | 100 +++++++
.../extensions/engines/OAIEngine.test.ts | 119 ++++++++
.../engines/RemoteOAIEngine.test.ts | 43 +++
.../extensions/engines/helpers/sse.test.ts | 64 +++++
.../browser/extensions/engines/index.test.ts | 6 +
core/src/browser/extensions/index.test.ts | 32 +++
core/src/browser/fs.test.ts | 97 +++++++
core/src/browser/tools/tool.test.ts | 55 ++++
.../src/node/api/processors/processor.test.ts | 0
.../node/api/restful/helper/builder.test.ts | 264 ++++++++++++++++++
core/src/node/api/restful/helper/builder.ts | 6 +-
.../api/restful/helper/startStopModel.test.ts | 16 ++
core/src/node/extension/index.test.ts | 7 +
core/src/types/api/index.test.ts | 24 ++
core/src/types/config/appConfigEvent.test.ts | 9 +
.../huggingface/huggingfaceEntity.test.ts | 28 ++
core/src/types/huggingface/index.test.ts | 8 +
core/src/types/index.test.ts | 28 ++
.../types/inference/inferenceEntity.test.ts | 13 +
.../types/inference/inferenceEvent.test.ts | 7 +
core/src/types/message/messageEvent.test.ts | 7 +
.../types/message/messageRequestType.test.ts | 7 +
core/src/types/model/modelEntity.test.ts | 30 ++
core/src/types/model/modelEvent.test.ts | 7 +
joi/jest.config.js | 1 +
package.json | 2 +-
testRunner.js | 19 ++
web/containers/Loader/Loader.test.tsx | 23 ++
web/extension/Extension.test.ts | 19 ++
web/extension/Extension.ts | 6 +-
web/extension/ExtensionManager.test.ts | 131 +++++++++
web/extension/index.test.ts | 9 +
web/helpers/atoms/ApiServer.atom.test.ts | 9 +
web/helpers/atoms/App.atom.test.ts | 8 +
web/helpers/atoms/AppConfig.atom.test.ts | 7 +
web/helpers/atoms/Assistant.atom.test.ts | 8 +
web/helpers/atoms/ChatMessage.atom.test.ts | 32 +++
web/helpers/atoms/HuggingFace.atom.test.ts | 14 +
web/helpers/atoms/LocalServer.atom.test.ts | 7 +
web/helpers/atoms/Setting.atom.test.ts | 7 +
.../atoms/ThreadRightPanel.atom.test.ts | 6 +
web/hooks/useDownloadState.test.ts | 109 ++++++++
web/jest.config.js | 3 +-
web/screens/Settings/Advanced/index.test.tsx | 5 +-
web/services/appService.test.ts | 30 ++
web/services/eventsService.test.ts | 47 ++++
web/services/extensionService.test.ts | 35 +++
web/services/restService.test.ts | 15 +
web/utils/json.test.ts | 22 ++
web/utils/modelParam.test.ts | 8 +
web/utils/threadMessageBuilder.test.ts | 27 ++
56 files changed, 1786 insertions(+), 9 deletions(-)
create mode 100644 core/src/browser/extensions/engines/AIEngine.test.ts
create mode 100644 core/src/browser/extensions/engines/EngineManager.test.ts
create mode 100644 core/src/browser/extensions/engines/LocalOAIEngine.test.ts
create mode 100644 core/src/browser/extensions/engines/OAIEngine.test.ts
create mode 100644 core/src/browser/extensions/engines/RemoteOAIEngine.test.ts
create mode 100644 core/src/browser/extensions/engines/index.test.ts
create mode 100644 core/src/browser/extensions/index.test.ts
create mode 100644 core/src/browser/fs.test.ts
create mode 100644 core/src/browser/tools/tool.test.ts
delete mode 100644 core/src/node/api/processors/processor.test.ts
create mode 100644 core/src/node/api/restful/helper/builder.test.ts
create mode 100644 core/src/node/api/restful/helper/startStopModel.test.ts
create mode 100644 core/src/node/extension/index.test.ts
create mode 100644 core/src/types/api/index.test.ts
create mode 100644 core/src/types/config/appConfigEvent.test.ts
create mode 100644 core/src/types/huggingface/huggingfaceEntity.test.ts
create mode 100644 core/src/types/huggingface/index.test.ts
create mode 100644 core/src/types/index.test.ts
create mode 100644 core/src/types/inference/inferenceEntity.test.ts
create mode 100644 core/src/types/inference/inferenceEvent.test.ts
create mode 100644 core/src/types/message/messageEvent.test.ts
create mode 100644 core/src/types/message/messageRequestType.test.ts
create mode 100644 core/src/types/model/modelEntity.test.ts
create mode 100644 core/src/types/model/modelEvent.test.ts
create mode 100644 testRunner.js
create mode 100644 web/containers/Loader/Loader.test.tsx
create mode 100644 web/extension/Extension.test.ts
create mode 100644 web/extension/ExtensionManager.test.ts
create mode 100644 web/extension/index.test.ts
create mode 100644 web/helpers/atoms/ApiServer.atom.test.ts
create mode 100644 web/helpers/atoms/App.atom.test.ts
create mode 100644 web/helpers/atoms/AppConfig.atom.test.ts
create mode 100644 web/helpers/atoms/Assistant.atom.test.ts
create mode 100644 web/helpers/atoms/ChatMessage.atom.test.ts
create mode 100644 web/helpers/atoms/HuggingFace.atom.test.ts
create mode 100644 web/helpers/atoms/LocalServer.atom.test.ts
create mode 100644 web/helpers/atoms/Setting.atom.test.ts
create mode 100644 web/helpers/atoms/ThreadRightPanel.atom.test.ts
create mode 100644 web/hooks/useDownloadState.test.ts
create mode 100644 web/services/appService.test.ts
create mode 100644 web/services/eventsService.test.ts
create mode 100644 web/services/extensionService.test.ts
create mode 100644 web/services/restService.test.ts
create mode 100644 web/utils/json.test.ts
create mode 100644 web/utils/threadMessageBuilder.test.ts
diff --git a/.gitignore b/.gitignore
index eaee28a62..f28d152d9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,4 +45,5 @@ core/test_results.html
coverage
.yarn
.yarnrc
+test_results.html
*.tsbuildinfo
diff --git a/core/jest.config.js b/core/jest.config.js
index 6c805f1c9..2f652dd39 100644
--- a/core/jest.config.js
+++ b/core/jest.config.js
@@ -1,6 +1,7 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
+ collectCoverageFrom: ['src/**/*.{ts,tsx}'],
moduleNameMapper: {
'@/(.*)': '/src/$1',
},
diff --git a/core/src/browser/extension.test.ts b/core/src/browser/extension.test.ts
index 6c1cd8579..2db14a24e 100644
--- a/core/src/browser/extension.test.ts
+++ b/core/src/browser/extension.test.ts
@@ -1,4 +1,9 @@
import { BaseExtension } from './extension'
+import { SettingComponentProps } from '../types'
+import { getJanDataFolderPath, joinPath } from './core'
+import { fs } from './fs'
+jest.mock('./core')
+jest.mock('./fs')
class TestBaseExtension extends BaseExtension {
onLoad(): void {}
@@ -44,3 +49,103 @@ describe('BaseExtension', () => {
// Add your assertions here
})
})
+
+describe('BaseExtension', () => {
+ class TestBaseExtension extends BaseExtension {
+ onLoad(): void {}
+ onUnload(): void {}
+ }
+
+ let baseExtension: TestBaseExtension
+
+ beforeEach(() => {
+ baseExtension = new TestBaseExtension('https://example.com', 'TestExtension')
+ })
+
+ afterEach(() => {
+ jest.resetAllMocks()
+ })
+
+ it('should have the correct properties', () => {
+ expect(baseExtension.name).toBe('TestExtension')
+ expect(baseExtension.productName).toBeUndefined()
+ expect(baseExtension.url).toBe('https://example.com')
+ expect(baseExtension.active).toBeUndefined()
+ expect(baseExtension.description).toBeUndefined()
+ expect(baseExtension.version).toBeUndefined()
+ })
+
+ it('should return undefined for type()', () => {
+ expect(baseExtension.type()).toBeUndefined()
+ })
+
+ it('should have abstract methods onLoad() and onUnload()', () => {
+ expect(baseExtension.onLoad).toBeDefined()
+ expect(baseExtension.onUnload).toBeDefined()
+ })
+
+ it('should have installationState() return "NotRequired"', async () => {
+ const installationState = await baseExtension.installationState()
+ expect(installationState).toBe('NotRequired')
+ })
+
+ it('should install the extension', async () => {
+ await baseExtension.install()
+ // Add your assertions here
+ })
+
+ it('should register settings', async () => {
+ const settings: SettingComponentProps[] = [
+ { key: 'setting1', controllerProps: { value: 'value1' } } as any,
+ { key: 'setting2', controllerProps: { value: 'value2' } } as any,
+ ]
+
+ ;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data')
+ ;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension')
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(false)
+ ;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
+ ;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined)
+
+ await baseExtension.registerSettings(settings)
+
+ expect(fs.mkdir).toHaveBeenCalledWith('/data/settings/TestExtension')
+ expect(fs.writeFileSync).toHaveBeenCalledWith(
+ '/data/settings/TestExtension',
+ JSON.stringify(settings, null, 2)
+ )
+ })
+
+ it('should get setting with default value', async () => {
+ const settings: SettingComponentProps[] = [
+ { key: 'setting1', controllerProps: { value: 'value1' } } as any,
+ ]
+
+ jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings)
+
+ const value = await baseExtension.getSetting('setting1', 'defaultValue')
+ expect(value).toBe('value1')
+
+ const defaultValue = await baseExtension.getSetting('setting2', 'defaultValue')
+ expect(defaultValue).toBe('defaultValue')
+ })
+
+ it('should update settings', async () => {
+ const settings: SettingComponentProps[] = [
+ { key: 'setting1', controllerProps: { value: 'value1' } } as any,
+ ]
+
+ jest.spyOn(baseExtension, 'getSettings').mockResolvedValue(settings)
+ ;(getJanDataFolderPath as jest.Mock).mockResolvedValue('/data')
+ ;(joinPath as jest.Mock).mockResolvedValue('/data/settings/TestExtension/settings.json')
+ ;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined)
+
+ await baseExtension.updateSettings([
+ { key: 'setting1', controllerProps: { value: 'newValue' } } as any,
+ ])
+
+ expect(fs.writeFileSync).toHaveBeenCalledWith(
+ '/data/settings/TestExtension/settings.json',
+ JSON.stringify([{ key: 'setting1', controllerProps: { value: 'newValue' } }], null, 2)
+ )
+ })
+})
diff --git a/core/src/browser/extensions/engines/AIEngine.test.ts b/core/src/browser/extensions/engines/AIEngine.test.ts
new file mode 100644
index 000000000..59dad280f
--- /dev/null
+++ b/core/src/browser/extensions/engines/AIEngine.test.ts
@@ -0,0 +1,59 @@
+import { AIEngine } from './AIEngine'
+import { events } from '../../events'
+import { ModelEvent, Model, ModelFile, InferenceEngine } from '../../../types'
+import { EngineManager } from './EngineManager'
+import { fs } from '../../fs'
+
+jest.mock('../../events')
+jest.mock('./EngineManager')
+jest.mock('../../fs')
+
+class TestAIEngine extends AIEngine {
+ onUnload(): void {}
+ provider = 'test-provider'
+
+ inference(data: any) {}
+
+ stopInference() {}
+}
+
+describe('AIEngine', () => {
+ let engine: TestAIEngine
+
+ beforeEach(() => {
+ engine = new TestAIEngine('', '')
+ jest.clearAllMocks()
+ })
+
+ it('should load model if provider matches', async () => {
+ const model: ModelFile = { id: 'model1', engine: 'test-provider' } as any
+
+ await engine.loadModel(model)
+
+ expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
+ })
+
+ it('should not load model if provider does not match', async () => {
+ const model: ModelFile = { id: 'model1', engine: 'other-provider' } as any
+
+ await engine.loadModel(model)
+
+ expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
+ })
+
+ it('should unload model if provider matches', async () => {
+ const model: Model = { id: 'model1', version: '1.0', engine: 'test-provider' } as any
+
+ await engine.unloadModel(model)
+
+ expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, model)
+ })
+
+ it('should not unload model if provider does not match', async () => {
+ const model: Model = { id: 'model1', version: '1.0', engine: 'other-provider' } as any
+
+ await engine.unloadModel(model)
+
+ expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, model)
+ })
+})
diff --git a/core/src/browser/extensions/engines/EngineManager.test.ts b/core/src/browser/extensions/engines/EngineManager.test.ts
new file mode 100644
index 000000000..c1f1fcb71
--- /dev/null
+++ b/core/src/browser/extensions/engines/EngineManager.test.ts
@@ -0,0 +1,43 @@
+/**
+ * @jest-environment jsdom
+ */
+import { EngineManager } from './EngineManager'
+import { AIEngine } from './AIEngine'
+
+// @ts-ignore
+class MockAIEngine implements AIEngine {
+ provider: string
+ constructor(provider: string) {
+ this.provider = provider
+ }
+}
+
+describe('EngineManager', () => {
+ let engineManager: EngineManager
+
+ beforeEach(() => {
+ engineManager = new EngineManager()
+ })
+
+ test('should register an engine', () => {
+ const engine = new MockAIEngine('testProvider')
+ // @ts-ignore
+ engineManager.register(engine)
+ expect(engineManager.engines.get('testProvider')).toBe(engine)
+ })
+
+ test('should retrieve a registered engine by provider', () => {
+ const engine = new MockAIEngine('testProvider')
+ // @ts-ignore
+ engineManager.register(engine)
+ // @ts-ignore
+ const retrievedEngine = engineManager.get('testProvider')
+ expect(retrievedEngine).toBe(engine)
+ })
+
+ test('should return undefined for an unregistered provider', () => {
+ // @ts-ignore
+ const retrievedEngine = engineManager.get('nonExistentProvider')
+ expect(retrievedEngine).toBeUndefined()
+ })
+})
diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts
new file mode 100644
index 000000000..4ae81496f
--- /dev/null
+++ b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts
@@ -0,0 +1,100 @@
+/**
+ * @jest-environment jsdom
+ */
+import { LocalOAIEngine } from './LocalOAIEngine'
+import { events } from '../../events'
+import { ModelEvent, ModelFile, Model } from '../../../types'
+import { executeOnMain, systemInformation, dirName } from '../../core'
+
+jest.mock('../../core', () => ({
+ executeOnMain: jest.fn(),
+ systemInformation: jest.fn(),
+ dirName: jest.fn(),
+}))
+
+jest.mock('../../events', () => ({
+ events: {
+ on: jest.fn(),
+ emit: jest.fn(),
+ },
+}))
+
+class TestLocalOAIEngine extends LocalOAIEngine {
+ inferenceUrl = ''
+ nodeModule = 'testNodeModule'
+ provider = 'testProvider'
+}
+
+describe('LocalOAIEngine', () => {
+ let engine: TestLocalOAIEngine
+
+ beforeEach(() => {
+ engine = new TestLocalOAIEngine('', '')
+ })
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should subscribe to events on load', () => {
+ engine.onLoad()
+ expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelInit, expect.any(Function))
+ expect(events.on).toHaveBeenCalledWith(ModelEvent.OnModelStop, expect.any(Function))
+ })
+
+ it('should load model correctly', async () => {
+ const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any
+ const modelFolder = 'path/to'
+ const systemInfo = { os: 'testOS' }
+ const res = { error: null }
+
+ ;(dirName as jest.Mock).mockResolvedValue(modelFolder)
+ ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
+ ;(executeOnMain as jest.Mock).mockResolvedValue(res)
+
+ await engine.loadModel(model)
+
+ expect(dirName).toHaveBeenCalledWith(model.file_path)
+ expect(systemInformation).toHaveBeenCalled()
+ expect(executeOnMain).toHaveBeenCalledWith(
+ engine.nodeModule,
+ engine.loadModelFunctionName,
+ { modelFolder, model },
+ systemInfo
+ )
+ expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
+ })
+
+ it('should handle load model error', async () => {
+ const model: ModelFile = { engine: 'testProvider', file_path: 'path/to/model' } as any
+ const modelFolder = 'path/to'
+ const systemInfo = { os: 'testOS' }
+ const res = { error: 'load error' }
+
+ ;(dirName as jest.Mock).mockResolvedValue(modelFolder)
+ ;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
+ ;(executeOnMain as jest.Mock).mockResolvedValue(res)
+
+ await expect(engine.loadModel(model)).rejects.toEqual('load error')
+
+ expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelFail, { error: res.error })
+ })
+
+ it('should unload model correctly', async () => {
+ const model: Model = { engine: 'testProvider' } as any
+
+ await engine.unloadModel(model)
+
+ expect(executeOnMain).toHaveBeenCalledWith(engine.nodeModule, engine.unloadModelFunctionName)
+ expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, {})
+ })
+
+ it('should not unload model if engine does not match', async () => {
+ const model: Model = { engine: 'otherProvider' } as any
+
+ await engine.unloadModel(model)
+
+ expect(executeOnMain).not.toHaveBeenCalled()
+ expect(events.emit).not.toHaveBeenCalledWith(ModelEvent.OnModelStopped, {})
+ })
+})
diff --git a/core/src/browser/extensions/engines/OAIEngine.test.ts b/core/src/browser/extensions/engines/OAIEngine.test.ts
new file mode 100644
index 000000000..81348786c
--- /dev/null
+++ b/core/src/browser/extensions/engines/OAIEngine.test.ts
@@ -0,0 +1,119 @@
+/**
+ * @jest-environment jsdom
+ */
+import { OAIEngine } from './OAIEngine'
+import { events } from '../../events'
+import {
+ MessageEvent,
+ InferenceEvent,
+ MessageRequest,
+ MessageRequestType,
+ MessageStatus,
+ ChatCompletionRole,
+ ContentType,
+} from '../../../types'
+import { requestInference } from './helpers/sse'
+import { ulid } from 'ulidx'
+
+jest.mock('./helpers/sse')
+jest.mock('ulidx')
+jest.mock('../../events')
+
+class TestOAIEngine extends OAIEngine {
+ inferenceUrl = 'http://test-inference-url'
+ provider = 'test-provider'
+
+ async headers() {
+ return { Authorization: 'Bearer test-token' }
+ }
+}
+
+describe('OAIEngine', () => {
+ let engine: TestOAIEngine
+
+ beforeEach(() => {
+ engine = new TestOAIEngine('', '')
+ jest.clearAllMocks()
+ })
+
+ it('should subscribe to events on load', () => {
+ engine.onLoad()
+ expect(events.on).toHaveBeenCalledWith(MessageEvent.OnMessageSent, expect.any(Function))
+ expect(events.on).toHaveBeenCalledWith(InferenceEvent.OnInferenceStopped, expect.any(Function))
+ })
+
+ it('should handle inference request', async () => {
+ const data: MessageRequest = {
+ model: { engine: 'test-provider', id: 'test-model' } as any,
+ threadId: 'test-thread',
+ type: MessageRequestType.Thread,
+ assistantId: 'test-assistant',
+ messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
+ }
+
+ ;(ulid as jest.Mock).mockReturnValue('test-id')
+ ;(requestInference as jest.Mock).mockReturnValue({
+ subscribe: ({ next, complete }: any) => {
+ next('test response')
+ complete()
+ },
+ })
+
+ await engine.inference(data)
+
+ expect(requestInference).toHaveBeenCalledWith(
+ 'http://test-inference-url',
+ expect.objectContaining({ model: 'test-model' }),
+ expect.any(Object),
+ expect.any(AbortController),
+ { Authorization: 'Bearer test-token' },
+ undefined
+ )
+
+ expect(events.emit).toHaveBeenCalledWith(
+ MessageEvent.OnMessageResponse,
+ expect.objectContaining({ id: 'test-id' })
+ )
+ expect(events.emit).toHaveBeenCalledWith(
+ MessageEvent.OnMessageUpdate,
+ expect.objectContaining({
+ content: [{ type: ContentType.Text, text: { value: 'test response', annotations: [] } }],
+ status: MessageStatus.Ready,
+ })
+ )
+ })
+
+ it('should handle inference error', async () => {
+ const data: MessageRequest = {
+ model: { engine: 'test-provider', id: 'test-model' } as any,
+ threadId: 'test-thread',
+ type: MessageRequestType.Thread,
+ assistantId: 'test-assistant',
+ messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
+ }
+
+ ;(ulid as jest.Mock).mockReturnValue('test-id')
+ ;(requestInference as jest.Mock).mockReturnValue({
+ subscribe: ({ error }: any) => {
+ error({ message: 'test error', code: 500 })
+ },
+ })
+
+ await engine.inference(data)
+
+ expect(events.emit).toHaveBeenCalledWith(
+ MessageEvent.OnMessageUpdate,
+ expect.objectContaining({
+ content: [{ type: ContentType.Text, text: { value: 'test error', annotations: [] } }],
+ status: MessageStatus.Error,
+ error_code: 500,
+ })
+ )
+ })
+
+ it('should stop inference', () => {
+ engine.stopInference()
+ expect(engine.isCancelled).toBe(true)
+ expect(engine.controller.signal.aborted).toBe(true)
+ })
+})
diff --git a/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts b/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts
new file mode 100644
index 000000000..871499f45
--- /dev/null
+++ b/core/src/browser/extensions/engines/RemoteOAIEngine.test.ts
@@ -0,0 +1,43 @@
+/**
+ * @jest-environment jsdom
+ */
+import { RemoteOAIEngine } from './'
+
+class TestRemoteOAIEngine extends RemoteOAIEngine {
+ inferenceUrl: string = ''
+ provider: string = 'TestRemoteOAIEngine'
+}
+
+describe('RemoteOAIEngine', () => {
+ let engine: TestRemoteOAIEngine
+
+ beforeEach(() => {
+ engine = new TestRemoteOAIEngine('', '')
+ })
+
+ test('should call onLoad and super.onLoad', () => {
+ const onLoadSpy = jest.spyOn(engine, 'onLoad')
+ const superOnLoadSpy = jest.spyOn(Object.getPrototypeOf(RemoteOAIEngine.prototype), 'onLoad')
+ engine.onLoad()
+
+ expect(onLoadSpy).toHaveBeenCalled()
+ expect(superOnLoadSpy).toHaveBeenCalled()
+ })
+
+ test('should return headers with apiKey', async () => {
+ engine.apiKey = 'test-api-key'
+ const headers = await engine.headers()
+
+ expect(headers).toEqual({
+ 'Authorization': 'Bearer test-api-key',
+ 'api-key': 'test-api-key',
+ })
+ })
+
+ test('should return empty headers when apiKey is not set', async () => {
+ engine.apiKey = undefined
+ const headers = await engine.headers()
+
+ expect(headers).toEqual({})
+ })
+})
diff --git a/core/src/browser/extensions/engines/helpers/sse.test.ts b/core/src/browser/extensions/engines/helpers/sse.test.ts
index cff5b93b3..0b78aa9b5 100644
--- a/core/src/browser/extensions/engines/helpers/sse.test.ts
+++ b/core/src/browser/extensions/engines/helpers/sse.test.ts
@@ -1,6 +1,7 @@
import { lastValueFrom, Observable } from 'rxjs'
import { requestInference } from './sse'
+import { ReadableStream } from 'stream/web';
describe('requestInference', () => {
it('should send a request to the inference server and return an Observable', () => {
// Mock the fetch function
@@ -58,3 +59,66 @@ describe('requestInference', () => {
expect(lastValueFrom(result)).rejects.toEqual({ message: 'Wrong API Key', code: 'invalid_api_key' })
})
})
+
+ it('should handle a successful response with a transformResponse function', () => {
+ // Mock the fetch function
+ const mockFetch: any = jest.fn(() =>
+ Promise.resolve({
+ ok: true,
+ json: () => Promise.resolve({ choices: [{ message: { content: 'Generated response' } }] }),
+ headers: new Headers(),
+ redirected: false,
+ status: 200,
+ statusText: 'OK',
+ })
+ )
+ jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
+
+ // Define the test inputs
+ const inferenceUrl = 'https://inference-server.com'
+ const requestBody = { message: 'Hello' }
+ const model = { id: 'model-id', parameters: { stream: false } }
+ const transformResponse = (data: any) => data.choices[0].message.content.toUpperCase()
+
+ // Call the function
+ const result = requestInference(inferenceUrl, requestBody, model, undefined, undefined, transformResponse)
+
+ // Assert the expected behavior
+ expect(result).toBeInstanceOf(Observable)
+ expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE')
+ })
+
+
+ it('should handle a successful response with streaming enabled', () => {
+ // Mock the fetch function
+ const mockFetch: any = jest.fn(() =>
+ Promise.resolve({
+ ok: true,
+ body: new ReadableStream({
+ start(controller) {
+ controller.enqueue(new TextEncoder().encode('data: {"choices": [{"delta": {"content": "Streamed"}}]}'));
+ controller.enqueue(new TextEncoder().encode('data: [DONE]'));
+ controller.close();
+ }
+ }),
+ headers: new Headers(),
+ redirected: false,
+ status: 200,
+ statusText: 'OK',
+ })
+ );
+ jest.spyOn(global, 'fetch').mockImplementation(mockFetch);
+
+ // Define the test inputs
+ const inferenceUrl = 'https://inference-server.com';
+ const requestBody = { message: 'Hello' };
+ const model = { id: 'model-id', parameters: { stream: true } };
+
+ // Call the function
+ const result = requestInference(inferenceUrl, requestBody, model);
+
+ // Assert the expected behavior
+ expect(result).toBeInstanceOf(Observable);
+ expect(lastValueFrom(result)).resolves.toEqual('Streamed');
+ });
+
diff --git a/core/src/browser/extensions/engines/index.test.ts b/core/src/browser/extensions/engines/index.test.ts
new file mode 100644
index 000000000..4c0ef11d8
--- /dev/null
+++ b/core/src/browser/extensions/engines/index.test.ts
@@ -0,0 +1,6 @@
+
+import { expect } from '@jest/globals';
+
+it('should re-export all exports from ./AIEngine', () => {
+ expect(require('./index')).toHaveProperty('AIEngine');
+});
diff --git a/core/src/browser/extensions/index.test.ts b/core/src/browser/extensions/index.test.ts
new file mode 100644
index 000000000..26cbda8c5
--- /dev/null
+++ b/core/src/browser/extensions/index.test.ts
@@ -0,0 +1,32 @@
+import { ConversationalExtension } from './index';
+import { InferenceExtension } from './index';
+import { MonitoringExtension } from './index';
+import { AssistantExtension } from './index';
+import { ModelExtension } from './index';
+import * as Engines from './index';
+
+describe('index.ts exports', () => {
+ test('should export ConversationalExtension', () => {
+ expect(ConversationalExtension).toBeDefined();
+ });
+
+ test('should export InferenceExtension', () => {
+ expect(InferenceExtension).toBeDefined();
+ });
+
+ test('should export MonitoringExtension', () => {
+ expect(MonitoringExtension).toBeDefined();
+ });
+
+ test('should export AssistantExtension', () => {
+ expect(AssistantExtension).toBeDefined();
+ });
+
+ test('should export ModelExtension', () => {
+ expect(ModelExtension).toBeDefined();
+ });
+
+ test('should export Engines', () => {
+ expect(Engines).toBeDefined();
+ });
+});
\ No newline at end of file
diff --git a/core/src/browser/fs.test.ts b/core/src/browser/fs.test.ts
new file mode 100644
index 000000000..21da54874
--- /dev/null
+++ b/core/src/browser/fs.test.ts
@@ -0,0 +1,97 @@
+import { fs } from './fs'
+
+describe('fs module', () => {
+ beforeEach(() => {
+ globalThis.core = {
+ api: {
+ writeFileSync: jest.fn(),
+ writeBlob: jest.fn(),
+ readFileSync: jest.fn(),
+ existsSync: jest.fn(),
+ readdirSync: jest.fn(),
+ mkdir: jest.fn(),
+ rm: jest.fn(),
+ unlinkSync: jest.fn(),
+ appendFileSync: jest.fn(),
+ copyFile: jest.fn(),
+ getGgufFiles: jest.fn(),
+ fileStat: jest.fn(),
+ },
+ }
+ })
+
+ it('should call writeFileSync with correct arguments', () => {
+ const args = ['path/to/file', 'data']
+ fs.writeFileSync(...args)
+ expect(globalThis.core.api.writeFileSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call writeBlob with correct arguments', async () => {
+ const path = 'path/to/file'
+ const data = 'blob data'
+ await fs.writeBlob(path, data)
+ expect(globalThis.core.api.writeBlob).toHaveBeenCalledWith(path, data)
+ })
+
+ it('should call readFileSync with correct arguments', () => {
+ const args = ['path/to/file']
+ fs.readFileSync(...args)
+ expect(globalThis.core.api.readFileSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call existsSync with correct arguments', () => {
+ const args = ['path/to/file']
+ fs.existsSync(...args)
+ expect(globalThis.core.api.existsSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call readdirSync with correct arguments', () => {
+ const args = ['path/to/directory']
+ fs.readdirSync(...args)
+ expect(globalThis.core.api.readdirSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call mkdir with correct arguments', () => {
+ const args = ['path/to/directory']
+ fs.mkdir(...args)
+ expect(globalThis.core.api.mkdir).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call rm with correct arguments', () => {
+ const args = ['path/to/directory']
+ fs.rm(...args)
+ expect(globalThis.core.api.rm).toHaveBeenCalledWith(...args, { recursive: true, force: true })
+ })
+
+ it('should call unlinkSync with correct arguments', () => {
+ const args = ['path/to/file']
+ fs.unlinkSync(...args)
+ expect(globalThis.core.api.unlinkSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call appendFileSync with correct arguments', () => {
+ const args = ['path/to/file', 'data']
+ fs.appendFileSync(...args)
+ expect(globalThis.core.api.appendFileSync).toHaveBeenCalledWith(...args)
+ })
+
+ it('should call copyFile with correct arguments', async () => {
+ const src = 'path/to/src'
+ const dest = 'path/to/dest'
+ await fs.copyFile(src, dest)
+ expect(globalThis.core.api.copyFile).toHaveBeenCalledWith(src, dest)
+ })
+
+ it('should call getGgufFiles with correct arguments', async () => {
+ const paths = ['path/to/file1', 'path/to/file2']
+ await fs.getGgufFiles(paths)
+ expect(globalThis.core.api.getGgufFiles).toHaveBeenCalledWith(paths)
+ })
+
+ it('should call fileStat with correct arguments', async () => {
+ const path = 'path/to/file'
+ const outsideJanDataFolder = true
+ await fs.fileStat(path, outsideJanDataFolder)
+ expect(globalThis.core.api.fileStat).toHaveBeenCalledWith(path, outsideJanDataFolder)
+ })
+})
diff --git a/core/src/browser/tools/tool.test.ts b/core/src/browser/tools/tool.test.ts
new file mode 100644
index 000000000..ba918a3cb
--- /dev/null
+++ b/core/src/browser/tools/tool.test.ts
@@ -0,0 +1,55 @@
+import { ToolManager } from '../../browser/tools/manager'
+import { InferenceTool } from '../../browser/tools/tool'
+import { AssistantTool, MessageRequest } from '../../types'
+
+class MockInferenceTool implements InferenceTool {
+ name = 'mockTool'
+ process(request: MessageRequest, tool: AssistantTool): Promise {
+ return Promise.resolve(request)
+ }
+}
+
+it('should register a tool', () => {
+ const manager = new ToolManager()
+ const tool = new MockInferenceTool()
+ manager.register(tool)
+ expect(manager.get(tool.name)).toBe(tool)
+})
+
+it('should retrieve a tool by its name', () => {
+ const manager = new ToolManager()
+ const tool = new MockInferenceTool()
+ manager.register(tool)
+ const retrievedTool = manager.get(tool.name)
+ expect(retrievedTool).toBe(tool)
+})
+
+it('should return undefined for a non-existent tool', () => {
+ const manager = new ToolManager()
+ const retrievedTool = manager.get('nonExistentTool')
+ expect(retrievedTool).toBeUndefined()
+})
+
+it('should process the message request with enabled tools', async () => {
+ const manager = new ToolManager()
+ const tool = new MockInferenceTool()
+ manager.register(tool)
+
+ const request: MessageRequest = { message: 'test' } as any
+ const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any
+
+ const result = await manager.process(request, tools)
+ expect(result).toBe(request)
+})
+
+it('should skip processing for disabled tools', async () => {
+ const manager = new ToolManager()
+ const tool = new MockInferenceTool()
+ manager.register(tool)
+
+ const request: MessageRequest = { message: 'test' } as any
+ const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any
+
+ const result = await manager.process(request, tools)
+ expect(result).toBe(request)
+})
\ No newline at end of file
diff --git a/core/src/node/api/processors/processor.test.ts b/core/src/node/api/processors/processor.test.ts
deleted file mode 100644
index e69de29bb..000000000
diff --git a/core/src/node/api/restful/helper/builder.test.ts b/core/src/node/api/restful/helper/builder.test.ts
new file mode 100644
index 000000000..fef40c70a
--- /dev/null
+++ b/core/src/node/api/restful/helper/builder.test.ts
@@ -0,0 +1,264 @@
+import {
+ existsSync,
+ readdirSync,
+ readFileSync,
+ writeFileSync,
+ mkdirSync,
+ appendFileSync,
+ rmdirSync,
+} from 'fs'
+import { join } from 'path'
+import {
+ getBuilder,
+ retrieveBuilder,
+ deleteBuilder,
+ getMessages,
+ retrieveMessage,
+ createThread,
+ updateThread,
+ createMessage,
+ downloadModel,
+ chatCompletions,
+} from './builder'
+import { RouteConfiguration } from './configuration'
+
+jest.mock('fs')
+jest.mock('path')
+jest.mock('../../../helper', () => ({
+ getEngineConfiguration: jest.fn(),
+ getJanDataFolderPath: jest.fn().mockReturnValue('/mock/path'),
+}))
+jest.mock('request')
+jest.mock('request-progress')
+jest.mock('node-fetch')
+
+describe('builder helper functions', () => {
+ const mockConfiguration: RouteConfiguration = {
+ dirName: 'mockDir',
+ metadataFileName: 'metadata.json',
+ delete: {
+ object: 'mockObject',
+ },
+ }
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('getBuilder', () => {
+ it('should return an empty array if directory does not exist', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(false)
+ const result = await getBuilder(mockConfiguration)
+ expect(result).toEqual([])
+ })
+
+ it('should return model data if directory exists', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await getBuilder(mockConfiguration)
+ expect(result).toEqual([{ id: 'model1' }])
+ })
+ })
+
+ describe('retrieveBuilder', () => {
+ it('should return undefined if no data matches the id', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await retrieveBuilder(mockConfiguration, 'nonexistentId')
+ expect(result).toBeUndefined()
+ })
+
+ it('should return the matching data', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await retrieveBuilder(mockConfiguration, 'model1')
+ expect(result).toEqual({ id: 'model1' })
+ })
+ })
+
+ describe('deleteBuilder', () => {
+ it('should return a message if trying to delete Jan assistant', async () => {
+ const result = await deleteBuilder({ ...mockConfiguration, dirName: 'assistants' }, 'jan')
+ expect(result).toEqual({ message: 'Cannot delete Jan assistant' })
+ })
+
+ it('should return a message if data is not found', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await deleteBuilder(mockConfiguration, 'nonexistentId')
+ expect(result).toEqual({ message: 'Not found' })
+ })
+
+ it('should delete the directory and return success message', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await deleteBuilder(mockConfiguration, 'model1')
+ expect(rmdirSync).toHaveBeenCalledWith(join('/mock/path', 'mockDir', 'model1'), {
+ recursive: true,
+ })
+ expect(result).toEqual({ id: 'model1', object: 'mockObject', deleted: true })
+ })
+ })
+
+ describe('getMessages', () => {
+ it('should return an empty array if message file does not exist', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(false)
+
+ const result = await getMessages('thread1')
+ expect(result).toEqual([])
+ })
+
+ it('should return messages if message file exists', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
+ ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n{"id":"msg2"}\n')
+
+ const result = await getMessages('thread1')
+ expect(result).toEqual([{ id: 'msg1' }, { id: 'msg2' }])
+ })
+ })
+
+ describe('retrieveMessage', () => {
+ it('should return a message if no messages match the id', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
+ ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n')
+
+ const result = await retrieveMessage('thread1', 'nonexistentId')
+ expect(result).toEqual({ message: 'Not found' })
+ })
+
+ it('should return the matching message', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['messages.jsonl'])
+ ;(readFileSync as jest.Mock).mockReturnValue('{"id":"msg1"}\n')
+
+ const result = await retrieveMessage('thread1', 'msg1')
+ expect(result).toEqual({ id: 'msg1' })
+ })
+ })
+
+ describe('createThread', () => {
+ it('should return a message if thread has no assistants', async () => {
+ const result = await createThread({})
+ expect(result).toEqual({ message: 'Thread must have at least one assistant' })
+ })
+
+ it('should create a thread and return the updated thread', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(false)
+
+ const thread = { assistants: [{ assistant_id: 'assistant1' }] }
+ const result = await createThread(thread)
+ expect(mkdirSync).toHaveBeenCalled()
+ expect(writeFileSync).toHaveBeenCalled()
+ expect(result.id).toBeDefined()
+ })
+ })
+
+ describe('updateThread', () => {
+ it('should return a message if thread is not found', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await updateThread('nonexistentId', {})
+ expect(result).toEqual({ message: 'Thread not found' })
+ })
+
+ it('should update the thread and return the updated thread', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await updateThread('model1', { name: 'updatedName' })
+ expect(writeFileSync).toHaveBeenCalled()
+ expect(result.name).toEqual('updatedName')
+ })
+ })
+
+ describe('createMessage', () => {
+ it('should create a message and return the created message', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(false)
+ const message = { role: 'user', content: 'Hello' }
+
+ const result = (await createMessage('thread1', message)) as any
+ expect(mkdirSync).toHaveBeenCalled()
+ expect(appendFileSync).toHaveBeenCalled()
+ expect(result.id).toBeDefined()
+ })
+ })
+
+ describe('downloadModel', () => {
+ it('should return a message if model is not found', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(JSON.stringify({ id: 'model1' }))
+
+ const result = await downloadModel('nonexistentId')
+ expect(result).toEqual({ message: 'Model not found' })
+ })
+
+ it('should start downloading the model', async () => {
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(
+ JSON.stringify({ id: 'model1', object: 'model', sources: ['http://example.com'] })
+ )
+ const result = await downloadModel('model1')
+ expect(result).toEqual({ message: 'Starting download model1' })
+ })
+ })
+
+ describe('chatCompletions', () => {
+ it('should return an error if model is not found', async () => {
+ const request = { body: { model: 'nonexistentModel' } }
+ const reply = { code: jest.fn().mockReturnThis(), send: jest.fn() }
+
+ await chatCompletions(request, reply)
+ expect(reply.code).toHaveBeenCalledWith(404)
+ expect(reply.send).toHaveBeenCalledWith({
+ error: {
+ message: 'The model nonexistentModel does not exist',
+ type: 'invalid_request_error',
+ param: null,
+ code: 'model_not_found',
+ },
+ })
+ })
+
+ it('should return the chat completions', async () => {
+ const request = { body: { model: 'model1' } }
+ const reply = {
+ code: jest.fn().mockReturnThis(),
+ send: jest.fn(),
+ raw: { writeHead: jest.fn(), pipe: jest.fn() },
+ }
+
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(
+ JSON.stringify({ id: 'model1', engine: 'openai' })
+ )
+
+ // Mock fetch
+ const fetch = require('node-fetch')
+ fetch.mockResolvedValue({
+ status: 200,
+ body: { pipe: jest.fn() },
+ json: jest.fn().mockResolvedValue({ completions: ['completion1'] }),
+ })
+ await chatCompletions(request, reply)
+ expect(reply.raw.writeHead).toHaveBeenCalledWith(200, expect.any(Object))
+ })
+ })
+})
diff --git a/core/src/node/api/restful/helper/builder.ts b/core/src/node/api/restful/helper/builder.ts
index 08da0ff33..1a8120918 100644
--- a/core/src/node/api/restful/helper/builder.ts
+++ b/core/src/node/api/restful/helper/builder.ts
@@ -280,13 +280,13 @@ export const downloadModel = async (
for (const source of model.sources) {
const rq = request({ url: source, strictSSL, proxy })
progress(rq, {})
- .on('progress', function (state: any) {
+ ?.on('progress', function (state: any) {
console.debug('progress', JSON.stringify(state, null, 2))
})
- .on('error', function (err: Error) {
+ ?.on('error', function (err: Error) {
console.error('error', err)
})
- .on('end', function () {
+ ?.on('end', function () {
console.debug('end')
})
.pipe(createWriteStream(modelBinaryPath))
diff --git a/core/src/node/api/restful/helper/startStopModel.test.ts b/core/src/node/api/restful/helper/startStopModel.test.ts
new file mode 100644
index 000000000..a5475cc28
--- /dev/null
+++ b/core/src/node/api/restful/helper/startStopModel.test.ts
@@ -0,0 +1,16 @@
+
+
+ import { startModel } from './startStopModel'
+
+ describe('startModel', () => {
+ it('test_startModel_error', async () => {
+ const modelId = 'testModelId'
+ const settingParams = undefined
+
+ const result = await startModel(modelId, settingParams)
+
+ expect(result).toEqual({
+ error: expect.any(Error),
+ })
+ })
+ })
diff --git a/core/src/node/extension/index.test.ts b/core/src/node/extension/index.test.ts
new file mode 100644
index 000000000..ce9cb0d0a
--- /dev/null
+++ b/core/src/node/extension/index.test.ts
@@ -0,0 +1,7 @@
+
+
+ import { useExtensions } from './index'
+
+ test('testUseExtensionsMissingPath', () => {
+ expect(() => useExtensions(undefined as any)).toThrowError('A path to the extensions folder is required to use extensions')
+ })
diff --git a/core/src/types/api/index.test.ts b/core/src/types/api/index.test.ts
new file mode 100644
index 000000000..6f2f2dcdb
--- /dev/null
+++ b/core/src/types/api/index.test.ts
@@ -0,0 +1,24 @@
+
+
+import { NativeRoute } from '../index';
+
+test('testNativeRouteEnum', () => {
+ expect(NativeRoute.openExternalUrl).toBe('openExternalUrl');
+ expect(NativeRoute.openAppDirectory).toBe('openAppDirectory');
+ expect(NativeRoute.openFileExplore).toBe('openFileExplorer');
+ expect(NativeRoute.selectDirectory).toBe('selectDirectory');
+ expect(NativeRoute.selectFiles).toBe('selectFiles');
+ expect(NativeRoute.relaunch).toBe('relaunch');
+ expect(NativeRoute.setNativeThemeLight).toBe('setNativeThemeLight');
+ expect(NativeRoute.setNativeThemeDark).toBe('setNativeThemeDark');
+ expect(NativeRoute.setMinimizeApp).toBe('setMinimizeApp');
+ expect(NativeRoute.setCloseApp).toBe('setCloseApp');
+ expect(NativeRoute.setMaximizeApp).toBe('setMaximizeApp');
+ expect(NativeRoute.showOpenMenu).toBe('showOpenMenu');
+ expect(NativeRoute.hideQuickAskWindow).toBe('hideQuickAskWindow');
+ expect(NativeRoute.sendQuickAskInput).toBe('sendQuickAskInput');
+ expect(NativeRoute.hideMainWindow).toBe('hideMainWindow');
+ expect(NativeRoute.showMainWindow).toBe('showMainWindow');
+ expect(NativeRoute.quickAskSizeUpdated).toBe('quickAskSizeUpdated');
+ expect(NativeRoute.ackDeepLink).toBe('ackDeepLink');
+});
diff --git a/core/src/types/config/appConfigEvent.test.ts b/core/src/types/config/appConfigEvent.test.ts
new file mode 100644
index 000000000..6000156c7
--- /dev/null
+++ b/core/src/types/config/appConfigEvent.test.ts
@@ -0,0 +1,9 @@
+
+
+ import { AppConfigurationEventName } from './appConfigEvent';
+
+ describe('AppConfigurationEventName', () => {
+ it('should have the correct value for OnConfigurationUpdate', () => {
+ expect(AppConfigurationEventName.OnConfigurationUpdate).toBe('OnConfigurationUpdate');
+ });
+ });
diff --git a/core/src/types/huggingface/huggingfaceEntity.test.ts b/core/src/types/huggingface/huggingfaceEntity.test.ts
new file mode 100644
index 000000000..d57b484be
--- /dev/null
+++ b/core/src/types/huggingface/huggingfaceEntity.test.ts
@@ -0,0 +1,28 @@
+
+
+ import { AllQuantizations } from './huggingfaceEntity';
+
+ test('testAllQuantizationsArray', () => {
+ expect(AllQuantizations).toEqual([
+ 'Q3_K_S',
+ 'Q3_K_M',
+ 'Q3_K_L',
+ 'Q4_K_S',
+ 'Q4_K_M',
+ 'Q5_K_S',
+ 'Q5_K_M',
+ 'Q4_0',
+ 'Q4_1',
+ 'Q5_0',
+ 'Q5_1',
+ 'IQ2_XXS',
+ 'IQ2_XS',
+ 'Q2_K',
+ 'Q2_K_S',
+ 'Q6_K',
+ 'Q8_0',
+ 'F16',
+ 'F32',
+ 'COPY',
+ ]);
+ });
diff --git a/core/src/types/huggingface/index.test.ts b/core/src/types/huggingface/index.test.ts
new file mode 100644
index 000000000..9cb80a08f
--- /dev/null
+++ b/core/src/types/huggingface/index.test.ts
@@ -0,0 +1,8 @@
+
+
+ import * as huggingfaceEntity from './huggingfaceEntity';
+ import * as index from './index';
+
+ test('test_exports_from_huggingfaceEntity', () => {
+ expect(index).toEqual(huggingfaceEntity);
+ });
diff --git a/core/src/types/index.test.ts b/core/src/types/index.test.ts
new file mode 100644
index 000000000..9dc001c4d
--- /dev/null
+++ b/core/src/types/index.test.ts
@@ -0,0 +1,28 @@
+
+import * as assistant from './assistant';
+import * as model from './model';
+import * as thread from './thread';
+import * as message from './message';
+import * as inference from './inference';
+import * as monitoring from './monitoring';
+import * as file from './file';
+import * as config from './config';
+import * as huggingface from './huggingface';
+import * as miscellaneous from './miscellaneous';
+import * as api from './api';
+import * as setting from './setting';
+
+ test('test_module_exports', () => {
+ expect(assistant).toBeDefined();
+ expect(model).toBeDefined();
+ expect(thread).toBeDefined();
+ expect(message).toBeDefined();
+ expect(inference).toBeDefined();
+ expect(monitoring).toBeDefined();
+ expect(file).toBeDefined();
+ expect(config).toBeDefined();
+ expect(huggingface).toBeDefined();
+ expect(miscellaneous).toBeDefined();
+ expect(api).toBeDefined();
+ expect(setting).toBeDefined();
+ });
diff --git a/core/src/types/inference/inferenceEntity.test.ts b/core/src/types/inference/inferenceEntity.test.ts
new file mode 100644
index 000000000..a2c06e32b
--- /dev/null
+++ b/core/src/types/inference/inferenceEntity.test.ts
@@ -0,0 +1,13 @@
+
+
+ import { ChatCompletionMessage, ChatCompletionRole } from './inferenceEntity';
+
+ test('test_chatCompletionMessage_withStringContent_andSystemRole', () => {
+ const message: ChatCompletionMessage = {
+ content: 'Hello, world!',
+ role: ChatCompletionRole.System,
+ };
+
+ expect(message.content).toBe('Hello, world!');
+ expect(message.role).toBe(ChatCompletionRole.System);
+ });
diff --git a/core/src/types/inference/inferenceEvent.test.ts b/core/src/types/inference/inferenceEvent.test.ts
new file mode 100644
index 000000000..1cb44fdbb
--- /dev/null
+++ b/core/src/types/inference/inferenceEvent.test.ts
@@ -0,0 +1,7 @@
+
+
+ import { InferenceEvent } from './inferenceEvent';
+
+ test('testInferenceEventEnumContainsOnInferenceStopped', () => {
+ expect(InferenceEvent.OnInferenceStopped).toBe('OnInferenceStopped');
+ });
diff --git a/core/src/types/message/messageEvent.test.ts b/core/src/types/message/messageEvent.test.ts
new file mode 100644
index 000000000..80a943bb1
--- /dev/null
+++ b/core/src/types/message/messageEvent.test.ts
@@ -0,0 +1,7 @@
+
+
+ import { MessageEvent } from './messageEvent';
+
+ test('testOnMessageSentValue', () => {
+ expect(MessageEvent.OnMessageSent).toBe('OnMessageSent');
+ });
diff --git a/core/src/types/message/messageRequestType.test.ts b/core/src/types/message/messageRequestType.test.ts
new file mode 100644
index 000000000..41f53b2e0
--- /dev/null
+++ b/core/src/types/message/messageRequestType.test.ts
@@ -0,0 +1,7 @@
+
+
+ import { MessageRequestType } from './messageRequestType';
+
+ test('testMessageRequestTypeEnumContainsThread', () => {
+ expect(MessageRequestType.Thread).toBe('Thread');
+ });
diff --git a/core/src/types/model/modelEntity.test.ts b/core/src/types/model/modelEntity.test.ts
new file mode 100644
index 000000000..306316ac4
--- /dev/null
+++ b/core/src/types/model/modelEntity.test.ts
@@ -0,0 +1,30 @@
+
+
+ import { Model, ModelSettingParams, ModelRuntimeParams, InferenceEngine } from '../model'
+
+ test('testValidModelCreation', () => {
+ const model: Model = {
+ object: 'model',
+ version: '1.0',
+ format: 'format1',
+ sources: [{ filename: 'model.bin', url: 'http://example.com/model.bin' }],
+ id: 'model1',
+ name: 'Test Model',
+ created: Date.now(),
+ description: 'A cool model from Huggingface',
+ settings: { ctx_len: 100, ngl: 50, embedding: true },
+ parameters: { temperature: 0.5, token_limit: 100, top_k: 10 },
+ metadata: { author: 'Author', tags: ['tag1', 'tag2'], size: 100 },
+ engine: InferenceEngine.anthropic
+ };
+
+ expect(model).toBeDefined();
+ expect(model.object).toBe('model');
+ expect(model.version).toBe('1.0');
+ expect(model.sources).toHaveLength(1);
+ expect(model.sources[0].filename).toBe('model.bin');
+ expect(model.settings).toBeDefined();
+ expect(model.parameters).toBeDefined();
+ expect(model.metadata).toBeDefined();
+ expect(model.engine).toBe(InferenceEngine.anthropic);
+ });
diff --git a/core/src/types/model/modelEvent.test.ts b/core/src/types/model/modelEvent.test.ts
new file mode 100644
index 000000000..f9fa8cc6a
--- /dev/null
+++ b/core/src/types/model/modelEvent.test.ts
@@ -0,0 +1,7 @@
+
+
+ import { ModelEvent } from './modelEvent';
+
+ test('testOnModelInit', () => {
+ expect(ModelEvent.OnModelInit).toBe('OnModelInit');
+ });
diff --git a/joi/jest.config.js b/joi/jest.config.js
index 8543f24e3..676042491 100644
--- a/joi/jest.config.js
+++ b/joi/jest.config.js
@@ -3,6 +3,7 @@ module.exports = {
testEnvironment: 'node',
roots: ['/src'],
testMatch: ['**/*.test.*'],
+ collectCoverageFrom: ['src/**/*.{ts,tsx}'],
setupFilesAfterEnv: ['/jest.setup.js'],
testEnvironment: 'jsdom',
}
diff --git a/package.json b/package.json
index 2785ee3b5..255dda6c7 100644
--- a/package.json
+++ b/package.json
@@ -20,7 +20,7 @@
"scripts": {
"lint": "yarn workspace jan lint && yarn workspace @janhq/web lint",
"test:unit": "jest",
- "test:coverage": "jest --coverage --collectCoverageFrom='src/**/*.{ts,tsx}'",
+ "test:coverage": "jest --coverage",
"test": "yarn workspace jan test:e2e",
"test-local": "yarn lint && yarn build:test && yarn test",
"pre-install:darwin": "find extensions -type f -path \"**/*.tgz\" -exec cp {} pre-install \\;",
diff --git a/testRunner.js b/testRunner.js
new file mode 100644
index 000000000..1067f05a3
--- /dev/null
+++ b/testRunner.js
@@ -0,0 +1,19 @@
+const jestRunner = require('jest-runner')
+
+class EmptyTestFileRunner extends jestRunner.default {
+ async runTests(tests, watcher, onStart, onResult, onFailure, options) {
+ const nonEmptyTests = tests.filter(
+ (test) => test.context.hasteFS.getSize(test.path) > 0
+ )
+ return super.runTests(
+ nonEmptyTests,
+ watcher,
+ onStart,
+ onResult,
+ onFailure,
+ options
+ )
+ }
+}
+
+module.exports = EmptyTestFileRunner
diff --git a/web/containers/Loader/Loader.test.tsx b/web/containers/Loader/Loader.test.tsx
new file mode 100644
index 000000000..007d0eeba
--- /dev/null
+++ b/web/containers/Loader/Loader.test.tsx
@@ -0,0 +1,23 @@
+// Loader.test.tsx
+import '@testing-library/jest-dom';
+import React from 'react'
+import { render, screen } from '@testing-library/react'
+import Loader from './index'
+
+describe('Loader Component', () => {
+ it('renders without crashing', () => {
+ render( )
+ })
+
+ it('displays the correct description', () => {
+ const descriptionText = 'Loading...'
+ render( )
+ expect(screen.getByText(descriptionText)).toBeInTheDocument()
+ })
+
+ it('renders the correct number of loader elements', () => {
+ const { container } = render( )
+ const loaderElements = container.querySelectorAll('label')
+ expect(loaderElements).toHaveLength(6)
+ })
+})
diff --git a/web/extension/Extension.test.ts b/web/extension/Extension.test.ts
new file mode 100644
index 000000000..d7b4a1805
--- /dev/null
+++ b/web/extension/Extension.test.ts
@@ -0,0 +1,19 @@
+import Extension from "./Extension";
+
+test('should create an Extension instance with all properties', () => {
+ const url = 'https://example.com';
+ const name = 'Test Extension';
+ const productName = 'Test Product';
+ const active = true;
+ const description = 'Test Description';
+ const version = '1.0.0';
+
+ const extension = new Extension(url, name, productName, active, description, version);
+
+ expect(extension.url).toBe(url);
+ expect(extension.name).toBe(name);
+ expect(extension.productName).toBe(productName);
+ expect(extension.active).toBe(active);
+ expect(extension.description).toBe(description);
+ expect(extension.version).toBe(version);
+});
diff --git a/web/extension/Extension.ts b/web/extension/Extension.ts
index 9438238ca..7dfb72b43 100644
--- a/web/extension/Extension.ts
+++ b/web/extension/Extension.ts
@@ -12,13 +12,13 @@ class Extension {
url: string
/** @type {boolean} Whether the extension is activated or not. */
- active
+ active?: boolean
/** @type {string} Extension's description. */
- description
+ description?: string
/** @type {string} Extension's version. */
- version
+ version?: string
constructor(
url: string,
diff --git a/web/extension/ExtensionManager.test.ts b/web/extension/ExtensionManager.test.ts
new file mode 100644
index 000000000..58f784b07
--- /dev/null
+++ b/web/extension/ExtensionManager.test.ts
@@ -0,0 +1,131 @@
+// ExtensionManager.test.ts
+import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
+import { ExtensionManager } from './ExtensionManager'
+import Extension from './Extension'
+
+class TestExtension extends BaseExtension {
+ onLoad(): void {}
+ onUnload(): void {}
+}
+class TestEngine extends AIEngine {
+ provider: string = 'testEngine'
+ onUnload(): void {}
+}
+
+describe('ExtensionManager', () => {
+ let manager: ExtensionManager
+
+ beforeEach(() => {
+ manager = new ExtensionManager()
+ })
+
+ it('should register an extension', () => {
+ const extension = new TestExtension('', '')
+ manager.register('testExtension', extension)
+ expect(manager.getByName('testExtension')).toBe(extension)
+ })
+
+ it('should register an AI engine', () => {
+ const extension = { provider: 'testEngine' } as unknown as BaseExtension
+ manager.register('testExtension', extension)
+ expect(manager.getEngine('testEngine')).toBe(extension)
+ })
+
+ it('should retrieve an extension by type', () => {
+ const extension = new TestExtension('', '')
+ jest.spyOn(extension, 'type').mockReturnValue(ExtensionTypeEnum.Assistant)
+ manager.register('testExtension', extension)
+ expect(manager.get(ExtensionTypeEnum.Assistant)).toBe(extension)
+ })
+
+ it('should retrieve an extension by name', () => {
+ const extension = new TestExtension('', '')
+ manager.register('testExtension', extension)
+ expect(manager.getByName('testExtension')).toBe(extension)
+ })
+
+ it('should retrieve all extensions', () => {
+ const extension1 = new TestExtension('', '')
+ const extension2 = new TestExtension('', '')
+ manager.register('testExtension1', extension1)
+ manager.register('testExtension2', extension2)
+ expect(manager.getAll()).toEqual([extension1, extension2])
+ })
+
+ it('should retrieve an engine by name', () => {
+ const engine = new TestEngine('', '')
+ manager.register('anything', engine)
+ expect(manager.getEngine('testEngine')).toBe(engine)
+ })
+
+ it('should load all extensions', () => {
+ const extension = new TestExtension('', '')
+ jest.spyOn(extension, 'onLoad')
+ manager.register('testExtension', extension)
+ manager.load()
+ expect(extension.onLoad).toHaveBeenCalled()
+ })
+
+ it('should unload all extensions', () => {
+ const extension = new TestExtension('', '')
+ jest.spyOn(extension, 'onUnload')
+ manager.register('testExtension', extension)
+ manager.unload()
+ expect(extension.onUnload).toHaveBeenCalled()
+ })
+
+ it('should list all extensions', () => {
+ const extension1 = new TestExtension('', '')
+ const extension2 = new TestExtension('', '')
+ manager.register('testExtension1', extension1)
+ manager.register('testExtension2', extension2)
+ expect(manager.listExtensions()).toEqual([extension1, extension2])
+ })
+
+ it('should retrieve active extensions', async () => {
+ const extension = new Extension(
+ 'url',
+ 'name',
+ 'productName',
+ true,
+ 'description',
+ 'version'
+ )
+ window.core = {
+ api: {
+ getActiveExtensions: jest.fn(),
+ },
+ }
+ jest
+ .spyOn(window.core.api, 'getActiveExtensions')
+ .mockResolvedValue([extension])
+ const activeExtensions = await manager.getActive()
+ expect(activeExtensions).toEqual([extension])
+ })
+
+ it('should register all active extensions', async () => {
+ const extension = new Extension(
+ 'url',
+ 'name',
+ 'productName',
+ true,
+ 'description',
+ 'version'
+ )
+ jest.spyOn(manager, 'getActive').mockResolvedValue([extension])
+ jest.spyOn(manager, 'activateExtension').mockResolvedValue()
+ await manager.registerActive()
+ expect(manager.activateExtension).toHaveBeenCalledWith(extension)
+ })
+
+ it('should uninstall extensions', async () => {
+ window.core = {
+ api: {
+ uninstallExtension: jest.fn(),
+ },
+ }
+ jest.spyOn(window.core.api, 'uninstallExtension').mockResolvedValue(true)
+ const result = await manager.uninstall(['testExtension'])
+ expect(result).toBe(true)
+ })
+})
diff --git a/web/extension/index.test.ts b/web/extension/index.test.ts
new file mode 100644
index 000000000..50b6b59db
--- /dev/null
+++ b/web/extension/index.test.ts
@@ -0,0 +1,9 @@
+
+
+import { extensionManager } from './index';
+
+describe('index', () => {
+ it('should export extensionManager from ExtensionManager', () => {
+ expect(extensionManager).toBeDefined();
+ });
+});
diff --git a/web/helpers/atoms/ApiServer.atom.test.ts b/web/helpers/atoms/ApiServer.atom.test.ts
new file mode 100644
index 000000000..4c5d7fca4
--- /dev/null
+++ b/web/helpers/atoms/ApiServer.atom.test.ts
@@ -0,0 +1,9 @@
+
+import { hostOptions } from './ApiServer.atom';
+
+test('hostOptions correct values', () => {
+ expect(hostOptions).toEqual([
+ { name: '127.0.0.1', value: '127.0.0.1' },
+ { name: '0.0.0.0', value: '0.0.0.0' },
+ ]);
+});
diff --git a/web/helpers/atoms/App.atom.test.ts b/web/helpers/atoms/App.atom.test.ts
new file mode 100644
index 000000000..f3d58dfc1
--- /dev/null
+++ b/web/helpers/atoms/App.atom.test.ts
@@ -0,0 +1,8 @@
+
+import { mainViewStateAtom } from './App.atom';
+import { MainViewState } from '@/constants/screens';
+
+test('mainViewStateAtom initializes with Thread', () => {
+ const result = mainViewStateAtom.init;
+ expect(result).toBe(MainViewState.Thread);
+});
diff --git a/web/helpers/atoms/AppConfig.atom.test.ts b/web/helpers/atoms/AppConfig.atom.test.ts
new file mode 100644
index 000000000..28f085e53
--- /dev/null
+++ b/web/helpers/atoms/AppConfig.atom.test.ts
@@ -0,0 +1,7 @@
+
+import { hostAtom } from './AppConfig.atom';
+
+test('hostAtom default value', () => {
+ const result = hostAtom.init;
+ expect(result).toBe('http://localhost:1337/');
+});
diff --git a/web/helpers/atoms/Assistant.atom.test.ts b/web/helpers/atoms/Assistant.atom.test.ts
new file mode 100644
index 000000000..a5073d293
--- /dev/null
+++ b/web/helpers/atoms/Assistant.atom.test.ts
@@ -0,0 +1,8 @@
+
+import { assistantsAtom } from './Assistant.atom';
+
+test('assistantsAtom initializes as an empty array', () => {
+ const initialValue = assistantsAtom.init;
+ expect(Array.isArray(initialValue)).toBe(true);
+ expect(initialValue).toHaveLength(0);
+});
diff --git a/web/helpers/atoms/ChatMessage.atom.test.ts b/web/helpers/atoms/ChatMessage.atom.test.ts
new file mode 100644
index 000000000..6acf4283e
--- /dev/null
+++ b/web/helpers/atoms/ChatMessage.atom.test.ts
@@ -0,0 +1,32 @@
+
+import { getCurrentChatMessagesAtom } from './ChatMessage.atom';
+import { setConvoMessagesAtom, chatMessages, readyThreadsMessagesAtom } from './ChatMessage.atom';
+
+test('getCurrentChatMessagesAtom returns empty array when no active thread ID', () => {
+ const getMock = jest.fn().mockReturnValue(undefined);
+ expect(getCurrentChatMessagesAtom.read(getMock)).toEqual([]);
+});
+
+
+test('getCurrentChatMessagesAtom returns empty array when activeThreadId is undefined', () => {
+ const getMock = jest.fn().mockReturnValue({
+ activeThreadId: undefined,
+ chatMessages: {
+ threadId: [{ id: 1, content: 'message' }],
+ },
+ });
+ expect(getCurrentChatMessagesAtom.read(getMock)).toEqual([]);
+});
+
+test('setConvoMessagesAtom updates chatMessages and readyThreadsMessagesAtom', () => {
+ const getMock = jest.fn().mockReturnValue({});
+ const setMock = jest.fn();
+ const threadId = 'thread1';
+ const messages = [{ id: '1', content: 'Hello' }];
+
+ setConvoMessagesAtom.write(getMock, setMock, threadId, messages);
+
+ expect(setMock).toHaveBeenCalledWith(chatMessages, { [threadId]: messages });
+ expect(setMock).toHaveBeenCalledWith(readyThreadsMessagesAtom, { [threadId]: true });
+});
+
diff --git a/web/helpers/atoms/HuggingFace.atom.test.ts b/web/helpers/atoms/HuggingFace.atom.test.ts
new file mode 100644
index 000000000..134d19947
--- /dev/null
+++ b/web/helpers/atoms/HuggingFace.atom.test.ts
@@ -0,0 +1,14 @@
+
+import { importHuggingFaceModelStageAtom } from './HuggingFace.atom';
+import { importingHuggingFaceRepoDataAtom } from './HuggingFace.atom';
+
+test('importHuggingFaceModelStageAtom should have initial value of NONE', () => {
+ const result = importHuggingFaceModelStageAtom.init;
+ expect(result).toBe('NONE');
+});
+
+
+test('importingHuggingFaceRepoDataAtom should have initial value of undefined', () => {
+ const result = importingHuggingFaceRepoDataAtom.init;
+ expect(result).toBeUndefined();
+});
diff --git a/web/helpers/atoms/LocalServer.atom.test.ts b/web/helpers/atoms/LocalServer.atom.test.ts
new file mode 100644
index 000000000..b3c53ec07
--- /dev/null
+++ b/web/helpers/atoms/LocalServer.atom.test.ts
@@ -0,0 +1,7 @@
+
+import { serverEnabledAtom } from './LocalServer.atom';
+
+test('serverEnabledAtom_initialValue', () => {
+ const result = serverEnabledAtom.init;
+ expect(result).toBe(false);
+});
diff --git a/web/helpers/atoms/Setting.atom.test.ts b/web/helpers/atoms/Setting.atom.test.ts
new file mode 100644
index 000000000..7c5d7ce94
--- /dev/null
+++ b/web/helpers/atoms/Setting.atom.test.ts
@@ -0,0 +1,7 @@
+
+import { selectedSettingAtom } from './Setting.atom';
+
+test('selectedSettingAtom has correct initial value', () => {
+ const result = selectedSettingAtom.init;
+ expect(result).toBe('My Models');
+});
diff --git a/web/helpers/atoms/ThreadRightPanel.atom.test.ts b/web/helpers/atoms/ThreadRightPanel.atom.test.ts
new file mode 100644
index 000000000..162b059fd
--- /dev/null
+++ b/web/helpers/atoms/ThreadRightPanel.atom.test.ts
@@ -0,0 +1,6 @@
+
+import { activeTabThreadRightPanelAtom } from './ThreadRightPanel.atom';
+
+test('activeTabThreadRightPanelAtom can be imported', () => {
+ expect(activeTabThreadRightPanelAtom).toBeDefined();
+});
diff --git a/web/hooks/useDownloadState.test.ts b/web/hooks/useDownloadState.test.ts
new file mode 100644
index 000000000..893649e26
--- /dev/null
+++ b/web/hooks/useDownloadState.test.ts
@@ -0,0 +1,109 @@
+import {
+ setDownloadStateAtom,
+ modelDownloadStateAtom,
+} from './useDownloadState'
+
+// Mock dependencies
+jest.mock('jotai', () => ({
+ atom: jest.fn(),
+ useAtom: jest.fn(),
+}))
+jest.mock('@/containers/Toast', () => ({
+ toaster: jest.fn(),
+}))
+jest.mock('@/helpers/atoms/Model.atom', () => ({
+ configuredModelsAtom: jest.fn(),
+ downloadedModelsAtom: jest.fn(),
+ removeDownloadingModelAtom: jest.fn(),
+}))
+
+describe('setDownloadStateAtom', () => {
+ let get: jest.Mock
+ let set: jest.Mock
+
+ beforeEach(() => {
+ get = jest.fn()
+ set = jest.fn()
+ })
+
+ it('should handle download completion', () => {
+ const state = {
+ downloadState: 'end',
+ modelId: 'model1',
+ fileName: 'file1',
+ children: [],
+ }
+ const currentState = {
+ model1: {
+ children: [state],
+ },
+ }
+ get.mockReturnValueOnce(currentState)
+ get.mockReturnValueOnce([{ id: 'model1' }])
+
+ set(setDownloadStateAtom, state)
+
+ expect(set).toHaveBeenCalledWith(
+ undefined,
+ expect.objectContaining({ modelId: expect.stringContaining('model1') })
+ )
+ })
+
+ it('should handle download error', () => {
+ const state = {
+ downloadState: 'error',
+ modelId: 'model1',
+ error: 'some error',
+ }
+ const currentState = {
+ model1: {},
+ }
+ get.mockReturnValueOnce(currentState)
+
+ set(setDownloadStateAtom, state)
+
+ expect(set).toHaveBeenCalledWith(
+ undefined,
+ expect.objectContaining({ modelId: 'model1' })
+ )
+ })
+
+ it('should handle download error with certificate issue', () => {
+ const state = {
+ downloadState: 'error',
+ modelId: 'model1',
+ error: 'certificate error',
+ }
+ const currentState = {
+ model1: {},
+ }
+ get.mockReturnValueOnce(currentState)
+
+ set(setDownloadStateAtom, state)
+
+ expect(set).toHaveBeenCalledWith(
+ undefined,
+ expect.objectContaining({ modelId: 'model1' })
+ )
+ })
+
+ it('should handle download in progress', () => {
+ const state = {
+ downloadState: 'progress',
+ modelId: 'model1',
+ fileName: 'file1',
+ size: { total: 100, transferred: 50 },
+ }
+ const currentState = {
+ model1: {
+ children: [],
+ size: { total: 0, transferred: 0 },
+ },
+ }
+ get.mockReturnValueOnce(currentState)
+
+ set(setDownloadStateAtom, state)
+
+ expect(set).toHaveBeenCalledWith(modelDownloadStateAtom, expect.any(Object))
+ })
+})
diff --git a/web/jest.config.js b/web/jest.config.js
index 7601f1e43..8b2683e78 100644
--- a/web/jest.config.js
+++ b/web/jest.config.js
@@ -5,7 +5,6 @@ const createJestConfig = nextJest({})
// Add any custom config to be passed to Jest
const config = {
- coverageProvider: 'v8',
testEnvironment: 'jsdom',
transform: {
'^.+\\.(ts|tsx)$': 'ts-jest',
@@ -17,6 +16,8 @@ const config = {
},
// Add more setup options before each test is run
// setupFilesAfterEnv: ['/jest.setup.ts'],
+ runner: './testRunner.js',
+ collectCoverageFrom: ['./**/*.{ts,tsx}'],
}
// https://stackoverflow.com/a/72926763/5078746
diff --git a/web/screens/Settings/Advanced/index.test.tsx b/web/screens/Settings/Advanced/index.test.tsx
index 10ea810b1..e34626f6e 100644
--- a/web/screens/Settings/Advanced/index.test.tsx
+++ b/web/screens/Settings/Advanced/index.test.tsx
@@ -1,3 +1,7 @@
+/**
+ * @jest-environment jsdom
+ */
+
import React from 'react'
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import '@testing-library/jest-dom'
@@ -10,7 +14,6 @@ class ResizeObserverMock {
}
global.ResizeObserver = ResizeObserverMock
-// @ts-ignore
global.window.core = {
api: {
getAppConfigurations: () => jest.fn(),
diff --git a/web/services/appService.test.ts b/web/services/appService.test.ts
new file mode 100644
index 000000000..37053f930
--- /dev/null
+++ b/web/services/appService.test.ts
@@ -0,0 +1,30 @@
+
+import { ExtensionTypeEnum, extensionManager } from '@/extension';
+import { appService } from './appService';
+
+test('should return correct system information when monitoring extension is found', async () => {
+ const mockGpuSetting = { name: 'NVIDIA GeForce GTX 1080', memory: 8192 };
+ const mockOsInfo = { platform: 'win32', release: '10.0.19041' };
+ const mockMonitoringExtension = {
+ getGpuSetting: jest.fn().mockResolvedValue(mockGpuSetting),
+ getOsInfo: jest.fn().mockResolvedValue(mockOsInfo),
+ };
+ extensionManager.get = jest.fn().mockReturnValue(mockMonitoringExtension);
+
+ const result = await appService.systemInformation();
+
+ expect(mockMonitoringExtension.getGpuSetting).toHaveBeenCalled();
+ expect(mockMonitoringExtension.getOsInfo).toHaveBeenCalled();
+ expect(result).toEqual({ gpuSetting: mockGpuSetting, osInfo: mockOsInfo });
+});
+
+
+test('should log a warning when monitoring extension is not found', async () => {
+ const consoleWarnMock = jest.spyOn(console, 'warn').mockImplementation(() => {});
+ extensionManager.get = jest.fn().mockReturnValue(undefined);
+
+ await appService.systemInformation();
+
+ expect(consoleWarnMock).toHaveBeenCalledWith('System monitoring extension not found');
+ consoleWarnMock.mockRestore();
+});
diff --git a/web/services/eventsService.test.ts b/web/services/eventsService.test.ts
new file mode 100644
index 000000000..78b95167a
--- /dev/null
+++ b/web/services/eventsService.test.ts
@@ -0,0 +1,47 @@
+
+import { EventEmitter } from './eventsService';
+
+test('should do nothing when no handlers for event', () => {
+ const emitter = new EventEmitter();
+
+ expect(() => {
+ emitter.emit('nonExistentEvent', 'test data');
+ }).not.toThrow();
+});
+
+
+test('should call all handlers for event', () => {
+ const emitter = new EventEmitter();
+ const handler1 = jest.fn();
+ const handler2 = jest.fn();
+
+ emitter.on('testEvent', handler1);
+ emitter.on('testEvent', handler2);
+
+ emitter.emit('testEvent', 'test data');
+
+ expect(handler1).toHaveBeenCalledWith('test data');
+ expect(handler2).toHaveBeenCalledWith('test data');
+});
+
+
+test('should remove handler for event', () => {
+ const emitter = new EventEmitter();
+ const handler = jest.fn();
+
+ emitter.on('testEvent', handler);
+ emitter.off('testEvent', handler);
+
+ expect(emitter['handlers'].get('testEvent')).not.toContain(handler);
+});
+
+
+test('should add handler for event', () => {
+ const emitter = new EventEmitter();
+ const handler = jest.fn();
+
+ emitter.on('testEvent', handler);
+
+ expect(emitter['handlers'].has('testEvent')).toBe(true);
+ expect(emitter['handlers'].get('testEvent')).toContain(handler);
+});
diff --git a/web/services/extensionService.test.ts b/web/services/extensionService.test.ts
new file mode 100644
index 000000000..75bd4f78a
--- /dev/null
+++ b/web/services/extensionService.test.ts
@@ -0,0 +1,35 @@
+
+import { extensionManager } from '@/extension/ExtensionManager';
+import { ExtensionTypeEnum } from '@janhq/core';
+import { isCoreExtensionInstalled } from './extensionService';
+
+test('isCoreExtensionInstalled returns true when both extensions are installed', () => {
+ jest.spyOn(extensionManager, 'get').mockImplementation((type) => {
+ if (type === ExtensionTypeEnum.Conversational || type === ExtensionTypeEnum.Model) return {};
+ return undefined;
+ });
+
+ expect(isCoreExtensionInstalled()).toBe(true);
+});
+
+
+test('isCoreExtensionInstalled returns false when Model extension is not installed', () => {
+ jest.spyOn(extensionManager, 'get').mockImplementation((type) => {
+ if (type === ExtensionTypeEnum.Conversational) return {};
+ if (type === ExtensionTypeEnum.Model) return undefined;
+ return undefined;
+ });
+
+ expect(isCoreExtensionInstalled()).toBe(false);
+});
+
+
+test('isCoreExtensionInstalled returns false when Conversational extension is not installed', () => {
+ jest.spyOn(extensionManager, 'get').mockImplementation((type) => {
+ if (type === ExtensionTypeEnum.Conversational) return undefined;
+ if (type === ExtensionTypeEnum.Model) return {};
+ return undefined;
+ });
+
+ expect(isCoreExtensionInstalled()).toBe(false);
+});
diff --git a/web/services/restService.test.ts b/web/services/restService.test.ts
new file mode 100644
index 000000000..7782e7816
--- /dev/null
+++ b/web/services/restService.test.ts
@@ -0,0 +1,15 @@
+
+
+test('restAPI.baseApiUrl set correctly', () => {
+ const originalEnv = process.env.API_BASE_URL;
+ process.env.API_BASE_URL = 'http://test-api.com';
+
+ // Re-import to get the updated value
+ jest.resetModules();
+ const { restAPI } = require('./restService');
+
+ expect(restAPI.baseApiUrl).toBe('http://test-api.com');
+
+ // Clean up
+ process.env.API_BASE_URL = originalEnv;
+});
diff --git a/web/utils/json.test.ts b/web/utils/json.test.ts
new file mode 100644
index 000000000..47a37d5fd
--- /dev/null
+++ b/web/utils/json.test.ts
@@ -0,0 +1,22 @@
+// json.test.ts
+import { safeJsonParse } from './json';
+
+describe('safeJsonParse', () => {
+ it('should correctly parse a valid JSON string', () => {
+ const jsonString = '{"name": "John", "age": 30}';
+ const result = safeJsonParse<{ name: string; age: number }>(jsonString);
+ expect(result).toEqual({ name: 'John', age: 30 });
+ });
+
+ it('should return undefined for an invalid JSON string', () => {
+ const jsonString = '{"name": "John", "age": 30';
+ const result = safeJsonParse<{ name: string; age: number }>(jsonString);
+ expect(result).toBeUndefined();
+ });
+
+ it('should return undefined for an empty string', () => {
+ const jsonString = '';
+ const result = safeJsonParse(jsonString);
+ expect(result).toBeUndefined();
+ });
+});
\ No newline at end of file
diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts
index f1b858955..994a5bd57 100644
--- a/web/utils/modelParam.test.ts
+++ b/web/utils/modelParam.test.ts
@@ -149,6 +149,14 @@ describe('validationRules', () => {
})
})
+
+ it('should normalize invalid values for keys not listed in validationRules', () => {
+ expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
+ expect(normalizeValue('invalid_key', 123)).toBe(123)
+ expect(normalizeValue('invalid_key', true)).toBe(true)
+ expect(normalizeValue('invalid_key', false)).toBe(false)
+ })
+
describe('normalizeValue', () => {
it('should normalize ctx_len correctly', () => {
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
diff --git a/web/utils/threadMessageBuilder.test.ts b/web/utils/threadMessageBuilder.test.ts
new file mode 100644
index 000000000..d938a2e03
--- /dev/null
+++ b/web/utils/threadMessageBuilder.test.ts
@@ -0,0 +1,27 @@
+
+import { ChatCompletionRole, MessageStatus } from '@janhq/core'
+
+ import { ThreadMessageBuilder } from './threadMessageBuilder'
+ import { MessageRequestBuilder } from './messageRequestBuilder'
+
+ describe('ThreadMessageBuilder', () => {
+ it('testBuildMethod', () => {
+ const msgRequest = new MessageRequestBuilder(
+ 'type',
+ { model: 'model' },
+ { id: 'thread-id' },
+ []
+ )
+ const builder = new ThreadMessageBuilder(msgRequest)
+ const result = builder.build()
+
+ expect(result.id).toBe(msgRequest.msgId)
+ expect(result.thread_id).toBe(msgRequest.thread.id)
+ expect(result.role).toBe(ChatCompletionRole.User)
+ expect(result.status).toBe(MessageStatus.Ready)
+ expect(result.created).toBeDefined()
+ expect(result.updated).toBeDefined()
+ expect(result.object).toBe('thread.message')
+ expect(result.content).toEqual([])
+ })
+ })
From 9471481e33989e5538af87d6735d191aaafa7de9 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Fri, 20 Sep 2024 18:15:06 +0700
Subject: [PATCH 16/37] fix: thread title for remote model from first prompt
(#3712)
* fix: generate title using first prompt for remote model
* fix: disable react-hooks/exhaustive-deps
---
web/containers/Providers/EventHandler.tsx | 24 ++++++++++++++++++++---
1 file changed, 21 insertions(+), 3 deletions(-)
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index 2e4db4173..1fbcd3919 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -213,6 +213,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// Attempt to generate the title of the Thread when needed
generateThreadTitle(message, thread)
},
+ // eslint-disable-next-line react-hooks/exhaustive-deps
[setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting]
)
@@ -236,12 +237,29 @@ export default function EventHandler({ children }: { children: ReactNode }) {
return
}
- // Check model engine; we don't want to generate a title when it's not a local engine.
+ if (!activeModelRef.current) {
+ return
+ }
+
+ // Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
if (
- !activeModelRef.current ||
!localEngines.includes(activeModelRef.current?.engine as InferenceEngine)
) {
- return
+ const updatedThread: Thread = {
+ ...thread,
+ title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
+ metadata: thread.metadata,
+ }
+ return extensionManager
+ .get(ExtensionTypeEnum.Conversational)
+ ?.saveThread({
+ ...updatedThread,
+ })
+ .then(() => {
+ updateThread({
+ ...updatedThread,
+ })
+ })
}
// This is the first time message comes in on a new thread
From 3091bb0e5e895bb684950dfc5345ea1905fa25ed Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Mon, 23 Sep 2024 12:46:38 +0700
Subject: [PATCH 17/37] fix: remove title local API server page (#3710)
---
web/screens/LocalServer/LocalServerLeftPanel/index.tsx | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
index 16aa75af5..ef2c2d76c 100644
--- a/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
+++ b/web/screens/LocalServer/LocalServerLeftPanel/index.tsx
@@ -107,8 +107,7 @@ const LocalServerLeftPanel = () => {
-
Server Options
-
+
Start an OpenAI-compatible local HTTP server.
From 5b7f0c13082538548782cc3b90dd08264c2ab4cf Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Mon, 23 Sep 2024 12:46:50 +0700
Subject: [PATCH 18/37] test: update test coverage UI component Joi (#3707)
* test: update test coverage joi
* test: update test export all components
* test: update clear mock useOs
* test: remove delete global window during test case getInitialValue
* test: update getValueInEffect with mock userAgent
---
joi/src/hooks/useClickOutside/index.tsx | 36 +++---
.../useClickOutside/useClickOutside.test.tsx | 103 +++++++++++-------
joi/src/hooks/useMediaQuery/index.ts | 2 +-
.../hooks/useMediaQuery/useMediaQuery.test.ts | 72 +++++++++++-
joi/src/hooks/useOs/index.tsx | 2 +-
joi/src/hooks/useOs/useOs.test.ts | 21 +++-
joi/src/index.test.ts | 43 ++++++++
7 files changed, 224 insertions(+), 55 deletions(-)
create mode 100644 joi/src/index.test.ts
diff --git a/joi/src/hooks/useClickOutside/index.tsx b/joi/src/hooks/useClickOutside/index.tsx
index 75e2400cf..af47ba484 100644
--- a/joi/src/hooks/useClickOutside/index.tsx
+++ b/joi/src/hooks/useClickOutside/index.tsx
@@ -1,4 +1,3 @@
-/* eslint-disable @typescript-eslint/no-explicit-any */
import { useEffect, useRef } from 'react'
const DEFAULT_EVENTS = ['mousedown', 'touchstart']
@@ -8,34 +7,43 @@ export function useClickOutside(
events?: string[] | null,
nodes?: (HTMLElement | null)[]
) {
- const ref = useRef()
+ const ref = useRef(null)
useEffect(() => {
- const listener = (event: any) => {
- const { target } = event ?? {}
+ const listener = (event: Event) => {
+ const target = event.target as HTMLElement
+
+ // Check if the target or any ancestor has the data-ignore-outside-clicks attribute
+ const shouldIgnore =
+ target.closest('[data-ignore-outside-clicks]') !== null
+
if (Array.isArray(nodes)) {
- const shouldIgnore =
- target?.hasAttribute('data-ignore-outside-clicks') ||
- (!document.body.contains(target) && target.tagName !== 'HTML')
const shouldTrigger = nodes.every(
(node) => !!node && !event.composedPath().includes(node)
)
- shouldTrigger && !shouldIgnore && handler()
- } else if (ref.current && !ref.current.contains(target)) {
+ if (shouldTrigger && !shouldIgnore) {
+ handler()
+ }
+ } else if (
+ ref.current &&
+ !ref.current.contains(target) &&
+ !shouldIgnore
+ ) {
handler()
}
}
- ;(events || DEFAULT_EVENTS).forEach((fn) =>
- document.addEventListener(fn, listener)
+ const eventList = events || DEFAULT_EVENTS
+ eventList.forEach((event) =>
+ document.documentElement.addEventListener(event, listener)
)
return () => {
- ;(events || DEFAULT_EVENTS).forEach((fn) =>
- document.removeEventListener(fn, listener)
+ eventList.forEach((event) =>
+ document.documentElement.removeEventListener(event, listener)
)
}
- }, [ref, handler, nodes])
+ }, [handler, nodes, events])
return ref
}
diff --git a/joi/src/hooks/useClickOutside/useClickOutside.test.tsx b/joi/src/hooks/useClickOutside/useClickOutside.test.tsx
index ac73b280a..8997721cd 100644
--- a/joi/src/hooks/useClickOutside/useClickOutside.test.tsx
+++ b/joi/src/hooks/useClickOutside/useClickOutside.test.tsx
@@ -1,55 +1,84 @@
import React from 'react'
-import { render, fireEvent, act } from '@testing-library/react'
+import { render, screen, fireEvent, cleanup } from '@testing-library/react'
import { useClickOutside } from './index'
-// Mock component to test the hook
-const TestComponent: React.FC<{ onClickOutside: () => void }> = ({
- onClickOutside,
+const TestComponent = ({
+ handler,
+ nodes,
+}: {
+ handler: () => void
+ nodes?: (HTMLElement | null)[]
}) => {
- const ref = useClickOutside(onClickOutside)
- return }>Test
+ const ref = useClickOutside(handler, undefined, nodes)
+
+ return (
+
+ Click me
+
+ )
}
-describe('@joi/hooks/useClickOutside', () => {
- it('should call handler when clicking outside', () => {
- const handleClickOutside = jest.fn()
- const { container } = render(
-
- )
+describe('useClickOutside', () => {
+ afterEach(cleanup)
- act(() => {
- fireEvent.mouseDown(document.body)
- })
+ it('should call handler when clicking outside the element', () => {
+ const handler = jest.fn()
+ render( )
- expect(handleClickOutside).toHaveBeenCalledTimes(1)
+ fireEvent.mouseDown(document.body)
+ expect(handler).toHaveBeenCalledTimes(1)
})
- it('should not call handler when clicking inside', () => {
- const handleClickOutside = jest.fn()
- const { getByText } = render(
-
- )
+ it('should not call handler when clicking inside the element', () => {
+ const handler = jest.fn()
+ render( )
- act(() => {
- fireEvent.mouseDown(getByText('Test'))
- })
-
- expect(handleClickOutside).not.toHaveBeenCalled()
+ fireEvent.mouseDown(screen.getByTestId('clickable'))
+ expect(handler).not.toHaveBeenCalled()
})
- it('should work with custom events', () => {
- const handleClickOutside = jest.fn()
- const TestComponentWithCustomEvent: React.FC = () => {
- const ref = useClickOutside(handleClickOutside, ['click'])
- return }>Test
- }
+ it('should not call handler if target has data-ignore-outside-clicks attribute', () => {
+ const handler = jest.fn()
+ render(
+ <>
+
+ Ignore this
+ >
+ )
- render( )
+ // Ensure that the div with the attribute is correctly queried
+ fireEvent.mouseDown(screen.getByText('Ignore this'))
+ expect(handler).not.toHaveBeenCalled()
+ })
- act(() => {
- fireEvent.click(document.body)
- })
+ it('should call handler when clicking outside if nodes is an empty array', () => {
+ const handler = jest.fn()
+ render( )
- expect(handleClickOutside).toHaveBeenCalledTimes(1)
+ fireEvent.mouseDown(document.body)
+ expect(handler).toHaveBeenCalledTimes(1)
+ })
+
+ it('should not call handler if clicking inside nodes', () => {
+ const handler = jest.fn()
+ const node = document.createElement('div')
+ document.body.appendChild(node)
+
+ render(
+ <>
+
+ >
+ )
+
+ fireEvent.mouseDown(node)
+ expect(handler).not.toHaveBeenCalled()
+ })
+
+ it('should call handler if nodes is undefined', () => {
+ const handler = jest.fn()
+ render( )
+
+ fireEvent.mouseDown(document.body)
+ expect(handler).toHaveBeenCalledTimes(1)
})
})
diff --git a/joi/src/hooks/useMediaQuery/index.ts b/joi/src/hooks/useMediaQuery/index.ts
index 03010fc78..31b548db0 100644
--- a/joi/src/hooks/useMediaQuery/index.ts
+++ b/joi/src/hooks/useMediaQuery/index.ts
@@ -23,7 +23,7 @@ function attachMediaListener(
}
}
-function getInitialValue(query: string, initialValue?: boolean) {
+export function getInitialValue(query: string, initialValue?: boolean) {
if (typeof initialValue === 'boolean') {
return initialValue
}
diff --git a/joi/src/hooks/useMediaQuery/useMediaQuery.test.ts b/joi/src/hooks/useMediaQuery/useMediaQuery.test.ts
index 5813bd41d..1d0fa20be 100644
--- a/joi/src/hooks/useMediaQuery/useMediaQuery.test.ts
+++ b/joi/src/hooks/useMediaQuery/useMediaQuery.test.ts
@@ -1,5 +1,8 @@
import { renderHook, act } from '@testing-library/react'
-import { useMediaQuery } from './index'
+import { useMediaQuery, getInitialValue } from './index'
+
+const global = globalThis
+const originalWindow = global.window
describe('@joi/hooks/useMediaQuery', () => {
const matchMediaMock = jest.fn()
@@ -10,6 +13,39 @@ describe('@joi/hooks/useMediaQuery', () => {
afterEach(() => {
matchMediaMock.mockClear()
+ global.window = originalWindow
+ })
+
+ it('should return undetermined when window is undefined', () => {
+ delete (global as any).window
+ expect(getInitialValue('(max-width: 600px)', true)).toBe(true)
+ expect(getInitialValue('(max-width: 600px)', false)).toBe(false)
+ })
+
+ it('should return default return false', () => {
+ delete (global as any).window
+ expect(getInitialValue('(max-width: 600px)')).toBe(false)
+ })
+
+ it('should return matchMedia result when window is defined and matchMedia exists', () => {
+ // Mock window.matchMedia
+ const matchMediaMock = jest.fn().mockImplementation((query) => ({
+ matches: query === '(max-width: 600px)',
+ media: query,
+ addEventListener: jest.fn(),
+ removeEventListener: jest.fn(),
+ }))
+
+ // Mock window and matchMedia
+ ;(global as any).window = { matchMedia: matchMediaMock }
+
+ // Test the function behavior
+ expect(getInitialValue('(max-width: 600px)')).toBe(true) // Query should match
+ expect(matchMediaMock).toHaveBeenCalledWith('(max-width: 600px)')
+
+ // Test with a non-matching query
+ expect(getInitialValue('(min-width: 1200px)')).toBe(false) // Query should not match
+ expect(matchMediaMock).toHaveBeenCalledWith('(min-width: 1200px)')
})
it('should return initial value when getInitialValueInEffect is true', () => {
@@ -87,4 +123,38 @@ describe('@joi/hooks/useMediaQuery', () => {
expect(result.current).toBe(true)
})
+
+ it('should return undefined when matchMedia is not available', () => {
+ delete (global as any).window.matchMedia
+
+ const { result } = renderHook(() => useMediaQuery('(max-width: 600px)'))
+ expect(result.current).toBe(undefined)
+ })
+
+ it('should use initialValue when getInitialValueInEffect is true', () => {
+ const { result } = renderHook(() =>
+ useMediaQuery('(max-width: 600px)', true, {
+ getInitialValueInEffect: true,
+ })
+ )
+ expect(result.current).toBe(true)
+ })
+
+ it('should use getInitialValue when getInitialValueInEffect is false', () => {
+ const { result } = renderHook(() =>
+ useMediaQuery('(max-width: 600px)', undefined, {
+ getInitialValueInEffect: false,
+ })
+ )
+ expect(result.current).toBe(false)
+ })
+
+ it('should use initialValue as false when getInitialValueInEffect is true', () => {
+ const { result } = renderHook(() =>
+ useMediaQuery('(max-width: 600px)', false, {
+ getInitialValueInEffect: true,
+ })
+ )
+ expect(result.current).toBe(false)
+ })
})
diff --git a/joi/src/hooks/useOs/index.tsx b/joi/src/hooks/useOs/index.tsx
index fb7fd9028..12e3d2410 100644
--- a/joi/src/hooks/useOs/index.tsx
+++ b/joi/src/hooks/useOs/index.tsx
@@ -8,7 +8,7 @@ export type OS =
| 'android'
| 'linux'
-function getOS(): OS {
+export function getOS(): OS {
if (typeof window === 'undefined') {
return 'undetermined'
}
diff --git a/joi/src/hooks/useOs/useOs.test.ts b/joi/src/hooks/useOs/useOs.test.ts
index 037640b5e..b66ad1519 100644
--- a/joi/src/hooks/useOs/useOs.test.ts
+++ b/joi/src/hooks/useOs/useOs.test.ts
@@ -1,5 +1,6 @@
import { renderHook } from '@testing-library/react'
-import { useOs } from './index'
+import { useOs, getOS } from './index'
+import '@testing-library/jest-dom'
const platforms = {
windows: [
@@ -21,10 +22,28 @@ const platforms = {
} as const
describe('@joi/hooks/useOS', () => {
+ const global = globalThis
+ const originalWindow = global.window
+
afterEach(() => {
+ global.window = originalWindow
jest.clearAllMocks()
})
+ it('should return undetermined when window is undefined', () => {
+ delete (global as any).window
+ expect(getOS()).toBe('undetermined')
+ })
+
+ it('should return undetermined when getValueInEffect is false', () => {
+ jest
+ .spyOn(window.navigator, 'userAgent', 'get')
+ .mockReturnValueOnce('UNKNOWN_USER_AGENT')
+
+ const { result } = renderHook(() => useOs({ getValueInEffect: false }))
+ expect(result.current).toBe('undetermined')
+ })
+
Object.entries(platforms).forEach(([os, userAgents]) => {
it.each(userAgents)(`should detect %s platform on ${os}`, (userAgent) => {
jest
diff --git a/joi/src/index.test.ts b/joi/src/index.test.ts
new file mode 100644
index 000000000..8bfba8d93
--- /dev/null
+++ b/joi/src/index.test.ts
@@ -0,0 +1,43 @@
+import * as components from './index'
+
+// Mock styles globally for all components in this test
+jest.mock('./core/Tooltip/styles.scss', () => ({}))
+jest.mock('./core/ScrollArea/styles.scss', () => ({}))
+jest.mock('./core/Button/styles.scss', () => ({}))
+jest.mock('./core/Switch/styles.scss', () => ({}))
+jest.mock('./core/Progress/styles.scss', () => ({}))
+jest.mock('./core/Checkbox/styles.scss', () => ({}))
+jest.mock('./core/Badge/styles.scss', () => ({}))
+jest.mock('./core/Modal/styles.scss', () => ({}))
+jest.mock('./core/Slider/styles.scss', () => ({}))
+jest.mock('./core/Input/styles.scss', () => ({}))
+jest.mock('./core/Select/styles.scss', () => ({}))
+jest.mock('./core/TextArea/styles.scss', () => ({}))
+jest.mock('./core/Tabs/styles.scss', () => ({}))
+jest.mock('./core/Accordion/styles.scss', () => ({}))
+
+describe('Exports', () => {
+ it('exports all components and hooks', () => {
+ expect(components.Tooltip).toBeDefined()
+ expect(components.ScrollArea).toBeDefined()
+ expect(components.Button).toBeDefined()
+ expect(components.Switch).toBeDefined()
+ expect(components.Progress).toBeDefined()
+ expect(components.Checkbox).toBeDefined()
+ expect(components.Badge).toBeDefined()
+ expect(components.Modal).toBeDefined()
+ expect(components.Slider).toBeDefined()
+ expect(components.Input).toBeDefined()
+ expect(components.Select).toBeDefined()
+ expect(components.TextArea).toBeDefined()
+ expect(components.Tabs).toBeDefined()
+ expect(components.Accordion).toBeDefined()
+
+ expect(components.useClipboard).toBeDefined()
+ expect(components.usePageLeave).toBeDefined()
+ expect(components.useTextSelection).toBeDefined()
+ expect(components.useClickOutside).toBeDefined()
+ expect(components.useOs).toBeDefined()
+ expect(components.useMediaQuery).toBeDefined()
+ })
+})
From c5e0c93ab45bdbbc9ddec10beb6bd9f53ed0803e Mon Sep 17 00:00:00 2001
From: Louis
Date: Mon, 23 Sep 2024 13:54:52 +0700
Subject: [PATCH 19/37] test: add missing tests (#3716)
---
core/jest.config.js | 8 ++
core/src/browser/extensions/assistant.test.ts | 8 ++
core/src/browser/extensions/inference.test.ts | 45 ++++++++++++
.../src/browser/extensions/monitoring.test.ts | 42 +++++++++++
core/src/browser/tools/index.test.ts | 5 ++
core/src/browser/tools/tool.test.ts | 10 ++-
core/src/index.test.ts | 7 ++
core/src/node/api/index.test.ts | 7 ++
.../src/node/api/processors/Processor.test.ts | 6 ++
.../src/node/api/processors/extension.test.ts | 31 ++++++++
.../node/api/restful/helper/consts.test.ts | 6 ++
core/src/node/helper/config.test.ts | 14 ++++
core/src/types/message/messageEntity.test.ts | 9 +++
.../miscellaneous/systemResourceInfo.test.ts | 6 ++
core/src/types/monitoring/index.test.ts | 16 ++++
core/src/types/setting/index.test.ts | 5 ++
.../types/setting/settingComponent.test.ts | 19 +++++
core/src/types/thread/threadEvent.test.ts | 6 ++
electron/jest.config.js | 18 +++++
electron/testRunner.js | 10 +++
web/jest.config.js | 8 ++
web/utils/base64.test.ts | 8 ++
web/utils/converter.test.ts | 33 +++++++++
web/utils/modelParam.test.ts | 19 +++++
web/utils/threadMessageBuilder.test.ts | 73 +++++++++++++++++++
25 files changed, 418 insertions(+), 1 deletion(-)
create mode 100644 core/src/browser/extensions/assistant.test.ts
create mode 100644 core/src/browser/extensions/inference.test.ts
create mode 100644 core/src/browser/extensions/monitoring.test.ts
create mode 100644 core/src/browser/tools/index.test.ts
create mode 100644 core/src/index.test.ts
create mode 100644 core/src/node/api/index.test.ts
create mode 100644 core/src/node/api/processors/Processor.test.ts
create mode 100644 core/src/node/api/restful/helper/consts.test.ts
create mode 100644 core/src/types/message/messageEntity.test.ts
create mode 100644 core/src/types/miscellaneous/systemResourceInfo.test.ts
create mode 100644 core/src/types/monitoring/index.test.ts
create mode 100644 core/src/types/setting/index.test.ts
create mode 100644 core/src/types/setting/settingComponent.test.ts
create mode 100644 core/src/types/thread/threadEvent.test.ts
create mode 100644 electron/jest.config.js
create mode 100644 electron/testRunner.js
create mode 100644 web/utils/base64.test.ts
create mode 100644 web/utils/converter.test.ts
diff --git a/core/jest.config.js b/core/jest.config.js
index 2f652dd39..9b1dd2ade 100644
--- a/core/jest.config.js
+++ b/core/jest.config.js
@@ -6,4 +6,12 @@ module.exports = {
'@/(.*)': '/src/$1',
},
runner: './testRunner.js',
+ transform: {
+ "^.+\\.tsx?$": [
+ "ts-jest",
+ {
+ diagnostics: false,
+ },
+ ],
+ },
}
diff --git a/core/src/browser/extensions/assistant.test.ts b/core/src/browser/extensions/assistant.test.ts
new file mode 100644
index 000000000..ae81b0985
--- /dev/null
+++ b/core/src/browser/extensions/assistant.test.ts
@@ -0,0 +1,8 @@
+
+import { AssistantExtension } from './assistant';
+import { ExtensionTypeEnum } from '../extension';
+
+it('should return the correct type', () => {
+ const extension = new AssistantExtension();
+ expect(extension.type()).toBe(ExtensionTypeEnum.Assistant);
+});
diff --git a/core/src/browser/extensions/inference.test.ts b/core/src/browser/extensions/inference.test.ts
new file mode 100644
index 000000000..45ec9d172
--- /dev/null
+++ b/core/src/browser/extensions/inference.test.ts
@@ -0,0 +1,45 @@
+import { MessageRequest, ThreadMessage } from '../../types'
+import { BaseExtension, ExtensionTypeEnum } from '../extension'
+import { InferenceExtension } from './'
+
+// Mock the MessageRequest and ThreadMessage types
+type MockMessageRequest = {
+ text: string
+}
+
+type MockThreadMessage = {
+ text: string
+ userId: string
+}
+
+// Mock the BaseExtension class
+class MockBaseExtension extends BaseExtension {
+ type(): ExtensionTypeEnum | undefined {
+ return ExtensionTypeEnum.Base
+ }
+}
+
+// Create a mock implementation of InferenceExtension
+class MockInferenceExtension extends InferenceExtension {
+ async inference(data: MessageRequest): Promise {
+ return { text: 'Mock response', userId: '123' } as unknown as ThreadMessage
+ }
+}
+
+describe('InferenceExtension', () => {
+ let inferenceExtension: InferenceExtension
+
+ beforeEach(() => {
+ inferenceExtension = new MockInferenceExtension()
+ })
+
+ it('should have the correct type', () => {
+ expect(inferenceExtension.type()).toBe(ExtensionTypeEnum.Inference)
+ })
+
+ it('should implement the inference method', async () => {
+ const messageRequest: MessageRequest = { text: 'Hello' } as unknown as MessageRequest
+ const result = await inferenceExtension.inference(messageRequest)
+ expect(result).toEqual({ text: 'Mock response', userId: '123' } as unknown as ThreadMessage)
+ })
+})
diff --git a/core/src/browser/extensions/monitoring.test.ts b/core/src/browser/extensions/monitoring.test.ts
new file mode 100644
index 000000000..9bba89a8c
--- /dev/null
+++ b/core/src/browser/extensions/monitoring.test.ts
@@ -0,0 +1,42 @@
+
+import { ExtensionTypeEnum } from '../extension';
+import { MonitoringExtension } from './monitoring';
+
+it('should have the correct type', () => {
+ class TestMonitoringExtension extends MonitoringExtension {
+ getGpuSetting(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getResourcesInfo(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getCurrentLoad(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getOsInfo(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ }
+ const monitoringExtension = new TestMonitoringExtension();
+ expect(monitoringExtension.type()).toBe(ExtensionTypeEnum.SystemMonitoring);
+});
+
+
+it('should create an instance of MonitoringExtension', () => {
+ class TestMonitoringExtension extends MonitoringExtension {
+ getGpuSetting(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getResourcesInfo(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getCurrentLoad(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ getOsInfo(): Promise {
+ throw new Error('Method not implemented.');
+ }
+ }
+ const monitoringExtension = new TestMonitoringExtension();
+ expect(monitoringExtension).toBeInstanceOf(MonitoringExtension);
+});
diff --git a/core/src/browser/tools/index.test.ts b/core/src/browser/tools/index.test.ts
new file mode 100644
index 000000000..8a24d3bb6
--- /dev/null
+++ b/core/src/browser/tools/index.test.ts
@@ -0,0 +1,5 @@
+
+
+it('should not throw any errors when imported', () => {
+ expect(() => require('./index')).not.toThrow();
+})
diff --git a/core/src/browser/tools/tool.test.ts b/core/src/browser/tools/tool.test.ts
index ba918a3cb..dcb478478 100644
--- a/core/src/browser/tools/tool.test.ts
+++ b/core/src/browser/tools/tool.test.ts
@@ -52,4 +52,12 @@ it('should skip processing for disabled tools', async () => {
const result = await manager.process(request, tools)
expect(result).toBe(request)
-})
\ No newline at end of file
+})
+
+it('should throw an error when process is called without implementation', () => {
+ class TestTool extends InferenceTool {
+ name = 'testTool'
+ }
+ const tool = new TestTool()
+ expect(() => tool.process({} as MessageRequest)).toThrowError()
+})
diff --git a/core/src/index.test.ts b/core/src/index.test.ts
new file mode 100644
index 000000000..a1bd7c6b9
--- /dev/null
+++ b/core/src/index.test.ts
@@ -0,0 +1,7 @@
+
+
+it('should declare global object core when importing the module and then deleting it', () => {
+ import('./index');
+ delete globalThis.core;
+ expect(typeof globalThis.core).toBe('undefined');
+});
diff --git a/core/src/node/api/index.test.ts b/core/src/node/api/index.test.ts
new file mode 100644
index 000000000..c35d6e792
--- /dev/null
+++ b/core/src/node/api/index.test.ts
@@ -0,0 +1,7 @@
+
+import * as restfulV1 from './restful/v1';
+
+it('should re-export from restful/v1', () => {
+ const restfulV1Exports = require('./restful/v1');
+ expect(restfulV1Exports).toBeDefined();
+})
diff --git a/core/src/node/api/processors/Processor.test.ts b/core/src/node/api/processors/Processor.test.ts
new file mode 100644
index 000000000..fd913c481
--- /dev/null
+++ b/core/src/node/api/processors/Processor.test.ts
@@ -0,0 +1,6 @@
+
+import { Processor } from './Processor';
+
+it('should be defined', () => {
+ expect(Processor).toBeDefined();
+});
diff --git a/core/src/node/api/processors/extension.test.ts b/core/src/node/api/processors/extension.test.ts
index 917883499..2067c5c42 100644
--- a/core/src/node/api/processors/extension.test.ts
+++ b/core/src/node/api/processors/extension.test.ts
@@ -7,3 +7,34 @@ it('should call function associated with key in process method', () => {
extension.process('testKey', 'arg1', 'arg2');
expect(mockFunc).toHaveBeenCalledWith('arg1', 'arg2');
});
+
+
+it('should_handle_empty_extension_list_for_install', async () => {
+ jest.mock('../../extension/store', () => ({
+ installExtensions: jest.fn(() => Promise.resolve([])),
+ }));
+ const extension = new Extension();
+ const result = await extension.installExtension([]);
+ expect(result).toEqual([]);
+});
+
+
+it('should_handle_empty_extension_list_for_update', async () => {
+ jest.mock('../../extension/store', () => ({
+ getExtension: jest.fn(() => ({ update: jest.fn(() => Promise.resolve(true)) })),
+ }));
+ const extension = new Extension();
+ const result = await extension.updateExtension([]);
+ expect(result).toEqual([]);
+});
+
+
+it('should_handle_empty_extension_list', async () => {
+ jest.mock('../../extension/store', () => ({
+ getExtension: jest.fn(() => ({ uninstall: jest.fn(() => Promise.resolve(true)) })),
+ removeExtension: jest.fn(),
+ }));
+ const extension = new Extension();
+ const result = await extension.uninstallExtension([]);
+ expect(result).toBe(true);
+});
diff --git a/core/src/node/api/restful/helper/consts.test.ts b/core/src/node/api/restful/helper/consts.test.ts
new file mode 100644
index 000000000..34d42dcf0
--- /dev/null
+++ b/core/src/node/api/restful/helper/consts.test.ts
@@ -0,0 +1,6 @@
+
+import { NITRO_DEFAULT_PORT } from './consts';
+
+it('should test NITRO_DEFAULT_PORT', () => {
+ expect(NITRO_DEFAULT_PORT).toBe(3928);
+});
diff --git a/core/src/node/helper/config.test.ts b/core/src/node/helper/config.test.ts
index 201a98141..d46750d5f 100644
--- a/core/src/node/helper/config.test.ts
+++ b/core/src/node/helper/config.test.ts
@@ -1,6 +1,8 @@
import { getEngineConfiguration } from './config';
import { getAppConfigurations, defaultAppConfig } from './config';
+import { getJanExtensionsPath } from './config';
+import { getJanDataFolderPath } from './config';
it('should return undefined for invalid engine ID', async () => {
const config = await getEngineConfiguration('invalid_engine');
expect(config).toBeUndefined();
@@ -12,3 +14,15 @@ it('should return default config when CI is e2e', () => {
const config = getAppConfigurations();
expect(config).toEqual(defaultAppConfig());
});
+
+
+it('should return extensions path when retrieved successfully', () => {
+ const extensionsPath = getJanExtensionsPath();
+ expect(extensionsPath).not.toBeUndefined();
+});
+
+
+it('should return data folder path when retrieved successfully', () => {
+ const dataFolderPath = getJanDataFolderPath();
+ expect(dataFolderPath).not.toBeUndefined();
+});
diff --git a/core/src/types/message/messageEntity.test.ts b/core/src/types/message/messageEntity.test.ts
new file mode 100644
index 000000000..1d41d129a
--- /dev/null
+++ b/core/src/types/message/messageEntity.test.ts
@@ -0,0 +1,9 @@
+
+import { MessageStatus } from './messageEntity';
+
+it('should have correct values', () => {
+ expect(MessageStatus.Ready).toBe('ready');
+ expect(MessageStatus.Pending).toBe('pending');
+ expect(MessageStatus.Error).toBe('error');
+ expect(MessageStatus.Stopped).toBe('stopped');
+})
diff --git a/core/src/types/miscellaneous/systemResourceInfo.test.ts b/core/src/types/miscellaneous/systemResourceInfo.test.ts
new file mode 100644
index 000000000..35a459f0e
--- /dev/null
+++ b/core/src/types/miscellaneous/systemResourceInfo.test.ts
@@ -0,0 +1,6 @@
+
+import { SupportedPlatforms } from './systemResourceInfo';
+
+it('should contain the correct values', () => {
+ expect(SupportedPlatforms).toEqual(['win32', 'linux', 'darwin']);
+});
diff --git a/core/src/types/monitoring/index.test.ts b/core/src/types/monitoring/index.test.ts
new file mode 100644
index 000000000..010fcb97a
--- /dev/null
+++ b/core/src/types/monitoring/index.test.ts
@@ -0,0 +1,16 @@
+
+import * as monitoringInterface from './monitoringInterface';
+import * as resourceInfo from './resourceInfo';
+
+ import * as index from './index';
+ import * as monitoringInterface from './monitoringInterface';
+ import * as resourceInfo from './resourceInfo';
+
+ it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
+ for (const key in monitoringInterface) {
+ expect(index[key]).toBe(monitoringInterface[key]);
+ }
+ for (const key in resourceInfo) {
+ expect(index[key]).toBe(resourceInfo[key]);
+ }
+ });
diff --git a/core/src/types/setting/index.test.ts b/core/src/types/setting/index.test.ts
new file mode 100644
index 000000000..699adfe4f
--- /dev/null
+++ b/core/src/types/setting/index.test.ts
@@ -0,0 +1,5 @@
+
+
+it('should not throw any errors', () => {
+ expect(() => require('./index')).not.toThrow();
+});
diff --git a/core/src/types/setting/settingComponent.test.ts b/core/src/types/setting/settingComponent.test.ts
new file mode 100644
index 000000000..c56550e19
--- /dev/null
+++ b/core/src/types/setting/settingComponent.test.ts
@@ -0,0 +1,19 @@
+
+import { createSettingComponent } from './settingComponent';
+
+ it('should throw an error when creating a setting component with invalid controller type', () => {
+ const props: SettingComponentProps = {
+ key: 'invalidControllerKey',
+ title: 'Invalid Controller Title',
+ description: 'Invalid Controller Description',
+ controllerType: 'invalid' as any,
+ controllerProps: {
+ placeholder: 'Enter text',
+ value: 'Initial Value',
+ type: 'text',
+ textAlign: 'left',
+ inputActions: ['unobscure'],
+ },
+ };
+ expect(() => createSettingComponent(props)).toThrowError();
+ });
diff --git a/core/src/types/thread/threadEvent.test.ts b/core/src/types/thread/threadEvent.test.ts
new file mode 100644
index 000000000..f892f1050
--- /dev/null
+++ b/core/src/types/thread/threadEvent.test.ts
@@ -0,0 +1,6 @@
+
+import { ThreadEvent } from './threadEvent';
+
+it('should have the correct values', () => {
+ expect(ThreadEvent.OnThreadStarted).toBe('OnThreadStarted');
+});
diff --git a/electron/jest.config.js b/electron/jest.config.js
new file mode 100644
index 000000000..ec5968ccd
--- /dev/null
+++ b/electron/jest.config.js
@@ -0,0 +1,18 @@
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ collectCoverageFrom: ['src/**/*.{ts,tsx}'],
+ modulePathIgnorePatterns: ['/tests'],
+ moduleNameMapper: {
+ '@/(.*)': '/src/$1',
+ },
+ runner: './testRunner.js',
+ transform: {
+ '^.+\\.tsx?$': [
+ 'ts-jest',
+ {
+ diagnostics: false,
+ },
+ ],
+ },
+}
diff --git a/electron/testRunner.js b/electron/testRunner.js
new file mode 100644
index 000000000..b0d108160
--- /dev/null
+++ b/electron/testRunner.js
@@ -0,0 +1,10 @@
+const jestRunner = require('jest-runner');
+
+class EmptyTestFileRunner extends jestRunner.default {
+ async runTests(tests, watcher, onStart, onResult, onFailure, options) {
+ const nonEmptyTests = tests.filter(test => test.context.hasteFS.getSize(test.path) > 0);
+ return super.runTests(nonEmptyTests, watcher, onStart, onResult, onFailure, options);
+ }
+}
+
+module.exports = EmptyTestFileRunner;
\ No newline at end of file
diff --git a/web/jest.config.js b/web/jest.config.js
index 8b2683e78..7d2bee9ee 100644
--- a/web/jest.config.js
+++ b/web/jest.config.js
@@ -18,6 +18,14 @@ const config = {
// setupFilesAfterEnv: ['/jest.setup.ts'],
runner: './testRunner.js',
collectCoverageFrom: ['./**/*.{ts,tsx}'],
+ transform: {
+ "^.+\\.tsx?$": [
+ "ts-jest",
+ {
+ diagnostics: false,
+ },
+ ],
+ },
}
// https://stackoverflow.com/a/72926763/5078746
diff --git a/web/utils/base64.test.ts b/web/utils/base64.test.ts
new file mode 100644
index 000000000..1067970d4
--- /dev/null
+++ b/web/utils/base64.test.ts
@@ -0,0 +1,8 @@
+
+import { getBase64 } from './base64';
+
+test('getBase64_converts_file_to_base64', async () => {
+ const file = new File(['test'], 'test.txt', { type: 'text/plain' });
+ const base64String = await getBase64(file);
+ expect(base64String).toBe('data:text/plain;base64,dGVzdA==');
+});
diff --git a/web/utils/converter.test.ts b/web/utils/converter.test.ts
new file mode 100644
index 000000000..e86923b30
--- /dev/null
+++ b/web/utils/converter.test.ts
@@ -0,0 +1,33 @@
+
+import { formatDownloadSpeed } from './converter';
+import { formatExtensionsName } from './converter';
+import { formatTwoDigits } from './converter';
+
+ test('formatDownloadSpeed_should_return_correct_output_when_input_is_undefined', () => {
+ expect(formatDownloadSpeed(undefined)).toBe('0B/s');
+ });
+
+
+ test('formatExtensionsName_should_return_correct_output_for_string_with_janhq_and_dash', () => {
+ expect(formatExtensionsName('@janhq/extension-name')).toBe('extension name');
+ });
+
+
+ test('formatTwoDigits_should_return_correct_output_for_single_digit_number', () => {
+ expect(formatTwoDigits(5)).toBe('5.00');
+ });
+
+
+ test('formatDownloadSpeed_should_return_correct_output_for_gigabytes', () => {
+ expect(formatDownloadSpeed(1500000000)).toBe('1.40GB/s');
+ });
+
+
+ test('formatDownloadSpeed_should_return_correct_output_for_megabytes', () => {
+ expect(formatDownloadSpeed(1500000)).toBe('1.43MB/s');
+ });
+
+
+ test('formatDownloadSpeed_should_return_correct_output_for_kilobytes', () => {
+ expect(formatDownloadSpeed(1500)).toBe('1.46KB/s');
+ });
diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts
index 994a5bd57..97325d277 100644
--- a/web/utils/modelParam.test.ts
+++ b/web/utils/modelParam.test.ts
@@ -1,5 +1,7 @@
// web/utils/modelParam.test.ts
import { normalizeValue, validationRules } from './modelParam'
+import { extractModelLoadParams } from './modelParam';
+import { extractInferenceParams } from './modelParam';
describe('validationRules', () => {
it('should validate temperature correctly', () => {
@@ -189,3 +191,20 @@ describe('normalizeValue', () => {
expect(normalizeValue('cpu_threads', 0)).toBe(0)
})
})
+
+
+ it('should handle invalid values correctly by falling back to originParams', () => {
+ const modelParams = { temperature: 'invalid', token_limit: -1 };
+ const originParams = { temperature: 0.5, token_limit: 100 };
+ expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams);
+ });
+
+
+ it('should return an empty object when no modelParams are provided', () => {
+ expect(extractModelLoadParams()).toEqual({});
+ });
+
+
+ it('should return an empty object when no modelParams are provided', () => {
+ expect(extractInferenceParams()).toEqual({});
+ });
diff --git a/web/utils/threadMessageBuilder.test.ts b/web/utils/threadMessageBuilder.test.ts
index d938a2e03..cc192a5c1 100644
--- a/web/utils/threadMessageBuilder.test.ts
+++ b/web/utils/threadMessageBuilder.test.ts
@@ -4,6 +4,7 @@ import { ChatCompletionRole, MessageStatus } from '@janhq/core'
import { ThreadMessageBuilder } from './threadMessageBuilder'
import { MessageRequestBuilder } from './messageRequestBuilder'
+import { ContentType } from '@janhq/core';
describe('ThreadMessageBuilder', () => {
it('testBuildMethod', () => {
const msgRequest = new MessageRequestBuilder(
@@ -25,3 +26,75 @@ import { ChatCompletionRole, MessageStatus } from '@janhq/core'
expect(result.content).toEqual([])
})
})
+
+ it('testPushMessageWithPromptOnly', () => {
+ const msgRequest = new MessageRequestBuilder(
+ 'type',
+ { model: 'model' },
+ { id: 'thread-id' },
+ []
+ );
+ const builder = new ThreadMessageBuilder(msgRequest);
+ const prompt = 'test prompt';
+ builder.pushMessage(prompt, undefined, []);
+ expect(builder.content).toEqual([
+ {
+ type: ContentType.Text,
+ text: {
+ value: prompt,
+ annotations: [],
+ },
+ },
+ ]);
+ });
+
+
+ it('testPushMessageWithPdf', () => {
+ const msgRequest = new MessageRequestBuilder(
+ 'type',
+ { model: 'model' },
+ { id: 'thread-id' },
+ []
+ );
+ const builder = new ThreadMessageBuilder(msgRequest);
+ const prompt = 'test prompt';
+ const base64 = 'test base64';
+ const fileUpload = [{ type: 'pdf', file: { name: 'test.pdf', size: 1000 } }];
+ builder.pushMessage(prompt, base64, fileUpload);
+ expect(builder.content).toEqual([
+ {
+ type: ContentType.Pdf,
+ text: {
+ value: prompt,
+ annotations: [base64],
+ name: fileUpload[0].file.name,
+ size: fileUpload[0].file.size,
+ },
+ },
+ ]);
+ });
+
+
+ it('testPushMessageWithImage', () => {
+ const msgRequest = new MessageRequestBuilder(
+ 'type',
+ { model: 'model' },
+ { id: 'thread-id' },
+ []
+ );
+ const builder = new ThreadMessageBuilder(msgRequest);
+ const prompt = 'test prompt';
+ const base64 = 'test base64';
+ const fileUpload = [{ type: 'image', file: { name: 'test.jpg', size: 1000 } }];
+ builder.pushMessage(prompt, base64, fileUpload);
+ expect(builder.content).toEqual([
+ {
+ type: ContentType.Image,
+ text: {
+ value: prompt,
+ annotations: [base64],
+ },
+ },
+ ]);
+ });
+
From aee86243388e5885a368bd1df6aa717888c3f051 Mon Sep 17 00:00:00 2001
From: Louis
Date: Mon, 23 Sep 2024 14:20:01 +0700
Subject: [PATCH 20/37] fix: #3693 broken thread.json should not break the
entire threads (#3709)
* fix: #3693 broken thread.json should not break the entire threads
* test: add tests
---
.../conversational-extension/jest.config.js | 5 +
.../conversational-extension/package.json | 1 +
.../src/Conversational.test.ts | 408 ++++++++++++++++++
.../conversational-extension/src/index.ts | 13 +-
.../conversational-extension/src/jsonUtil.ts | 14 +
.../conversational-extension/tsconfig.json | 3 +-
.../model-extension/src/helpers/path.test.ts | 87 ++++
.../model-extension/src/helpers/path.ts | 2 +
jest.config.js | 9 +-
web/containers/ErrorMessage/index.test.tsx | 107 +++++
web/containers/ListContainer/index.test.tsx | 69 +++
web/containers/ListContainer/index.tsx | 2 +-
.../Loader/GenerateResponse.test.tsx | 75 ++++
web/containers/Loader/ModelReload.test.tsx | 124 ++++++
web/containers/Loader/ProgressCircle.test.tsx | 22 +
web/containers/Loader/ProgressCircle.tsx | 1 +
.../ModalTroubleShoot/AppLogs.test.tsx | 105 +++++
.../AssistantSetting/index.test.tsx | 137 ++++++
18 files changed, 1176 insertions(+), 8 deletions(-)
create mode 100644 extensions/conversational-extension/jest.config.js
create mode 100644 extensions/conversational-extension/src/Conversational.test.ts
create mode 100644 extensions/conversational-extension/src/jsonUtil.ts
create mode 100644 extensions/model-extension/src/helpers/path.test.ts
create mode 100644 web/containers/ErrorMessage/index.test.tsx
create mode 100644 web/containers/ListContainer/index.test.tsx
create mode 100644 web/containers/Loader/GenerateResponse.test.tsx
create mode 100644 web/containers/Loader/ModelReload.test.tsx
create mode 100644 web/containers/Loader/ProgressCircle.test.tsx
create mode 100644 web/containers/ModalTroubleShoot/AppLogs.test.tsx
create mode 100644 web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx
diff --git a/extensions/conversational-extension/jest.config.js b/extensions/conversational-extension/jest.config.js
new file mode 100644
index 000000000..8bb37208d
--- /dev/null
+++ b/extensions/conversational-extension/jest.config.js
@@ -0,0 +1,5 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+}
diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json
index d062ce9c3..036fcfab2 100644
--- a/extensions/conversational-extension/package.json
+++ b/extensions/conversational-extension/package.json
@@ -7,6 +7,7 @@
"author": "Jan ",
"license": "MIT",
"scripts": {
+ "test": "jest",
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
},
diff --git a/extensions/conversational-extension/src/Conversational.test.ts b/extensions/conversational-extension/src/Conversational.test.ts
new file mode 100644
index 000000000..3d1d6fc60
--- /dev/null
+++ b/extensions/conversational-extension/src/Conversational.test.ts
@@ -0,0 +1,408 @@
+/**
+ * @jest-environment jsdom
+ */
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core/node'),
+ fs: {
+ existsSync: jest.fn(),
+ mkdir: jest.fn(),
+ writeFileSync: jest.fn(),
+ readdirSync: jest.fn(),
+ readFileSync: jest.fn(),
+ appendFileSync: jest.fn(),
+ rm: jest.fn(),
+ writeBlob: jest.fn(),
+ joinPath: jest.fn(),
+ fileStat: jest.fn(),
+ },
+ joinPath: jest.fn(),
+ ConversationalExtension: jest.fn(),
+}))
+
+import { fs } from '@janhq/core'
+
+import JSONConversationalExtension from '.'
+
+describe('JSONConversationalExtension Tests', () => {
+ let extension: JSONConversationalExtension
+
+ beforeEach(() => {
+ // @ts-ignore
+ extension = new JSONConversationalExtension()
+ })
+
+ it('should create thread folder on load if it does not exist', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+
+ await extension.onLoad()
+
+ expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
+ })
+
+ it('should log message on unload', () => {
+ const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
+
+ extension.onUnload()
+
+ expect(consoleSpy).toHaveBeenCalledWith(
+ 'JSONConversationalExtension unloaded'
+ )
+ })
+
+ it('should return sorted threads', async () => {
+ jest
+ .spyOn(extension, 'getValidThreadDirs')
+ .mockResolvedValue(['dir1', 'dir2'])
+ jest
+ .spyOn(extension, 'readThread')
+ .mockResolvedValueOnce({ updated: '2023-01-01' })
+ .mockResolvedValueOnce({ updated: '2023-01-02' })
+
+ const threads = await extension.getThreads()
+
+ expect(threads).toEqual([
+ { updated: '2023-01-02' },
+ { updated: '2023-01-01' },
+ ])
+ })
+
+ it('should ignore broken threads', async () => {
+ jest
+ .spyOn(extension, 'getValidThreadDirs')
+ .mockResolvedValue(['dir1', 'dir2'])
+ jest
+ .spyOn(extension, 'readThread')
+ .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
+ .mockResolvedValueOnce('this_is_an_invalid_json_content')
+
+ const threads = await extension.getThreads()
+
+ expect(threads).toEqual([{ updated: '2023-01-01' }])
+ })
+
+ it('should save thread', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const writeFileSyncSpy = jest
+ .spyOn(fs, 'writeFileSync')
+ .mockResolvedValue({})
+
+ const thread = { id: '1', updated: '2023-01-01' } as any
+ await extension.saveThread(thread)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(writeFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should delete thread', async () => {
+ const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
+
+ await extension.deleteThread('1')
+
+ expect(rmSpy).toHaveBeenCalled()
+ })
+
+ it('should add new message', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const appendFileSyncSpy = jest
+ .spyOn(fs, 'appendFileSync')
+ .mockResolvedValue({})
+
+ const message = {
+ thread_id: '1',
+ content: [{ type: 'text', text: { annotations: [] } }],
+ } as any
+ await extension.addNewMessage(message)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(appendFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should store image', async () => {
+ const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ await extension.storeImage(
+ 'data:image/png;base64,abcd',
+ 'path/to/image.png'
+ )
+
+ expect(writeBlobSpy).toHaveBeenCalled()
+ })
+
+ it('should store file', async () => {
+ const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ await extension.storeFile(
+ 'data:application/pdf;base64,abcd',
+ 'path/to/file.pdf'
+ )
+
+ expect(writeBlobSpy).toHaveBeenCalled()
+ })
+
+ it('should write messages', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const writeFileSyncSpy = jest
+ .spyOn(fs, 'writeFileSync')
+ .mockResolvedValue({})
+
+ const messages = [{ id: '1', thread_id: '1', content: [] }] as any
+ await extension.writeMessages('1', messages)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(writeFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should get all messages on string response', async () => {
+ jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
+ jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n')
+
+ const messages = await extension.getAllMessages('1')
+
+ expect(messages).toEqual([{ id: '1' }, { id: '2' }])
+ })
+
+ it('should get all messages on object response', async () => {
+ jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
+ jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 })
+
+ const messages = await extension.getAllMessages('1')
+
+ expect(messages).toEqual([{ id: 1 }])
+ })
+
+ it('get all messages return empty on error', async () => {
+ jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl'])
+
+ const messages = await extension.getAllMessages('1')
+
+ expect(messages).toEqual([])
+ })
+
+ it('return empty messages on no messages file', async () => {
+ jest.spyOn(fs, 'readdirSync').mockResolvedValue([])
+
+ const messages = await extension.getAllMessages('1')
+
+ expect(messages).toEqual([])
+ })
+
+ it('should ignore error message', async () => {
+ jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl'])
+ jest
+ .spyOn(fs, 'readFileSync')
+ .mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n')
+
+ const messages = await extension.getAllMessages('1')
+
+ expect(messages).toEqual([{ id: '1' }, { id: '2' }])
+ })
+
+ it('should create thread folder on load if it does not exist', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+
+ await extension.onLoad()
+
+ expect(mkdirSpy).toHaveBeenCalledWith('file://threads')
+ })
+
+ it('should log message on unload', () => {
+ const consoleSpy = jest.spyOn(console, 'debug').mockImplementation()
+
+ extension.onUnload()
+
+ expect(consoleSpy).toHaveBeenCalledWith(
+ 'JSONConversationalExtension unloaded'
+ )
+ })
+
+ it('should return sorted threads', async () => {
+ jest
+ .spyOn(extension, 'getValidThreadDirs')
+ .mockResolvedValue(['dir1', 'dir2'])
+ jest
+ .spyOn(extension, 'readThread')
+ .mockResolvedValueOnce({ updated: '2023-01-01' })
+ .mockResolvedValueOnce({ updated: '2023-01-02' })
+
+ const threads = await extension.getThreads()
+
+ expect(threads).toEqual([
+ { updated: '2023-01-02' },
+ { updated: '2023-01-01' },
+ ])
+ })
+
+ it('should ignore broken threads', async () => {
+ jest
+ .spyOn(extension, 'getValidThreadDirs')
+ .mockResolvedValue(['dir1', 'dir2'])
+ jest
+ .spyOn(extension, 'readThread')
+ .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' }))
+ .mockResolvedValueOnce('this_is_an_invalid_json_content')
+
+ const threads = await extension.getThreads()
+
+ expect(threads).toEqual([{ updated: '2023-01-01' }])
+ })
+
+ it('should save thread', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const writeFileSyncSpy = jest
+ .spyOn(fs, 'writeFileSync')
+ .mockResolvedValue({})
+
+ const thread = { id: '1', updated: '2023-01-01' } as any
+ await extension.saveThread(thread)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(writeFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should delete thread', async () => {
+ const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({})
+
+ await extension.deleteThread('1')
+
+ expect(rmSpy).toHaveBeenCalled()
+ })
+
+ it('should add new message', async () => {
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(false)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const appendFileSyncSpy = jest
+ .spyOn(fs, 'appendFileSync')
+ .mockResolvedValue({})
+
+ const message = {
+ thread_id: '1',
+ content: [{ type: 'text', text: { annotations: [] } }],
+ } as any
+ await extension.addNewMessage(message)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(appendFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should add new image message', async () => {
+ jest
+ .spyOn(fs, 'existsSync')
+ // @ts-ignore
+ .mockResolvedValueOnce(false)
+ // @ts-ignore
+ .mockResolvedValueOnce(false)
+ // @ts-ignore
+ .mockResolvedValueOnce(true)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const appendFileSyncSpy = jest
+ .spyOn(fs, 'appendFileSync')
+ .mockResolvedValue({})
+ jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ const message = {
+ thread_id: '1',
+ content: [
+ { type: 'image', text: { annotations: ['data:image;base64,hehe'] } },
+ ],
+ } as any
+ await extension.addNewMessage(message)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(appendFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should add new pdf message', async () => {
+ jest
+ .spyOn(fs, 'existsSync')
+ // @ts-ignore
+ .mockResolvedValueOnce(false)
+ // @ts-ignore
+ .mockResolvedValueOnce(false)
+ // @ts-ignore
+ .mockResolvedValueOnce(true)
+ const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({})
+ const appendFileSyncSpy = jest
+ .spyOn(fs, 'appendFileSync')
+ .mockResolvedValue({})
+ jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ const message = {
+ thread_id: '1',
+ content: [
+ { type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } },
+ ],
+ } as any
+ await extension.addNewMessage(message)
+
+ expect(mkdirSpy).toHaveBeenCalled()
+ expect(appendFileSyncSpy).toHaveBeenCalled()
+ })
+
+ it('should store image', async () => {
+ const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ await extension.storeImage(
+ 'data:image/png;base64,abcd',
+ 'path/to/image.png'
+ )
+
+ expect(writeBlobSpy).toHaveBeenCalled()
+ })
+
+ it('should store file', async () => {
+ const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({})
+
+ await extension.storeFile(
+ 'data:application/pdf;base64,abcd',
+ 'path/to/file.pdf'
+ )
+
+ expect(writeBlobSpy).toHaveBeenCalled()
+ })
+})
+
+describe('test readThread', () => {
+ let extension: JSONConversationalExtension
+
+ beforeEach(() => {
+ // @ts-ignore
+ extension = new JSONConversationalExtension()
+ })
+
+ it('should read thread', async () => {
+ jest
+ .spyOn(fs, 'readFileSync')
+ .mockResolvedValue(JSON.stringify({ id: '1' }))
+ const thread = await extension.readThread('1')
+ expect(thread).toEqual(`{"id":"1"}`)
+ })
+
+ it('getValidThreadDirs should return valid thread directories', async () => {
+ jest
+ .spyOn(fs, 'readdirSync')
+ .mockResolvedValueOnce(['1', '2', '3'])
+ .mockResolvedValueOnce(['thread.json'])
+ .mockResolvedValueOnce(['thread.json'])
+ .mockResolvedValueOnce([])
+ // @ts-ignore
+ jest.spyOn(fs, 'existsSync').mockResolvedValue(true)
+ jest.spyOn(fs, 'fileStat').mockResolvedValue({
+ isDirectory: true,
+ } as any)
+ const validThreadDirs = await extension.getValidThreadDirs()
+ expect(validThreadDirs).toEqual(['1', '2'])
+ })
+})
diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts
index 1bca75347..b34f09181 100644
--- a/extensions/conversational-extension/src/index.ts
+++ b/extensions/conversational-extension/src/index.ts
@@ -5,6 +5,7 @@ import {
Thread,
ThreadMessage,
} from '@janhq/core'
+import { safelyParseJSON } from './jsonUtil'
/**
* JSONConversationalExtension is a ConversationalExtension implementation that provides
@@ -45,10 +46,11 @@ export default class JSONConversationalExtension extends ConversationalExtension
if (result.status === 'fulfilled') {
return typeof result.value === 'object'
? result.value
- : JSON.parse(result.value)
+ : safelyParseJSON(result.value)
}
+ return undefined
})
- .filter((convo) => convo != null)
+ .filter((convo) => !!convo)
convos.sort(
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
)
@@ -195,7 +197,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
* @param threadDirName the thread dir we are reading from.
* @returns data of the thread
*/
- private async readThread(threadDirName: string): Promise {
+ async readThread(threadDirName: string): Promise {
return fs.readFileSync(
await joinPath([
JSONConversationalExtension._threadFolder,
@@ -210,7 +212,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
* Returns a Promise that resolves to an array of thread directories.
* @private
*/
- private async getValidThreadDirs(): Promise {
+ async getValidThreadDirs(): Promise {
const fileInsideThread: string[] = await fs.readdirSync(
JSONConversationalExtension._threadFolder
)
@@ -266,7 +268,8 @@ export default class JSONConversationalExtension extends ConversationalExtension
const messages: ThreadMessage[] = []
result.forEach((line: string) => {
- messages.push(JSON.parse(line))
+ const message = safelyParseJSON(line)
+ if (message) messages.push(safelyParseJSON(line))
})
return messages
} catch (err) {
diff --git a/extensions/conversational-extension/src/jsonUtil.ts b/extensions/conversational-extension/src/jsonUtil.ts
new file mode 100644
index 000000000..7f83cadce
--- /dev/null
+++ b/extensions/conversational-extension/src/jsonUtil.ts
@@ -0,0 +1,14 @@
+// Note about performance
+// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block.
+// v8 4.5 and above can optimise try/catch
+export function safelyParseJSON(json) {
+ // This function cannot be optimised, it's best to
+ // keep it small!
+ var parsed
+ try {
+ parsed = JSON.parse(json)
+ } catch (e) {
+ return undefined
+ }
+ return parsed // Could be undefined!
+}
diff --git a/extensions/conversational-extension/tsconfig.json b/extensions/conversational-extension/tsconfig.json
index 2477d58ce..8427123e7 100644
--- a/extensions/conversational-extension/tsconfig.json
+++ b/extensions/conversational-extension/tsconfig.json
@@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
- "include": ["./src"]
+ "include": ["./src"],
+ "exclude": ["src/**/*.test.ts"]
}
diff --git a/extensions/model-extension/src/helpers/path.test.ts b/extensions/model-extension/src/helpers/path.test.ts
new file mode 100644
index 000000000..64ca65d8a
--- /dev/null
+++ b/extensions/model-extension/src/helpers/path.test.ts
@@ -0,0 +1,87 @@
+import { extractFileName } from './path';
+
+describe('extractFileName Function', () => {
+ it('should correctly extract the file name with the provided file extension', () => {
+ const url = 'http://example.com/some/path/to/file.ext';
+ const fileExtension = '.ext';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('file.ext');
+ });
+
+ it('should correctly append the file extension if it does not already exist in the file name', () => {
+ const url = 'http://example.com/some/path/to/file';
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('file.txt');
+ });
+
+ it('should handle cases where the URL does not have a file extension correctly', () => {
+ const url = 'http://example.com/some/path/to/file';
+ const fileExtension = '.jpg';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('file.jpg');
+ });
+
+ it('should correctly handle URLs without a trailing slash', () => {
+ const url = 'http://example.com/some/path/tofile';
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('tofile.txt');
+ });
+
+ it('should correctly handle URLs with multiple file extensions', () => {
+ const url = 'http://example.com/some/path/tofile.tar.gz';
+ const fileExtension = '.gz';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('tofile.tar.gz');
+ });
+
+ it('should correctly handle URLs with special characters', () => {
+ const url = 'http://example.com/some/path/tófÃlë.extë';
+ const fileExtension = '.extë';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('tófÃlë.extë');
+ });
+
+ it('should correctly handle URLs that are just a file with no path', () => {
+ const url = 'http://example.com/file.txt';
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('file.txt');
+ });
+
+ it('should correctly handle URLs that have special query parameters', () => {
+ const url = 'http://example.com/some/path/tofile.ext?query=1';
+ const fileExtension = '.ext';
+ const fileName = extractFileName(url.split('?')[0], fileExtension);
+ expect(fileName).toBe('tofile.ext');
+ });
+
+ it('should correctly handle URLs that have uppercase characters', () => {
+ const url = 'http://EXAMPLE.COM/PATH/TO/FILE.EXT';
+ const fileExtension = '.ext';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('FILE.EXT');
+ });
+
+ it('should correctly handle invalid URLs', () => {
+ const url = 'invalid-url';
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('invalid-url.txt');
+ });
+
+ it('should correctly handle empty URLs', () => {
+ const url = '';
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url, fileExtension);
+ expect(fileName).toBe('.txt');
+ });
+
+ it('should correctly handle undefined URLs', () => {
+ const url = undefined;
+ const fileExtension = '.txt';
+ const fileName = extractFileName(url as any, fileExtension);
+ expect(fileName).toBe('.txt');
+ });
+});
diff --git a/extensions/model-extension/src/helpers/path.ts b/extensions/model-extension/src/helpers/path.ts
index cbb151aa6..6091005b8 100644
--- a/extensions/model-extension/src/helpers/path.ts
+++ b/extensions/model-extension/src/helpers/path.ts
@@ -3,6 +3,8 @@
*/
export function extractFileName(url: string, fileExtension: string): string {
+ if(!url) return fileExtension
+
const extractedFileName = url.split('/').pop()
const fileName = extractedFileName.toLowerCase().endsWith(fileExtension)
? extractedFileName
diff --git a/jest.config.js b/jest.config.js
index a911a7f0a..a9f0f5938 100644
--- a/jest.config.js
+++ b/jest.config.js
@@ -1,3 +1,10 @@
module.exports = {
- projects: ['/core', '/web', '/joi'],
+ projects: [
+ '/core',
+ '/web',
+ '/joi',
+ '/extensions/inference-nitro-extension',
+ '/extensions/conversational-extension',
+ '/extensions/model-extension',
+ ],
}
diff --git a/web/containers/ErrorMessage/index.test.tsx b/web/containers/ErrorMessage/index.test.tsx
new file mode 100644
index 000000000..99dad5415
--- /dev/null
+++ b/web/containers/ErrorMessage/index.test.tsx
@@ -0,0 +1,107 @@
+// ErrorMessage.test.tsx
+import React from 'react';
+import { render, screen, fireEvent } from '@testing-library/react';
+import '@testing-library/jest-dom';
+import ErrorMessage from './index';
+import { ThreadMessage, MessageStatus, ErrorCode } from '@janhq/core';
+import { useAtomValue, useSetAtom } from 'jotai';
+import useSendChatMessage from '@/hooks/useSendChatMessage';
+
+// Mock the dependencies
+jest.mock('jotai', () => {
+ const originalModule = jest.requireActual('jotai')
+ return {
+ ...originalModule,
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ }
+ })
+
+jest.mock('@/hooks/useSendChatMessage', () => ({
+ __esModule: true,
+ default: jest.fn(),
+}));
+
+describe('ErrorMessage Component', () => {
+ const mockSetMainState = jest.fn();
+ const mockSetSelectedSettingScreen = jest.fn();
+ const mockSetModalTroubleShooting = jest.fn();
+ const mockResendChatMessage = jest.fn();
+
+ beforeEach(() => {
+ jest.clearAllMocks();
+ (useAtomValue as jest.Mock).mockReturnValue([]);
+ (useSetAtom as jest.Mock).mockReturnValue(mockSetMainState);
+ (useSetAtom as jest.Mock).mockReturnValue(mockSetSelectedSettingScreen);
+ (useSetAtom as jest.Mock).mockReturnValue(mockSetModalTroubleShooting);
+ (useSendChatMessage as jest.Mock).mockReturnValue({ resendChatMessage: mockResendChatMessage });
+ });
+
+ it('renders stopped message correctly', () => {
+ const message: ThreadMessage = {
+ id: '1',
+ status: MessageStatus.Stopped,
+ content: [{ text: { value: 'Test message' } }],
+ } as ThreadMessage;
+
+ render( );
+
+ expect(screen.getByText("Oops! The generation was interrupted. Let's give it another go!")).toBeInTheDocument();
+ expect(screen.getByText('Regenerate')).toBeInTheDocument();
+ });
+
+ it('renders error message with InvalidApiKey correctly', () => {
+ const message: ThreadMessage = {
+ id: '1',
+ status: MessageStatus.Error,
+ error_code: ErrorCode.InvalidApiKey,
+ content: [{ text: { value: 'Invalid API Key' } }],
+ } as ThreadMessage;
+
+ render( );
+
+ expect(screen.getByTestId('invalid-API-key-error')).toBeInTheDocument();
+ expect(screen.getByText('Settings')).toBeInTheDocument();
+ });
+
+ it('renders general error message correctly', () => {
+ const message: ThreadMessage = {
+ id: '1',
+ status: MessageStatus.Error,
+ error_code: ErrorCode.Unknown,
+ content: [{ text: { value: 'Unknown error occurred' } }],
+ } as ThreadMessage;
+
+ render( );
+
+ expect(screen.getByText("Apologies, something’s amiss!")).toBeInTheDocument();
+ expect(screen.getByText('troubleshooting assistance')).toBeInTheDocument();
+ });
+
+ it('calls regenerateMessage when Regenerate button is clicked', () => {
+ const message: ThreadMessage = {
+ id: '1',
+ status: MessageStatus.Stopped,
+ content: [{ text: { value: 'Test message' } }],
+ } as ThreadMessage;
+
+ render( );
+
+ fireEvent.click(screen.getByText('Regenerate'));
+ expect(mockResendChatMessage).toHaveBeenCalled();
+ });
+
+ it('opens troubleshooting modal when link is clicked', () => {
+ const message: ThreadMessage = {
+ id: '1',
+ status: MessageStatus.Error,
+ error_code: ErrorCode.Unknown,
+ content: [{ text: { value: 'Unknown error occurred' } }],
+ } as ThreadMessage;
+
+ render( );
+
+ fireEvent.click(screen.getByText('troubleshooting assistance'));
+ expect(mockSetModalTroubleShooting).toHaveBeenCalledWith(true);
+ });
+});
diff --git a/web/containers/ListContainer/index.test.tsx b/web/containers/ListContainer/index.test.tsx
new file mode 100644
index 000000000..866d8ff4e
--- /dev/null
+++ b/web/containers/ListContainer/index.test.tsx
@@ -0,0 +1,69 @@
+// ListContainer.test.tsx
+import React from 'react'
+import { render, screen, fireEvent } from '@testing-library/react'
+import '@testing-library/jest-dom'
+import ListContainer from './index'
+
+class ResizeObserverMock {
+ observe() {}
+ unobserve() {}
+ disconnect() {}
+}
+
+global.ResizeObserver = ResizeObserverMock
+
+describe('ListContainer', () => {
+ const scrollToMock = jest.fn()
+ Element.prototype.scrollTo = scrollToMock
+
+ it('renders children correctly', () => {
+ render(
+
+ Test Child
+
+ )
+ expect(screen.getByTestId('child')).toBeInTheDocument()
+ })
+
+ it('scrolls to bottom on initial render', () => {
+
+ render(
+
+ Long content
+
+ )
+
+ expect(scrollToMock).toHaveBeenCalledWith({
+ top: expect.any(Number),
+ behavior: 'instant',
+ })
+ })
+
+ it('sets isUserManuallyScrollingUp when scrolling up', () => {
+ const { container } = render(
+
+ Long content
+
+ )
+
+ const scrollArea = container.firstChild as HTMLElement
+
+ // Simulate scrolling down
+ fireEvent.scroll(scrollArea, { target: { scrollTop: 500 } })
+
+ // Simulate scrolling up
+ fireEvent.scroll(scrollArea, { target: { scrollTop: 300 } })
+
+ // We can't directly test the internal state, but we can check that
+ // subsequent scroll to bottom doesn't happen (as it would if isUserManuallyScrollingUp was false)
+
+ // Trigger a re-render
+ render(
+
+ Long content
+
+ )
+
+ expect(scrollToMock).toHaveBeenCalled()
+ })
+})
diff --git a/web/containers/ListContainer/index.tsx b/web/containers/ListContainer/index.tsx
index a48db5313..bd650e315 100644
--- a/web/containers/ListContainer/index.tsx
+++ b/web/containers/ListContainer/index.tsx
@@ -29,7 +29,7 @@ const ListContainer = ({ children }: Props) => {
}, [])
useEffect(() => {
- if (isUserManuallyScrollingUp.current === true) return
+ if (isUserManuallyScrollingUp.current === true || !listRef.current) return
const scrollHeight = listRef.current?.scrollHeight ?? 0
listRef.current?.scrollTo({
top: scrollHeight,
diff --git a/web/containers/Loader/GenerateResponse.test.tsx b/web/containers/Loader/GenerateResponse.test.tsx
new file mode 100644
index 000000000..7e3e5c3a4
--- /dev/null
+++ b/web/containers/Loader/GenerateResponse.test.tsx
@@ -0,0 +1,75 @@
+// GenerateResponse.test.tsx
+import React from 'react';
+import { render, screen, act } from '@testing-library/react';
+import '@testing-library/jest-dom';
+import GenerateResponse from './GenerateResponse';
+
+jest.useFakeTimers();
+
+describe('GenerateResponse Component', () => {
+ it('renders initially with 1% loader width', () => {
+ render( );
+ const loader = screen.getByTestId('response-loader');
+ expect(loader).toHaveStyle('width: 24%');
+ });
+
+ it('updates loader width over time', () => {
+ render( );
+ const loader = screen.getByTestId('response-loader');
+
+ // Advance timers to simulate time passing
+ act(() => {
+ jest.advanceTimersByTime(1000);
+ });
+
+ expect(loader).not.toHaveStyle('width: 1%');
+ expect(parseFloat(loader.style.width)).toBeGreaterThan(1);
+ });
+
+ it('pauses at specific percentages', () => {
+ render( );
+ const loader = screen.getByTestId('response-loader');
+
+ // Advance to 24%
+ act(() => {
+ for (let i = 0; i < 24; i++) {
+ jest.advanceTimersByTime(50);
+ }
+ });
+
+ expect(loader).toHaveStyle('width: 50%');
+
+ // Advance past the pause
+ act(() => {
+ jest.advanceTimersByTime(300);
+ });
+
+ expect(loader).toHaveStyle('width: 78%');
+ });
+
+ it('stops at 85%', () => {
+ render( );
+ const loader = screen.getByTestId('response-loader');
+
+ // Advance to 50%
+ act(() => {
+ for (let i = 0; i < 85; i++) {
+ jest.advanceTimersByTime(50);
+ }
+ });
+
+ expect(loader).toHaveStyle('width: 50%');
+
+ // Check if it stays at 78%
+ act(() => {
+ jest.advanceTimersByTime(1000);
+ });
+
+ expect(loader).toHaveStyle('width: 78%');
+ });
+
+ it('displays the correct text', () => {
+ render( );
+ expect(screen.getByText('Generating response...')).toBeInTheDocument();
+ });
+});
diff --git a/web/containers/Loader/ModelReload.test.tsx b/web/containers/Loader/ModelReload.test.tsx
new file mode 100644
index 000000000..2de2db4fd
--- /dev/null
+++ b/web/containers/Loader/ModelReload.test.tsx
@@ -0,0 +1,124 @@
+// ModelReload.test.tsx
+import React from 'react'
+import '@testing-library/jest-dom'
+import { render, screen, act } from '@testing-library/react'
+import ModelReload from './ModelReload'
+import { useActiveModel } from '@/hooks/useActiveModel'
+
+jest.mock('@/hooks/useActiveModel')
+
+describe('ModelReload Component', () => {
+ beforeEach(() => {
+ jest.useFakeTimers()
+ })
+
+ afterEach(() => {
+ jest.useRealTimers()
+ })
+
+ it('renders nothing when not loading', () => {
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: false },
+ })
+
+ const { container } = render( )
+ expect(container).toBeEmptyDOMElement()
+ })
+
+ it('renders loading message when loading', () => {
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: true, model: { id: 'test-model' } },
+ })
+
+ render( )
+ expect(screen.getByText(/Reloading model test-model/)).toBeInTheDocument()
+ })
+
+ it('updates loader percentage over time', () => {
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: true, model: { id: 'test-model' } },
+ })
+
+ render( )
+
+ // Initial render
+ expect(screen.getByText(/Reloading model test-model/)).toBeInTheDocument()
+ const loaderElement = screen.getByText(
+ /Reloading model test-model/
+ ).parentElement
+
+ // Check initial width
+ expect(loaderElement?.firstChild).toHaveStyle('width: 50%')
+
+ // Advance timers and check updated width
+ act(() => {
+ jest.advanceTimersByTime(250)
+ })
+ expect(loaderElement?.firstChild).toHaveStyle('width: 78%')
+
+ // Advance to 99%
+ for (let i = 0; i < 27; i++) {
+ act(() => {
+ jest.advanceTimersByTime(250)
+ })
+ }
+ expect(loaderElement?.firstChild).toHaveStyle('width: 99%')
+
+ // Advance one more time to hit the 250ms delay
+ act(() => {
+ jest.advanceTimersByTime(250)
+ })
+ expect(loaderElement?.firstChild).toHaveStyle('width: 99%')
+ })
+
+ it('stops at 99%', () => {
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: true, model: { id: 'test-model' } },
+ })
+
+ render( )
+
+ const loaderElement = screen.getByText(
+ /Reloading model test-model/
+ ).parentElement
+
+ // Advance to 99%
+ for (let i = 0; i < 50; i++) {
+ act(() => {
+ jest.advanceTimersByTime(250)
+ })
+ }
+ expect(loaderElement?.firstChild).toHaveStyle('width: 99%')
+
+ // Advance more and check it stays at 99%
+ act(() => {
+ jest.advanceTimersByTime(1000)
+ })
+ expect(loaderElement?.firstChild).toHaveStyle('width: 99%')
+ })
+
+ it('resets to 0% when loading completes', () => {
+ const { rerender } = render( )
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: true, model: { id: 'test-model' } },
+ })
+
+ rerender( )
+
+ const loaderElement = screen.getByText(
+ /Reloading model test-model/
+ ).parentElement
+
+ expect(loaderElement?.firstChild).toHaveStyle('width: 50%')
+ // Set loading to false
+ ;(useActiveModel as jest.Mock).mockReturnValue({
+ stateModel: { loading: false },
+ })
+
+ rerender( )
+
+ expect(
+ screen.queryByText(/Reloading model test-model/)
+ ).not.toBeInTheDocument()
+ })
+})
diff --git a/web/containers/Loader/ProgressCircle.test.tsx b/web/containers/Loader/ProgressCircle.test.tsx
new file mode 100644
index 000000000..651f9a4f2
--- /dev/null
+++ b/web/containers/Loader/ProgressCircle.test.tsx
@@ -0,0 +1,22 @@
+// ProgressCircle.test.tsx
+import React from 'react'
+import { render, screen } from '@testing-library/react'
+import '@testing-library/jest-dom'
+import ProgressCircle from './ProgressCircle'
+
+describe('ProgressCircle Component', () => {
+ test('renders ProgressCircle with default props', () => {
+ render( )
+ const svg = screen.getByRole('img', { hidden: true })
+ expect(svg).toBeInTheDocument()
+ expect(svg).toHaveAttribute('width', '100')
+ expect(svg).toHaveAttribute('height', '100')
+ })
+
+ test('renders ProgressCircle with custom size', () => {
+ render( )
+ const svg = screen.getByRole('img', { hidden: true })
+ expect(svg).toHaveAttribute('width', '200')
+ expect(svg).toHaveAttribute('height', '200')
+ })
+})
diff --git a/web/containers/Loader/ProgressCircle.tsx b/web/containers/Loader/ProgressCircle.tsx
index e10434113..aec7c81cc 100644
--- a/web/containers/Loader/ProgressCircle.tsx
+++ b/web/containers/Loader/ProgressCircle.tsx
@@ -22,6 +22,7 @@ const ProgressCircle: React.FC = ({
width={size}
xmlns="http://www.w3.org/2000/svg"
viewBox={`0 0 ${size} ${size}`}
+ role="img"
>
{
+ const mockLogs = ['Log 1', 'Log 2', 'Log 3']
+
+ beforeEach(() => {
+ // Reset all mocks
+ jest.resetAllMocks()
+
+ // Setup default mock implementations
+ ;(useLogs as jest.Mock).mockReturnValue({
+ getLogs: jest.fn().mockResolvedValue(mockLogs.join('\n')),
+ })
+ ;(usePath as jest.Mock).mockReturnValue({
+ onRevealInFinder: jest.fn(),
+ })
+ ;(useClipboard as jest.Mock).mockReturnValue({
+ copy: jest.fn(),
+ copied: false,
+ })
+ })
+
+ test('renders AppLogs component with logs', async () => {
+ render( )
+
+ await waitFor(() => {
+ mockLogs.forEach((log) => {
+ expect(screen.getByText(log)).toBeInTheDocument()
+ })
+ })
+
+ expect(screen.getByText('Open')).toBeInTheDocument()
+ expect(screen.getByText('Copy All')).toBeInTheDocument()
+ })
+
+ test('renders empty state when no logs', async () => {
+ ;(useLogs as jest.Mock).mockReturnValue({
+ getLogs: jest.fn().mockResolvedValue(''),
+ })
+
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByText('Empty logs')).toBeInTheDocument()
+ })
+ })
+
+ test('calls onRevealInFinder when Open button is clicked', async () => {
+ const mockOnRevealInFinder = jest.fn()
+ ;(usePath as jest.Mock).mockReturnValue({
+ onRevealInFinder: mockOnRevealInFinder,
+ })
+
+ render( )
+
+ await waitFor(() => {
+ const openButton = screen.getByText('Open')
+ userEvent.click(openButton)
+
+ expect(mockOnRevealInFinder).toHaveBeenCalledWith('Logs')
+ })
+ })
+
+ test('calls copy function when Copy All button is clicked', async () => {
+ const mockCopy = jest.fn()
+ ;(useClipboard as jest.Mock).mockReturnValue({
+ copy: mockCopy,
+ copied: false,
+ })
+
+ render( )
+
+ await waitFor(() => {
+ const copyButton = screen.getByText('Copy All')
+ userEvent.click(copyButton)
+ expect(mockCopy).toHaveBeenCalled()
+ })
+ })
+
+ test('shows Copying... text when copied is true', async () => {
+ ;(useClipboard as jest.Mock).mockReturnValue({
+ copy: jest.fn(),
+ copied: true,
+ })
+
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByText('Copying...')).toBeInTheDocument()
+ })
+ })
+})
diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx
new file mode 100644
index 000000000..96ff6f559
--- /dev/null
+++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx
@@ -0,0 +1,137 @@
+// ./AssistantSetting.test.tsx
+import '@testing-library/jest-dom'
+import React from 'react'
+import { render, screen, waitFor } from '@testing-library/react'
+import userEvent from '@testing-library/user-event'
+import { useAtomValue, useSetAtom } from 'jotai'
+import { useActiveModel } from '@/hooks/useActiveModel'
+import { useCreateNewThread } from '@/hooks/useCreateNewThread'
+import AssistantSetting from './index'
+
+jest.mock('jotai', () => {
+ const originalModule = jest.requireActual('jotai')
+ return {
+ ...originalModule,
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ }
+})
+jest.mock('@/hooks/useActiveModel')
+jest.mock('@/hooks/useCreateNewThread')
+jest.mock('./../../../../containers/ModelSetting/SettingComponent', () => {
+ return jest.fn().mockImplementation(({ onValueUpdated }) => {
+ return (
+ onValueUpdated('chunk_size', e.target.value)}
+ />
+ )
+ })
+})
+
+describe('AssistantSetting Component', () => {
+ const mockActiveThread = {
+ id: '123',
+ assistants: [
+ {
+ id: '456',
+ tools: [
+ {
+ type: 'retrieval',
+ enabled: true,
+ settings: {
+ chunk_size: 100,
+ chunk_overlap: 50,
+ },
+ },
+ ],
+ },
+ ],
+ }
+ const ComponentPropsMock: any[] = [
+ {
+ key: 'chunk_size',
+ type: 'number',
+ title: 'Chunk Size',
+ value: 100,
+ controllerType: 'input',
+ },
+ {
+ key: 'chunk_overlap',
+ type: 'number',
+ title: 'Chunk Overlap',
+ value: 50,
+ controllerType: 'input',
+ },
+ ]
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ test('renders AssistantSetting component with proper data', async () => {
+ const setEngineParamsUpdate = jest.fn()
+ ;(useSetAtom as jest.Mock).mockImplementationOnce(
+ () => setEngineParamsUpdate
+ )
+ ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
+ const updateThreadMetadata = jest.fn()
+ ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() })
+ ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
+ updateThreadMetadata,
+ })
+
+ render( )
+
+ await waitFor(() => {
+ const firstInput = screen.getByTestId('input')
+ expect(firstInput).toBeInTheDocument()
+
+ userEvent.type(firstInput, '200')
+ expect(updateThreadMetadata).toHaveBeenCalled()
+ expect(setEngineParamsUpdate).toHaveBeenCalledTimes(0)
+ })
+ })
+
+ test('triggers model reload with onValueChanged', async () => {
+ const setEngineParamsUpdate = jest.fn()
+ const updateThreadMetadata = jest.fn()
+ const stopModel = jest.fn()
+ ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread)
+ ;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate)
+ ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel })
+ ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
+ updateThreadMetadata,
+ })
+ ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
+ updateThreadMetadata,
+ })
+
+ render(
+
+ )
+
+ await waitFor(() => {
+ const firstInput = screen.getByTestId('input')
+ expect(firstInput).toBeInTheDocument()
+
+ userEvent.type(firstInput, '200')
+ expect(setEngineParamsUpdate).toHaveBeenCalled()
+ expect(stopModel).toHaveBeenCalled()
+ })
+ })
+})
From 15f42fb2691c3f193f5c81af142712db1d5e5a09 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Mon, 23 Sep 2024 14:42:52 +0700
Subject: [PATCH 21/37] fix: toolbar overlap chat input (#3720)
---
web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
index bbecef10e..235ebeae6 100644
--- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
@@ -392,7 +392,7 @@ const ChatInput = () => {
{activeSettingInputBox && (
Date: Mon, 23 Sep 2024 16:13:45 +0700
Subject: [PATCH 22/37] fix: the monorepo jest configs should not cover
extensions - they would be moved to mini repositories soon (#3723)
---
jest.config.js | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/jest.config.js b/jest.config.js
index a9f0f5938..a911a7f0a 100644
--- a/jest.config.js
+++ b/jest.config.js
@@ -1,10 +1,3 @@
module.exports = {
- projects: [
- '
/core',
- '/web',
- '/joi',
- '/extensions/inference-nitro-extension',
- '/extensions/conversational-extension',
- '/extensions/model-extension',
- ],
+ projects: ['/core', '/web', '/joi'],
}
From c0b59ece4d41162f38b525a6a7802b77e1936508 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 24 Sep 2024 10:07:53 +0700
Subject: [PATCH 23/37] fix: #3558 wrong model metadata import or download from
HuggingFace (#3725)
* fix: #3558 wrong model metadata import
* chore: remove redundant metadata retrieval
---
extensions/model-extension/src/index.test.ts | 56 ++++++++++++-
extensions/model-extension/src/index.ts | 78 +++++++++++++------
extensions/model-extension/src/node/index.ts | 52 ++++++++-----
.../model-extension/src/node/node.test.ts | 53 +++++++++++++
4 files changed, 195 insertions(+), 44 deletions(-)
create mode 100644 extensions/model-extension/src/node/node.test.ts
diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts
index 6816d7101..823b3a41d 100644
--- a/extensions/model-extension/src/index.test.ts
+++ b/extensions/model-extension/src/index.test.ts
@@ -1,6 +1,13 @@
+/**
+ * @jest-environment jsdom
+ */
const readDirSyncMock = jest.fn()
const existMock = jest.fn()
const readFileSyncMock = jest.fn()
+const downloadMock = jest.fn()
+const mkdirMock = jest.fn()
+const writeFileSyncMock = jest.fn()
+const copyFileMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
@@ -8,6 +15,9 @@ jest.mock('@janhq/core', () => ({
existsSync: existMock,
readdirSync: readDirSyncMock,
readFileSync: readFileSyncMock,
+ writeFileSync: writeFileSyncMock,
+ mkdir: mkdirMock,
+ copyFile: copyFileMock,
fileStat: () => ({
isDirectory: false,
}),
@@ -15,10 +25,20 @@ jest.mock('@janhq/core', () => ({
dirName: jest.fn(),
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
+ downloadFile: downloadMock,
}))
+global.fetch = jest.fn(() =>
+ Promise.resolve({
+ json: () => Promise.resolve({ test: 100 }),
+ arrayBuffer: jest.fn(),
+ })
+) as jest.Mock
+
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
+import { renderJinjaTemplate } from './node/index'
+import { Template } from '@huggingface/jinja'
describe('JanModelExtension', () => {
let sut: JanModelExtension
@@ -187,7 +207,6 @@ describe('JanModelExtension', () => {
describe('no models downloaded', () => {
it('should return empty array', async () => {
// Mock downloaded models data
- const downloadedModels = []
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
@@ -557,8 +576,41 @@ describe('JanModelExtension', () => {
file_path: 'file://models/model1/model.json',
} as any)
- expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
+ expect(fs.unlinkSync).toHaveBeenCalledWith(
+ 'file://models/model1/test.engine'
+ )
})
})
})
+
+ describe('downloadModel', () => {
+ const model: any = {
+ id: 'model-id',
+ name: 'Test Model',
+ sources: [
+ { url: 'http://example.com/model.gguf', filename: 'model.gguf' },
+ ],
+ engine: 'test-engine',
+ }
+
+ const network = {
+ ignoreSSL: true,
+ proxy: 'http://proxy.example.com',
+ }
+
+ const gpuSettings: any = {
+ gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }],
+ }
+
+ it('should reject with invalid gguf metadata', async () => {
+ existMock.mockImplementation(() => false)
+
+ expect(
+ sut.downloadModel(model, gpuSettings, network)
+ ).rejects.toBeTruthy()
+ })
+
+
+ })
+
})
diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts
index ac9b06a09..beb9f1fed 100644
--- a/extensions/model-extension/src/index.ts
+++ b/extensions/model-extension/src/index.ts
@@ -24,6 +24,7 @@ import {
ModelEvent,
ModelFile,
dirName,
+ ModelSettingParams,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@@ -80,11 +81,27 @@ export default class JanModelExtension extends ModelExtension {
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise {
- // create corresponding directory
+ // Create corresponding directory
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
+
+ // Download HF model - model.json not exist
if (!(await fs.existsSync(modelJsonPath))) {
+ // It supports only one source for HF download
+ const metadata = await this.fetchModelMetadata(model.sources[0].url)
+ const updatedModel = await this.retrieveGGUFMetadata(metadata)
+ if (updatedModel) {
+ // Update model settings
+ model.settings = {
+ ...model.settings,
+ ...updatedModel.settings,
+ }
+ model.parameters = {
+ ...model.parameters,
+ ...updatedModel.parameters,
+ }
+ }
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
events.emit(ModelEvent.OnModelsUpdate, {})
}
@@ -327,7 +344,7 @@ export default class JanModelExtension extends ModelExtension {
// Should depend on sources?
const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user'
- if (isUserImportModel) {
+ if (isUserImportModel) {
// just delete the folder
return fs.rm(dirPath)
}
@@ -555,7 +572,7 @@ export default class JanModelExtension extends ModelExtension {
])
)
- const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
+ const updatedModel = await this.retrieveGGUFMetadata(metadata)
if (!defaultModel) {
console.error('Unable to find default model')
@@ -575,18 +592,11 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
- stop: eos_id
- ? [metadata['tokenizer.ggml.tokens'][eos_id] ?? '']
- : defaultModel.parameters.stop,
+ ...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
- prompt_template:
- metadata?.parsed_chat_template ??
- defaultModel.settings.prompt_template,
- ctx_len:
- metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
- ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
+ ...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@@ -666,9 +676,9 @@ export default class JanModelExtension extends ModelExtension {
'retrieveGGUFMetadata',
modelBinaryPath
)
- const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
const binaryFileName = await baseName(modelBinaryPath)
+ const updatedModel = await this.retrieveGGUFMetadata(metadata)
const model: Model = {
...defaultModel,
@@ -682,19 +692,12 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
- stop: eos_id
- ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
- : defaultModel.parameters.stop,
+ ...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
- prompt_template:
- metadata?.parsed_chat_template ??
- defaultModel.settings.prompt_template,
- ctx_len:
- metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
- ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
+ ...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@@ -860,4 +863,35 @@ export default class JanModelExtension extends ModelExtension {
importedModels
)
}
+
+ /**
+ * Retrieve Model Settings from GGUF Metadata
+ * @param metadata
+ * @returns
+ */
+ async retrieveGGUFMetadata(metadata: any): Promise> {
+ const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata)
+ const defaultModel = DEFAULT_MODEL as Model
+ const eos_id = metadata['tokenizer.ggml.eos_token_id']
+ const architecture = metadata['general.architecture']
+
+ return {
+ settings: {
+ prompt_template: template ?? defaultModel.settings.prompt_template,
+ ctx_len:
+ metadata[`${architecture}.context_length`] ??
+ metadata['llama.context_length'] ??
+ 4096,
+ ngl:
+ (metadata[`${architecture}.block_count`] ??
+ metadata['llama.block_count'] ??
+ 32) + 1,
+ },
+ parameters: {
+ stop: eos_id
+ ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
+ : defaultModel.parameters.stop,
+ },
+ }
+ }
}
diff --git a/extensions/model-extension/src/node/index.ts b/extensions/model-extension/src/node/index.ts
index 2b498f424..6323d7f97 100644
--- a/extensions/model-extension/src/node/index.ts
+++ b/extensions/model-extension/src/node/index.ts
@@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
// Parse metadata and tensor info
const { metadata } = ggufMetadata(buffer.buffer)
- const template = new Template(metadata['tokenizer.chat_template'])
- const eos_id = metadata['tokenizer.ggml.eos_token_id']
- const bos_id = metadata['tokenizer.ggml.bos_token_id']
- const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
- const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
// Parse jinja template
- const renderedTemplate = template.render({
- add_generation_prompt: true,
- eos_token,
- bos_token,
- messages: [
- {
- role: 'system',
- content: '{system_message}',
- },
- {
- role: 'user',
- content: '{prompt}',
- },
- ],
- })
+ const renderedTemplate = renderJinjaTemplate(metadata)
return {
...metadata,
parsed_chat_template: renderedTemplate,
@@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
console.log('[MODEL_EXT]', e)
}
}
+
+/**
+ * Convert metadata to jinja template
+ * @param metadata
+ */
+export const renderJinjaTemplate = (metadata: any): string => {
+ const template = new Template(metadata['tokenizer.chat_template'])
+ const eos_id = metadata['tokenizer.ggml.eos_token_id']
+ const bos_id = metadata['tokenizer.ggml.bos_token_id']
+ if (eos_id === undefined || bos_id === undefined) {
+ return ''
+ }
+ const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
+ const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
+ // Parse jinja template
+ return template.render({
+ add_generation_prompt: true,
+ eos_token,
+ bos_token,
+ messages: [
+ {
+ role: 'system',
+ content: '{system_message}',
+ },
+ {
+ role: 'user',
+ content: '{prompt}',
+ },
+ ],
+ })
+}
diff --git a/extensions/model-extension/src/node/node.test.ts b/extensions/model-extension/src/node/node.test.ts
new file mode 100644
index 000000000..afd2b8470
--- /dev/null
+++ b/extensions/model-extension/src/node/node.test.ts
@@ -0,0 +1,53 @@
+import { renderJinjaTemplate } from './index'
+import { Template } from '@huggingface/jinja'
+
+jest.mock('@huggingface/jinja', () => ({
+ Template: jest.fn((template: string) => ({
+ render: jest.fn(() => `${template}_rendered`),
+ })),
+}))
+
+describe('renderJinjaTemplate', () => {
+ beforeEach(() => {
+ jest.clearAllMocks() // Clear mocks between tests
+ })
+
+ it('should render the template with correct parameters', () => {
+ const metadata = {
+ 'tokenizer.chat_template': 'Hello, {{ messages }}!',
+ 'tokenizer.ggml.eos_token_id': 0,
+ 'tokenizer.ggml.bos_token_id': 1,
+ 'tokenizer.ggml.tokens': ['EOS', 'BOS'],
+ }
+
+ const renderedTemplate = renderJinjaTemplate(metadata)
+
+ expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
+
+ expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered')
+ })
+
+ it('should handle missing token IDs gracefully', () => {
+ const metadata = {
+ 'tokenizer.chat_template': 'Hello, {{ messages }}!',
+ 'tokenizer.ggml.eos_token_id': 0,
+ 'tokenizer.ggml.tokens': ['EOS'],
+ }
+
+ const renderedTemplate = renderJinjaTemplate(metadata)
+
+ expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
+
+ expect(renderedTemplate).toBe('')
+ })
+
+ it('should handle empty template gracefully', () => {
+ const metadata = {}
+
+ const renderedTemplate = renderJinjaTemplate(metadata)
+
+ expect(Template).toHaveBeenCalledWith(undefined)
+
+ expect(renderedTemplate).toBe("")
+ })
+})
From 6af17c6455df482c38101ed71c762cfd3ec7a2ba Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 24 Sep 2024 10:40:32 +0700
Subject: [PATCH 24/37] fix: #3513 - anthropic extension does not forward the
system prompt (#3724)
---
.../jest.config.js | 9 +++
.../package.json | 1 +
.../src/anthropic.test.ts | 77 +++++++++++++++++++
.../src/index.ts | 7 +-
.../tsconfig.json | 3 +-
5 files changed, 95 insertions(+), 2 deletions(-)
create mode 100644 extensions/inference-anthropic-extension/jest.config.js
create mode 100644 extensions/inference-anthropic-extension/src/anthropic.test.ts
diff --git a/extensions/inference-anthropic-extension/jest.config.js b/extensions/inference-anthropic-extension/jest.config.js
new file mode 100644
index 000000000..3e32adceb
--- /dev/null
+++ b/extensions/inference-anthropic-extension/jest.config.js
@@ -0,0 +1,9 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ transform: {
+ 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
+ },
+ transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
+}
diff --git a/extensions/inference-anthropic-extension/package.json b/extensions/inference-anthropic-extension/package.json
index a9d30a8e5..19c0df5e8 100644
--- a/extensions/inference-anthropic-extension/package.json
+++ b/extensions/inference-anthropic-extension/package.json
@@ -9,6 +9,7 @@
"author": "Jan ",
"license": "AGPL-3.0",
"scripts": {
+ "test": "jest test",
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install",
"sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-anthropic-extension && yarn && yarn build:publish"
diff --git a/extensions/inference-anthropic-extension/src/anthropic.test.ts b/extensions/inference-anthropic-extension/src/anthropic.test.ts
new file mode 100644
index 000000000..703ead0fb
--- /dev/null
+++ b/extensions/inference-anthropic-extension/src/anthropic.test.ts
@@ -0,0 +1,77 @@
+// Import necessary modules
+import JanInferenceAnthropicExtension, { Settings } from '.'
+import { PayloadType, ChatCompletionRole } from '@janhq/core'
+
+// Mocks
+jest.mock('@janhq/core', () => ({
+ RemoteOAIEngine: jest.fn().mockImplementation(() => ({
+ registerSettings: jest.fn(),
+ registerModels: jest.fn(),
+ getSetting: jest.fn(),
+ onChange: jest.fn(),
+ onSettingUpdate: jest.fn(),
+ onLoad: jest.fn(),
+ headers: jest.fn(),
+ })),
+ PayloadType: jest.fn(),
+ ChatCompletionRole: {
+ User: 'user' as const,
+ Assistant: 'assistant' as const,
+ System: 'system' as const,
+ },
+}))
+
+// Helper functions
+const createMockPayload = (): PayloadType => ({
+ messages: [
+ { role: ChatCompletionRole.System, content: 'Meow' },
+ { role: ChatCompletionRole.User, content: 'Hello' },
+ { role: ChatCompletionRole.Assistant, content: 'Hi there' },
+ ],
+ model: 'claude-v1',
+ stream: false,
+})
+
+describe('JanInferenceAnthropicExtension', () => {
+ let extension: JanInferenceAnthropicExtension
+
+ beforeEach(() => {
+ extension = new JanInferenceAnthropicExtension('', '')
+ extension.apiKey = 'mock-api-key'
+ extension.inferenceUrl = 'mock-endpoint'
+ jest.clearAllMocks()
+ })
+
+ it('should initialize with correct settings', async () => {
+ await extension.onLoad()
+ expect(extension.apiKey).toBe('mock-api-key')
+ expect(extension.inferenceUrl).toBe('mock-endpoint')
+ })
+
+ it('should transform payload correctly', () => {
+ const payload = createMockPayload()
+ const transformedPayload = extension.transformPayload(payload)
+
+ expect(transformedPayload).toEqual({
+ max_tokens: 4096,
+ model: 'claude-v1',
+ stream: false,
+ system: 'Meow',
+ messages: [
+ { role: 'user', content: 'Hello' },
+ { role: 'assistant', content: 'Hi there' },
+ ],
+ })
+ })
+
+ it('should transform response correctly', () => {
+ const nonStreamResponse = { content: [{ text: 'Test response' }] }
+ const streamResponse =
+ 'data: {"type":"content_block_delta","delta":{"text":"Hello"}}'
+
+ expect(extension.transformResponse(nonStreamResponse)).toBe('Test response')
+ expect(extension.transformResponse(streamResponse)).toBe('Hello')
+ expect(extension.transformResponse('')).toBe('')
+ expect(extension.transformResponse('event: something')).toBe('')
+ })
+})
diff --git a/extensions/inference-anthropic-extension/src/index.ts b/extensions/inference-anthropic-extension/src/index.ts
index f28a584f2..94da26d94 100644
--- a/extensions/inference-anthropic-extension/src/index.ts
+++ b/extensions/inference-anthropic-extension/src/index.ts
@@ -13,7 +13,7 @@ import { ChatCompletionRole } from '@janhq/core'
declare const SETTINGS: Array
declare const MODELS: Array
-enum Settings {
+export enum Settings {
apiKey = 'anthropic-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
@@ -23,6 +23,7 @@ type AnthropicPayloadType = {
model?: string
max_tokens?: number
messages?: Array<{ role: string; content: string }>
+ system?: string
}
/**
@@ -113,6 +114,10 @@ export default class JanInferenceAnthropicExtension extends RemoteOAIEngine {
role: 'assistant',
content: item.content as string,
})
+ } else if (item.role === ChatCompletionRole.System) {
+ // When using Claude, you can dramatically improve its performance by using the system parameter to give it a role.
+ // This technique, known as role prompting, is the most powerful way to use system prompts with Claude.
+ convertedData.system = item.content as string
}
})
diff --git a/extensions/inference-anthropic-extension/tsconfig.json b/extensions/inference-anthropic-extension/tsconfig.json
index 2477d58ce..6db951c9e 100644
--- a/extensions/inference-anthropic-extension/tsconfig.json
+++ b/extensions/inference-anthropic-extension/tsconfig.json
@@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
- "include": ["./src"]
+ "include": ["./src"],
+ "exclude": ["**/*.test.ts"]
}
From 36c1306390df9655b07097a7c00175a21c95cfe9 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 24 Sep 2024 10:40:45 +0700
Subject: [PATCH 25/37] fix: #3515 - The default assistant instructions are
ignored (#3721)
---
web/hooks/useCreateNewThread.test.ts | 224 +++++++++++++++++++++++++++
web/hooks/useCreateNewThread.ts | 2 +-
2 files changed, 225 insertions(+), 1 deletion(-)
create mode 100644 web/hooks/useCreateNewThread.test.ts
diff --git a/web/hooks/useCreateNewThread.test.ts b/web/hooks/useCreateNewThread.test.ts
new file mode 100644
index 000000000..25589c098
--- /dev/null
+++ b/web/hooks/useCreateNewThread.test.ts
@@ -0,0 +1,224 @@
+// useCreateNewThread.test.ts
+import { renderHook, act } from '@testing-library/react'
+import { useCreateNewThread } from './useCreateNewThread'
+import { useAtomValue, useSetAtom } from 'jotai'
+import { useActiveModel } from './useActiveModel'
+import useRecommendedModel from './useRecommendedModel'
+import useSetActiveThread from './useSetActiveThread'
+import { extensionManager } from '@/extension'
+import { toaster } from '@/containers/Toast'
+
+// Mock the dependencies
+jest.mock('jotai', () => {
+ const originalModule = jest.requireActual('jotai')
+ return {
+ ...originalModule,
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ }
+})
+jest.mock('./useActiveModel')
+jest.mock('./useRecommendedModel')
+jest.mock('./useSetActiveThread')
+jest.mock('@/extension')
+jest.mock('@/containers/Toast')
+
+describe('useCreateNewThread', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should create a new thread', async () => {
+ const mockSetAtom = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAtom)
+ ;(useAtomValue as jest.Mock).mockReturnValue({
+ metadata: {},
+ assistants: [
+ {
+ id: 'assistant1',
+ name: 'Assistant 1',
+ instructions: undefined,
+ },
+ ],
+ })
+ ;(useActiveModel as jest.Mock).mockReturnValue({ stopInference: jest.fn() })
+ ;(useRecommendedModel as jest.Mock).mockReturnValue({
+ recommendedModel: { id: 'model1', parameters: [], settings: [] },
+ downloadedModels: [],
+ })
+ ;(useSetActiveThread as jest.Mock).mockReturnValue({
+ setActiveThread: jest.fn(),
+ })
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ saveThread: jest.fn(),
+ })
+
+ const { result } = renderHook(() => useCreateNewThread())
+
+ await act(async () => {
+ await result.current.requestCreateNewThread({
+ id: 'assistant1',
+ name: 'Assistant 1',
+ model: {
+ id: 'model1',
+ parameters: [],
+ settings: [],
+ },
+ } as any)
+ })
+
+ expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
+ expect(extensionManager.get).toHaveBeenCalled()
+ })
+
+ it('should create a new thread with instructions', async () => {
+ const mockSetAtom = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAtom)
+ ;(useAtomValue as jest.Mock).mockReturnValueOnce(false)
+ ;(useAtomValue as jest.Mock).mockReturnValue({
+ metadata: {},
+ assistants: [
+ {
+ id: 'assistant1',
+ name: 'Assistant 1',
+ instructions: 'Hello Jan',
+ },
+ ],
+ })
+ ;(useAtomValue as jest.Mock).mockReturnValueOnce(false)
+ ;(useActiveModel as jest.Mock).mockReturnValue({ stopInference: jest.fn() })
+ ;(useRecommendedModel as jest.Mock).mockReturnValue({
+ recommendedModel: { id: 'model1', parameters: [], settings: [] },
+ downloadedModels: [],
+ })
+ ;(useSetActiveThread as jest.Mock).mockReturnValue({
+ setActiveThread: jest.fn(),
+ })
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ saveThread: jest.fn(),
+ })
+
+ const { result } = renderHook(() => useCreateNewThread())
+
+ await act(async () => {
+ await result.current.requestCreateNewThread({
+ id: 'assistant1',
+ name: 'Assistant 1',
+ instructions: "Hello Jan Assistant",
+ model: {
+ id: 'model1',
+ parameters: [],
+ settings: [],
+ },
+ } as any)
+ })
+
+ expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
+ expect(extensionManager.get).toHaveBeenCalled()
+ expect(mockSetAtom).toHaveBeenNthCalledWith(
+ 2,
+ expect.objectContaining({
+ assistants: expect.arrayContaining([
+ expect.objectContaining({ instructions: 'Hello Jan Assistant' }),
+ ]),
+ })
+ )
+ })
+
+ it('should create a new thread with previous instructions', async () => {
+ const mockSetAtom = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAtom)
+ ;(useAtomValue as jest.Mock).mockReturnValueOnce(true)
+ ;(useAtomValue as jest.Mock).mockReturnValueOnce({
+ metadata: {},
+ assistants: [
+ {
+ id: 'assistant1',
+ name: 'Assistant 1',
+ instructions: 'Hello Jan',
+ },
+ ],
+ })
+ ;(useAtomValue as jest.Mock).mockReturnValueOnce(true)
+ ;(useActiveModel as jest.Mock).mockReturnValue({ stopInference: jest.fn() })
+ ;(useRecommendedModel as jest.Mock).mockReturnValue({
+ recommendedModel: { id: 'model1', parameters: [], settings: [] },
+ downloadedModels: [],
+ })
+ ;(useSetActiveThread as jest.Mock).mockReturnValue({
+ setActiveThread: jest.fn(),
+ })
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ saveThread: jest.fn(),
+ })
+
+ const { result } = renderHook(() => useCreateNewThread())
+
+ await act(async () => {
+ await result.current.requestCreateNewThread({
+ id: 'assistant1',
+ name: 'Assistant 1',
+ model: {
+ id: 'model1',
+ parameters: [],
+ settings: [],
+ },
+ } as any)
+ })
+
+ expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set
+ expect(extensionManager.get).toHaveBeenCalled()
+ expect(mockSetAtom).toHaveBeenNthCalledWith(
+ 2,
+ expect.objectContaining({
+ assistants: expect.arrayContaining([
+ expect.objectContaining({ instructions: 'Hello Jan' }),
+ ]),
+ })
+ )
+ })
+
+ it('should show a warning toast if trying to create an empty thread', async () => {
+ ;(useAtomValue as jest.Mock).mockReturnValue([{ metadata: {} }]) // Mock an empty thread
+ ;(useRecommendedModel as jest.Mock).mockReturnValue({
+ recommendedModel: null,
+ downloadedModels: [],
+ })
+
+ const { result } = renderHook(() => useCreateNewThread())
+
+ await act(async () => {
+ await result.current.requestCreateNewThread({
+ id: 'assistant1',
+ name: 'Assistant 1',
+ tools: [],
+ } as any)
+ })
+
+ expect(toaster).toHaveBeenCalledWith(
+ expect.objectContaining({
+ title: 'No new thread created.',
+ type: 'warning',
+ })
+ )
+ })
+
+ it('should update thread metadata', async () => {
+ const mockUpdateThread = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockUpdateThread)
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ saveThread: jest.fn(),
+ })
+
+ const { result } = renderHook(() => useCreateNewThread())
+
+ const mockThread = { id: 'thread1', title: 'Test Thread' }
+
+ await act(async () => {
+ await result.current.updateThreadMetadata(mockThread as any)
+ })
+
+ expect(mockUpdateThread).toHaveBeenCalledWith(mockThread)
+ expect(extensionManager.get).toHaveBeenCalled()
+ })
+})
diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts
index 5548259fd..e65353753 100644
--- a/web/hooks/useCreateNewThread.ts
+++ b/web/hooks/useCreateNewThread.ts
@@ -115,7 +115,7 @@ export const useCreateNewThread = () => {
: {}
const createdAt = Date.now()
- let instructions: string | undefined = undefined
+ let instructions: string | undefined = assistant.instructions
if (copyOverInstructionEnabled) {
instructions = activeThread?.assistants[0]?.instructions ?? undefined
}
From acd3be3a2a096446eae3a723a13ab61efbc9828e Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 24 Sep 2024 16:35:08 +0700
Subject: [PATCH 26/37] fix: #3698 - o1 preview models do not work with
max_tokens (#3728)
---
.../inference-openai-extension/jest.config.js | 9 ++++
.../src/OpenAIExtension.test.ts | 54 +++++++++++++++++++
.../inference-openai-extension/src/index.ts | 28 ++++++++--
.../inference-openai-extension/tsconfig.json | 3 +-
4 files changed, 90 insertions(+), 4 deletions(-)
create mode 100644 extensions/inference-openai-extension/jest.config.js
create mode 100644 extensions/inference-openai-extension/src/OpenAIExtension.test.ts
diff --git a/extensions/inference-openai-extension/jest.config.js b/extensions/inference-openai-extension/jest.config.js
new file mode 100644
index 000000000..3e32adceb
--- /dev/null
+++ b/extensions/inference-openai-extension/jest.config.js
@@ -0,0 +1,9 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ transform: {
+ 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
+ },
+ transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
+}
diff --git a/extensions/inference-openai-extension/src/OpenAIExtension.test.ts b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts
new file mode 100644
index 000000000..4d46bc007
--- /dev/null
+++ b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts
@@ -0,0 +1,54 @@
+/**
+ * @jest-environment jsdom
+ */
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core/node'),
+ RemoteOAIEngine: jest.fn().mockImplementation(() => ({
+ onLoad: jest.fn(),
+ registerSettings: jest.fn(),
+ registerModels: jest.fn(),
+ getSetting: jest.fn(),
+ onSettingUpdate: jest.fn(),
+ })),
+}))
+import JanInferenceOpenAIExtension, { Settings } from '.'
+
+describe('JanInferenceOpenAIExtension', () => {
+ let extension: JanInferenceOpenAIExtension
+
+ beforeEach(() => {
+ // @ts-ignore
+ extension = new JanInferenceOpenAIExtension()
+ })
+
+ it('should initialize with settings and models', async () => {
+ await extension.onLoad()
+ // Assuming there are some default SETTINGS and MODELS being registered
+ expect(extension.apiKey).toBe(undefined)
+ expect(extension.inferenceUrl).toBe('')
+ })
+
+ it('should transform the payload for preview models', () => {
+ const payload: any = {
+ max_tokens: 100,
+ model: 'o1-mini',
+ // Add other required properties...
+ }
+
+ const transformedPayload = extension.transformPayload(payload)
+ expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens)
+ expect(transformedPayload).not.toHaveProperty('max_tokens')
+ expect(transformedPayload).toHaveProperty('max_completion_tokens')
+ })
+
+ it('should not transform the payload for non-preview models', () => {
+ const payload: any = {
+ max_tokens: 100,
+ model: 'non-preview-model',
+ // Add other required properties...
+ }
+
+ const transformedPayload = extension.transformPayload(payload)
+ expect(transformedPayload).toEqual(payload)
+ })
+})
diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts
index 60446ccce..44c243adf 100644
--- a/extensions/inference-openai-extension/src/index.ts
+++ b/extensions/inference-openai-extension/src/index.ts
@@ -6,16 +6,17 @@
* @module inference-openai-extension/src/index
*/
-import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core'
+import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core'
declare const SETTINGS: Array
declare const MODELS: Array
-enum Settings {
+export enum Settings {
apiKey = 'openai-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
-
+type OpenAIPayloadType = PayloadType &
+ ModelRuntimeParams & { max_completion_tokens: number }
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
@@ -24,6 +25,7 @@ enum Settings {
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'openai'
+ previewModels = ['o1-mini', 'o1-preview']
override async onLoad(): Promise {
super.onLoad()
@@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
}
}
}
+
+ /**
+ * Tranform the payload before sending it to the inference endpoint.
+ * The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter.
+ * Others do not.
+ * @param payload
+ * @returns
+ */
+ transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => {
+ // Transform the payload for preview models
+ if (this.previewModels.includes(payload.model)) {
+ const { max_tokens, ...params } = payload
+ return {
+ ...params,
+ max_completion_tokens: max_tokens,
+ }
+ }
+ // Pass through for non-preview models
+ return payload
+ }
}
diff --git a/extensions/inference-openai-extension/tsconfig.json b/extensions/inference-openai-extension/tsconfig.json
index 2477d58ce..6db951c9e 100644
--- a/extensions/inference-openai-extension/tsconfig.json
+++ b/extensions/inference-openai-extension/tsconfig.json
@@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
- "include": ["./src"]
+ "include": ["./src"],
+ "exclude": ["**/*.test.ts"]
}
From 886b1cbc5484e1e61acdd6b0a63e72a9380df08e Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Tue, 24 Sep 2024 20:14:43 +0700
Subject: [PATCH 27/37] enhance: tabs component in model selection (#3730)
* ui: tabs-model-selection
* chore: updat tabs variant
* test: update test and render correct tab
---
joi/src/core/Tabs/Tabs.test.tsx | 16 ++++
joi/src/core/Tabs/index.tsx | 11 ++-
joi/src/core/Tabs/styles.scss | 24 +++++
web/containers/ModelDropdown/index.test.tsx | 101 ++++++++++++++++++++
web/containers/ModelDropdown/index.tsx | 63 ++++++------
5 files changed, 177 insertions(+), 38 deletions(-)
create mode 100644 web/containers/ModelDropdown/index.test.tsx
diff --git a/joi/src/core/Tabs/Tabs.test.tsx b/joi/src/core/Tabs/Tabs.test.tsx
index b6dcf8a7b..46bd48435 100644
--- a/joi/src/core/Tabs/Tabs.test.tsx
+++ b/joi/src/core/Tabs/Tabs.test.tsx
@@ -96,4 +96,20 @@ describe('@joi/core/Tabs', () => {
'Disabled tab'
)
})
+
+ it('applies the tabStyle if provided', () => {
+ render(
+ {}}
+ tabStyle="segmented"
+ />
+ )
+
+ const tabsContainer = screen.getByTestId('segmented-style')
+ expect(tabsContainer).toHaveClass('tabs')
+ expect(tabsContainer).toHaveClass('tabs--segmented')
+ })
})
diff --git a/joi/src/core/Tabs/index.tsx b/joi/src/core/Tabs/index.tsx
index af004e2ba..2dca19831 100644
--- a/joi/src/core/Tabs/index.tsx
+++ b/joi/src/core/Tabs/index.tsx
@@ -7,6 +7,8 @@ import { Tooltip } from '../Tooltip'
import './styles.scss'
import { twMerge } from 'tailwind-merge'
+type TabStyles = 'segmented'
+
type TabsProps = {
options: {
name: string
@@ -14,8 +16,10 @@ type TabsProps = {
disabled?: boolean
tooltipContent?: string
}[]
- children: ReactNode
+ children?: ReactNode
+
defaultValue?: string
+ tabStyle?: TabStyles
value: string
onValueChange?: (value: string) => void
}
@@ -40,15 +44,18 @@ const TabsContent = ({ value, children, className }: TabsContentProps) => {
const Tabs = ({
options,
children,
+ tabStyle,
defaultValue,
value,
onValueChange,
+ ...props
}: TabsProps) => (
{options.map((option, i) => {
diff --git a/joi/src/core/Tabs/styles.scss b/joi/src/core/Tabs/styles.scss
index a24585b4e..ce3df013b 100644
--- a/joi/src/core/Tabs/styles.scss
+++ b/joi/src/core/Tabs/styles.scss
@@ -3,6 +3,27 @@
flex-direction: column;
width: 100%;
+ &--segmented {
+ background-color: hsla(var(--secondary-bg));
+ border-radius: 6px;
+ height: 33px;
+
+ .tabs__list {
+ border: none;
+ justify-content: center;
+ align-items: center;
+ height: 33px;
+ }
+
+ .tabs__trigger[data-state='active'] {
+ background-color: hsla(var(--app-bg));
+ border: none;
+ height: 25px;
+ margin: 0 4px;
+ border-radius: 5px;
+ }
+ }
+
&__list {
flex-shrink: 0;
display: flex;
@@ -14,9 +35,11 @@
flex: 1;
height: 38px;
display: flex;
+ color: hsla(var(--text-secondary));
align-items: center;
justify-content: center;
line-height: 1;
+ font-weight: medium;
user-select: none;
&:focus {
position: relative;
@@ -38,4 +61,5 @@
.tabs__trigger[data-state='active'] {
border-bottom: 1px solid hsla(var(--primary-bg));
font-weight: 600;
+ color: hsla(var(--text-primary));
}
diff --git a/web/containers/ModelDropdown/index.test.tsx b/web/containers/ModelDropdown/index.test.tsx
new file mode 100644
index 000000000..7541f891b
--- /dev/null
+++ b/web/containers/ModelDropdown/index.test.tsx
@@ -0,0 +1,101 @@
+import { render, screen, waitFor } from '@testing-library/react'
+import { useAtomValue, useAtom, useSetAtom } from 'jotai'
+import ModelDropdown from './index'
+import useRecommendedModel from '@/hooks/useRecommendedModel'
+import '@testing-library/jest-dom'
+
+class ResizeObserverMock {
+ observe() {}
+ unobserve() {}
+ disconnect() {}
+}
+
+global.ResizeObserver = ResizeObserverMock
+
+jest.mock('jotai', () => {
+ const originalModule = jest.requireActual('jotai')
+ return {
+ ...originalModule,
+ useAtom: jest.fn(),
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ }
+})
+
+jest.mock('@/containers/ModelLabel')
+jest.mock('@/hooks/useRecommendedModel')
+
+describe('ModelDropdown', () => {
+ const remoteModel = {
+ metadata: { tags: ['Featured'], size: 100 },
+ name: 'Test Model',
+ engine: 'openai',
+ }
+
+ const localModel = {
+ metadata: { tags: ['Local'], size: 100 },
+ name: 'Local Model',
+ engine: 'nitro',
+ }
+
+ const configuredModels = [remoteModel, localModel]
+
+ const mockConfiguredModel = configuredModels
+ const selectedModel = { id: 'selectedModel', name: 'selectedModel' }
+ const setSelectedModel = jest.fn()
+ const showEngineListModel = ['nitro']
+ const showEngineListModelAtom = jest.fn()
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(useAtom as jest.Mock).mockReturnValue([selectedModel, setSelectedModel])
+ ;(useAtom as jest.Mock).mockReturnValue([
+ showEngineListModel,
+ showEngineListModelAtom,
+ ])
+ ;(useAtomValue as jest.Mock).mockReturnValue(mockConfiguredModel)
+ ;(useRecommendedModel as jest.Mock).mockReturnValue({
+ recommendedModel: { id: 'model1', parameters: [], settings: [] },
+ downloadedModels: [],
+ })
+ })
+
+ it('renders the ModelDropdown component', async () => {
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByTestId('model-selector')).toBeInTheDocument()
+ })
+ })
+
+ it('renders the ModelDropdown component as disabled', async () => {
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByTestId('model-selector')).toBeInTheDocument()
+ expect(screen.getByTestId('model-selector')).toHaveClass(
+ 'pointer-events-none'
+ )
+ })
+ })
+
+ it('renders the ModelDropdown component as badge for chat Input', async () => {
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByTestId('model-selector')).toBeInTheDocument()
+ expect(screen.getByTestId('model-selector-badge')).toBeInTheDocument()
+ expect(screen.getByTestId('model-selector-badge')).toHaveClass('badge')
+ })
+ })
+
+ it('renders the Tab correctly', async () => {
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByTestId('model-selector')).toBeInTheDocument()
+ expect(screen.getByText('On-device'))
+ expect(screen.getByText('Cloud'))
+ })
+ })
+})
diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx
index d8743ddce..2a0c4ffaf 100644
--- a/web/containers/ModelDropdown/index.tsx
+++ b/web/containers/ModelDropdown/index.tsx
@@ -8,7 +8,7 @@ import {
Button,
Input,
ScrollArea,
- Select,
+ Tabs,
useClickOutside,
} from '@janhq/joi'
@@ -70,8 +70,8 @@ const ModelDropdown = ({
strictedThread = true,
}: Props) => {
const { downloadModel } = useDownloadModel()
- const [searchFilter, setSearchFilter] = useState('all')
- const [filterOptionsOpen, setFilterOptionsOpen] = useState(false)
+
+ const [searchFilter, setSearchFilter] = useState('local')
const [searchText, setSearchText] = useState('')
const [open, setOpen] = useState(false)
const activeThread = useAtomValue(activeThreadAtom)
@@ -92,10 +92,7 @@ const ModelDropdown = ({
)
const { updateThreadMetadata } = useCreateNewThread()
- useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
- dropdownOptions,
- toggle,
- ])
+ useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
const [showEngineListModel, setShowEngineListModel] = useAtom(
showEngineListModelAtom
@@ -115,9 +112,6 @@ const ModelDropdown = ({
e.name.toLowerCase().includes(searchText.toLowerCase().trim())
)
.filter((e) => {
- if (searchFilter === 'all') {
- return e.engine
- }
if (searchFilter === 'local') {
return localEngines.includes(e.engine)
}
@@ -152,9 +146,9 @@ const ModelDropdown = ({
useEffect(() => {
if (!activeThread) return
- let model = downloadedModels.find(
- (model) => model.id === activeThread.assistants[0].model.id
- )
+ const modelId = activeThread?.assistants?.[0]?.model?.id
+
+ let model = downloadedModels.find((model) => model.id === modelId)
if (!model) {
model = recommendedModel
}
@@ -309,10 +303,14 @@ const ModelDropdown = ({
}
return (
-
+
{chatInputMode ? (
-
+
+ setSearchFilter(value)}
+ />
+
+
-
+
{groupByEngine.map((engine, i) => {
const apiKey = !localEngines.includes(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]
From dbc4bed40fbc6ca4facca0830d045cf4909ddffb Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 24 Sep 2024 20:26:06 +0700
Subject: [PATCH 28/37] fix: #3673 - API responds with Request body is too
large (#3729)
---
.../node/api/restful/helper/builder.test.ts | 41 +++++++++++++++++++
core/src/node/api/restful/helper/builder.ts | 6 ++-
server/index.ts | 5 +++
3 files changed, 50 insertions(+), 2 deletions(-)
diff --git a/core/src/node/api/restful/helper/builder.test.ts b/core/src/node/api/restful/helper/builder.test.ts
index fef40c70a..eb21e9401 100644
--- a/core/src/node/api/restful/helper/builder.test.ts
+++ b/core/src/node/api/restful/helper/builder.test.ts
@@ -236,6 +236,47 @@ describe('builder helper functions', () => {
})
})
+ it('should return the error on status not ok', async () => {
+ const request = { body: { model: 'model1' } }
+ const mockSend = jest.fn()
+ const reply = {
+ code: jest.fn().mockReturnThis(),
+ send: jest.fn(),
+ headers: jest.fn().mockReturnValue({
+ send: mockSend,
+ }),
+ raw: {
+ writeHead: jest.fn(),
+ pipe: jest.fn(),
+ },
+ }
+
+ ;(existsSync as jest.Mock).mockReturnValue(true)
+ ;(readdirSync as jest.Mock).mockReturnValue(['file1'])
+ ;(readFileSync as jest.Mock).mockReturnValue(
+ JSON.stringify({ id: 'model1', engine: 'openai' })
+ )
+
+ // Mock fetch
+ const fetch = require('node-fetch')
+ fetch.mockResolvedValue({
+ status: 400,
+ headers: new Map([
+ ['content-type', 'application/json'],
+ ['x-request-id', '123456'],
+ ]),
+ body: { pipe: jest.fn() },
+ text: jest.fn().mockResolvedValue({ error: 'Mock error response' }),
+ })
+ await chatCompletions(request, reply)
+ expect(reply.code).toHaveBeenCalledWith(400)
+ expect(mockSend).toHaveBeenCalledWith(
+ expect.objectContaining({
+ error: 'Mock error response',
+ })
+ )
+ })
+
it('should return the chat completions', async () => {
const request = { body: { model: 'model1' } }
const reply = {
diff --git a/core/src/node/api/restful/helper/builder.ts b/core/src/node/api/restful/helper/builder.ts
index 1a8120918..db2000d69 100644
--- a/core/src/node/api/restful/helper/builder.ts
+++ b/core/src/node/api/restful/helper/builder.ts
@@ -353,8 +353,10 @@ export const chatCompletions = async (request: any, reply: any) => {
body: JSON.stringify(request.body),
})
if (response.status !== 200) {
- console.error(response)
- reply.code(400).send(response)
+ // Forward the error response to client via reply
+ const responseBody = await response.text()
+ const responseHeaders = Object.fromEntries(response.headers)
+ reply.code(response.status).headers(responseHeaders).send(responseBody)
} else {
reply.raw.writeHead(200, {
'Content-Type': request.body.stream === true ? 'text/event-stream' : 'application/json',
diff --git a/server/index.ts b/server/index.ts
index f82c4f5bc..e8a6eea78 100644
--- a/server/index.ts
+++ b/server/index.ts
@@ -67,6 +67,11 @@ export const startServer = async (configs?: ServerConfig): Promise => {
// Initialize Fastify server with logging
server = fastify({
logger: new Logger(),
+ // Set body limit to 100MB - Default is 1MB
+ // According to OpenAI - a batch input file can be up to 100 MB in size
+ // Whisper endpoints accept up to 25MB
+ // Vision endpoints accept up to 4MB
+ bodyLimit: 104_857_600
})
// Register CORS if enabled
From f46ab45e0e9dd2885ca30e9262c3edde9d0329b6 Mon Sep 17 00:00:00 2001
From: Louis
Date: Wed, 25 Sep 2024 09:46:46 +0700
Subject: [PATCH 29/37] fix: #3727 LLM model download fail can still be used
(#3731)
* fix: #3727 - LLM model download fail can still be used
* test: add tests
* test: fix path on Windows
---
core/src/node/api/processors/download.test.ts | 158 +++++++++++++-----
core/src/node/api/processors/download.ts | 6 +-
2 files changed, 120 insertions(+), 44 deletions(-)
diff --git a/core/src/node/api/processors/download.test.ts b/core/src/node/api/processors/download.test.ts
index 1dc0eefb8..370f1746f 100644
--- a/core/src/node/api/processors/download.test.ts
+++ b/core/src/node/api/processors/download.test.ts
@@ -1,59 +1,131 @@
-import { Downloader } from './download';
-import { DownloadEvent } from '../../../types/api';
-import { DownloadManager } from '../../helper/download';
+import { Downloader } from './download'
+import { DownloadEvent } from '../../../types/api'
+import { DownloadManager } from '../../helper/download'
-it('should handle getFileSize errors correctly', async () => {
- const observer = jest.fn();
- const url = 'http://example.com/file';
+jest.mock('../../helper', () => ({
+ getJanDataFolderPath: jest.fn().mockReturnValue('path/to/folder'),
+}))
- const downloader = new Downloader(observer);
- const requestMock = jest.fn((options, callback) => {
- callback(new Error('Test error'), null);
- });
- jest.mock('request', () => requestMock);
+jest.mock('../../helper/path', () => ({
+ validatePath: jest.fn().mockReturnValue('path/to/folder'),
+ normalizeFilePath: () => process.platform === 'win32' ? 'C:\\Users\path\\to\\file.gguf' : '/Users/path/to/file.gguf',
+}))
- await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error');
-});
+jest.mock(
+ 'request',
+ jest.fn().mockReturnValue(() => ({
+ on: jest.fn(),
+ }))
+)
+jest.mock('fs', () => ({
+ createWriteStream: jest.fn(),
+}))
-it('should pause download correctly', () => {
- const observer = jest.fn();
- const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
+jest.mock('request-progress', () => {
+ return jest.fn().mockImplementation(() => {
+ return {
+ on: jest.fn().mockImplementation((event, callback) => {
+ if (event === 'error') {
+ callback(new Error('Download failed'))
+ }
+ return {
+ on: jest.fn().mockImplementation((event, callback) => {
+ if (event === 'error') {
+ callback(new Error('Download failed'))
+ }
+ return {
+ on: jest.fn().mockImplementation((event, callback) => {
+ if (event === 'error') {
+ callback(new Error('Download failed'))
+ }
+ return { pipe: jest.fn() }
+ }),
+ }
+ }),
+ }
+ }),
+ }
+ })
+})
- const downloader = new Downloader(observer);
- const pauseMock = jest.fn();
- DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock };
+describe('Downloader', () => {
+ beforeEach(() => {
+ jest.resetAllMocks()
+ })
+ it('should handle getFileSize errors correctly', async () => {
+ const observer = jest.fn()
+ const url = 'http://example.com/file'
- downloader.pauseDownload(observer, fileName);
+ const downloader = new Downloader(observer)
+ const requestMock = jest.fn((options, callback) => {
+ callback(new Error('Test error'), null)
+ })
+ jest.mock('request', () => requestMock)
- expect(pauseMock).toHaveBeenCalled();
-});
+ await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
+ })
-it('should resume download correctly', () => {
- const observer = jest.fn();
- const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
+ it('should pause download correctly', () => {
+ const observer = jest.fn()
+ const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
- const downloader = new Downloader(observer);
- const resumeMock = jest.fn();
- DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock };
+ const downloader = new Downloader(observer)
+ const pauseMock = jest.fn()
+ DownloadManager.instance.networkRequests[fileName] = { pause: pauseMock }
- downloader.resumeDownload(observer, fileName);
+ downloader.pauseDownload(observer, fileName)
- expect(resumeMock).toHaveBeenCalled();
-});
+ expect(pauseMock).toHaveBeenCalled()
+ })
-it('should handle aborting a download correctly', () => {
- const observer = jest.fn();
- const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file';
+ it('should resume download correctly', () => {
+ const observer = jest.fn()
+ const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
- const downloader = new Downloader(observer);
- const abortMock = jest.fn();
- DownloadManager.instance.networkRequests[fileName] = { abort: abortMock };
+ const downloader = new Downloader(observer)
+ const resumeMock = jest.fn()
+ DownloadManager.instance.networkRequests[fileName] = { resume: resumeMock }
- downloader.abortDownload(observer, fileName);
+ downloader.resumeDownload(observer, fileName)
- expect(abortMock).toHaveBeenCalled();
- expect(observer).toHaveBeenCalledWith(DownloadEvent.onFileDownloadError, expect.objectContaining({
- error: 'aborted'
- }));
-});
+ expect(resumeMock).toHaveBeenCalled()
+ })
+
+ it('should handle aborting a download correctly', () => {
+ const observer = jest.fn()
+ const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file'
+
+ const downloader = new Downloader(observer)
+ const abortMock = jest.fn()
+ DownloadManager.instance.networkRequests[fileName] = { abort: abortMock }
+
+ downloader.abortDownload(observer, fileName)
+
+ expect(abortMock).toHaveBeenCalled()
+ expect(observer).toHaveBeenCalledWith(
+ DownloadEvent.onFileDownloadError,
+ expect.objectContaining({
+ error: 'aborted',
+ })
+ )
+ })
+
+ it('should handle download fail correctly', () => {
+ const observer = jest.fn()
+ const fileName = process.platform === 'win32' ? 'C:\\path\\to\\file' : 'path/to/file.gguf'
+
+ const downloader = new Downloader(observer)
+
+ downloader.downloadFile(observer, {
+ localPath: fileName,
+ url: 'http://127.0.0.1',
+ })
+ expect(observer).toHaveBeenCalledWith(
+ DownloadEvent.onFileDownloadError,
+ expect.objectContaining({
+ error: expect.anything(),
+ })
+ )
+ })
+})
diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts
index 07486bdf8..21f7a6f1c 100644
--- a/core/src/node/api/processors/download.ts
+++ b/core/src/node/api/processors/download.ts
@@ -100,7 +100,11 @@ export class Downloader implements Processor {
})
.on('end', () => {
const currentDownloadState = DownloadManager.instance.downloadProgressMap[modelId]
- if (currentDownloadState && DownloadManager.instance.networkRequests[normalizedPath]) {
+ if (
+ currentDownloadState &&
+ DownloadManager.instance.networkRequests[normalizedPath] &&
+ DownloadManager.instance.downloadProgressMap[modelId]?.downloadState !== 'error'
+ ) {
// Finished downloading, rename temp file to actual file
renameSync(downloadingTempFile, destination)
const downloadState: DownloadState = {
From 7f08f0fa79301527d940aa2198978b1fbbe98bb1 Mon Sep 17 00:00:00 2001
From: Louis
Date: Thu, 26 Sep 2024 12:43:23 +0700
Subject: [PATCH 30/37] fix: #3703 - Deepseek-Coder-33B-Instruct is
incompatible (#3732)
---
extensions/inference-nitro-extension/package.json | 2 +-
.../resources/models/deepseek-coder-1.3b/model.json | 6 +++---
.../resources/models/deepseek-coder-34b/model.json | 10 +++++-----
3 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json
index ac3ed180a..f484b4511 100644
--- a/extensions/inference-nitro-extension/package.json
+++ b/extensions/inference-nitro-extension/package.json
@@ -1,7 +1,7 @@
{
"name": "@janhq/inference-cortex-extension",
"productName": "Cortex Inference Engine",
- "version": "1.0.17",
+ "version": "1.0.18",
"description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.",
"main": "dist/index.js",
"node": "dist/node/index.cjs.js",
diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json
index 36fceaad2..4d825cfeb 100644
--- a/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json
+++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-1.3b/model.json
@@ -8,7 +8,7 @@
"id": "deepseek-coder-1.3b",
"object": "model",
"name": "Deepseek Coder 1.3B Instruct Q8",
- "version": "1.3",
+ "version": "1.4",
"description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.",
"format": "gguf",
"settings": {
@@ -22,13 +22,13 @@
"top_p": 0.95,
"stream": true,
"max_tokens": 16384,
- "stop": [],
+ "stop": ["<|EOT|>"],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
"author": "Deepseek, The Bloke",
- "tags": ["Tiny", "Foundational Model"],
+ "tags": ["Tiny"],
"size": 1430000000
},
"engine": "nitro"
diff --git a/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json
index 103c4cbcb..e87d6a643 100644
--- a/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json
+++ b/extensions/inference-nitro-extension/resources/models/deepseek-coder-34b/model.json
@@ -2,13 +2,13 @@
"sources": [
{
"filename": "deepseek-coder-33b-instruct.Q4_K_M.gguf",
- "url": "https://huggingface.co/TheBloke/deepseek-coder-33B-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf"
+ "url": "https://huggingface.co/mradermacher/deepseek-coder-33b-instruct-GGUF/resolve/main/deepseek-coder-33b-instruct.Q4_K_M.gguf"
}
],
"id": "deepseek-coder-34b",
"object": "model",
"name": "Deepseek Coder 33B Instruct Q4",
- "version": "1.3",
+ "version": "1.4",
"description": "Deepseek Coder excelled in project-level code completion with advanced capabilities across multiple programming languages.",
"format": "gguf",
"settings": {
@@ -22,13 +22,13 @@
"top_p": 0.95,
"stream": true,
"max_tokens": 16384,
- "stop": [],
+ "stop": ["<|EOT|>"],
"frequency_penalty": 0,
"presence_penalty": 0
},
"metadata": {
- "author": "Deepseek, The Bloke",
- "tags": ["34B", "Foundational Model"],
+ "author": "Deepseek",
+ "tags": ["33B"],
"size": 19940000000
},
"engine": "nitro"
From 143f2f5c585329a9569bfeaf43f4b3e3c1aa4196 Mon Sep 17 00:00:00 2001
From: Louis
Date: Thu, 26 Sep 2024 12:43:34 +0700
Subject: [PATCH 31/37] fix: wrong model download location when there is a
mismatch model_id (#3733)
---
core/src/node/api/processors/download.ts | 2 +-
core/src/types/file/index.ts | 8 +
core/src/types/model/modelInterface.ts | 6 +-
extensions/model-extension/src/index.test.ts | 184 ++++++++++++++++++-
extensions/model-extension/src/index.ts | 17 +-
5 files changed, 202 insertions(+), 15 deletions(-)
diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts
index 21f7a6f1c..5db18a53a 100644
--- a/core/src/node/api/processors/download.ts
+++ b/core/src/node/api/processors/download.ts
@@ -34,7 +34,7 @@ export class Downloader implements Processor {
}
const array = normalizedPath.split(sep)
const fileName = array.pop() ?? ''
- const modelId = array.pop() ?? ''
+ const modelId = downloadRequest.modelId ?? array.pop() ?? ''
const destination = resolve(getJanDataFolderPath(), normalizedPath)
validatePath(destination)
diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts
index 4db956b1e..9f3e32b3e 100644
--- a/core/src/types/file/index.ts
+++ b/core/src/types/file/index.ts
@@ -40,6 +40,14 @@ export type DownloadRequest = {
*/
extensionId?: string
+ /**
+ * The model ID of the model that initiated the download.
+ */
+ modelId?: string
+
+ /**
+ * The download type.
+ */
downloadType?: DownloadType | string
}
diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts
index 5b5856231..08d456b7e 100644
--- a/core/src/types/model/modelInterface.ts
+++ b/core/src/types/model/modelInterface.ts
@@ -12,7 +12,7 @@ export interface ModelInterface {
* @returns A Promise that resolves when the model has been downloaded.
*/
downloadModel(
- model: Model,
+ model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise
@@ -35,11 +35,11 @@ export interface ModelInterface {
* Gets a list of downloaded models.
* @returns A Promise that resolves with an array of downloaded models.
*/
- getDownloadedModels(): Promise
+ getDownloadedModels(): Promise
/**
* Gets a list of configured models.
* @returns A Promise that resolves with an array of configured models.
*/
- getConfiguredModels(): Promise
+ getConfiguredModels(): Promise
}
diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts
index 823b3a41d..5b126d4cc 100644
--- a/extensions/model-extension/src/index.test.ts
+++ b/extensions/model-extension/src/index.test.ts
@@ -8,9 +8,14 @@ const downloadMock = jest.fn()
const mkdirMock = jest.fn()
const writeFileSyncMock = jest.fn()
const copyFileMock = jest.fn()
+const dirNameMock = jest.fn()
+const executeMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
+ events: {
+ emit: jest.fn(),
+ },
fs: {
existsSync: existMock,
readdirSync: readDirSyncMock,
@@ -22,12 +27,15 @@ jest.mock('@janhq/core', () => ({
isDirectory: false,
}),
},
- dirName: jest.fn(),
+ dirName: dirNameMock,
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
downloadFile: downloadMock,
+ executeOnMain: executeMock,
}))
+jest.mock('@huggingface/gguf')
+
global.fetch = jest.fn(() =>
Promise.resolve({
json: () => Promise.resolve({ test: 100 }),
@@ -37,8 +45,7 @@ global.fetch = jest.fn(() =>
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
-import { renderJinjaTemplate } from './node/index'
-import { Template } from '@huggingface/jinja'
+import { gguf } from '@huggingface/gguf'
describe('JanModelExtension', () => {
let sut: JanModelExtension
@@ -48,7 +55,7 @@ describe('JanModelExtension', () => {
sut = new JanModelExtension()
})
- afterEach(() => {
+ beforeEach(() => {
jest.clearAllMocks()
})
@@ -610,7 +617,172 @@ describe('JanModelExtension', () => {
).rejects.toBeTruthy()
})
-
+ it('should download corresponding ID', async () => {
+ existMock.mockImplementation(() => true)
+ dirNameMock.mockImplementation(() => 'file://models/model1')
+ downloadMock.mockImplementation(() => {
+ return Promise.resolve({})
+ })
+
+ expect(
+ await sut.downloadModel(
+ { ...model, file_path: 'file://models/model1/model.json' },
+ gpuSettings,
+ network
+ )
+ ).toBeUndefined()
+
+ expect(downloadMock).toHaveBeenCalledWith(
+ {
+ localPath: 'file://models/model1/model.gguf',
+ modelId: 'model-id',
+ url: 'http://example.com/model.gguf',
+ },
+ { ignoreSSL: true, proxy: 'http://proxy.example.com' }
+ )
+ })
+
+ it('should handle invalid model file', async () => {
+ executeMock.mockResolvedValue({})
+
+ fs.readFileSync = jest.fn(() => {
+ return JSON.stringify({ metadata: { author: 'user' } })
+ })
+
+ expect(
+ sut.downloadModel(
+ { ...model, file_path: 'file://models/model1/model.json' },
+ gpuSettings,
+ network
+ )
+ ).resolves.not.toThrow()
+
+ expect(downloadMock).not.toHaveBeenCalled()
+ })
+ it('should handle model file with no sources', async () => {
+ executeMock.mockResolvedValue({})
+ const modelWithoutSources = { ...model, sources: [] }
+
+ expect(
+ sut.downloadModel(
+ {
+ ...modelWithoutSources,
+ file_path: 'file://models/model1/model.json',
+ },
+ gpuSettings,
+ network
+ )
+ ).resolves.toBe(undefined)
+
+ expect(downloadMock).not.toHaveBeenCalled()
+ })
+
+ it('should handle model file with multiple sources', async () => {
+ const modelWithMultipleSources = {
+ ...model,
+ sources: [
+ { url: 'http://example.com/model1.gguf', filename: 'model1.gguf' },
+ { url: 'http://example.com/model2.gguf', filename: 'model2.gguf' },
+ ],
+ }
+
+ executeMock.mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ ;(gguf as jest.Mock).mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ // @ts-ignore
+ global.NODE = 'node'
+ // @ts-ignore
+ global.DEFAULT_MODEL = {
+ parameters: { stop: [] },
+ }
+ downloadMock.mockImplementation(() => {
+ return Promise.resolve({})
+ })
+
+ expect(
+ await sut.downloadModel(
+ {
+ ...modelWithMultipleSources,
+ file_path: 'file://models/model1/model.json',
+ },
+ gpuSettings,
+ network
+ )
+ ).toBeUndefined()
+
+ expect(downloadMock).toHaveBeenCalledWith(
+ {
+ localPath: 'file://models/model1/model1.gguf',
+ modelId: 'model-id',
+ url: 'http://example.com/model1.gguf',
+ },
+ { ignoreSSL: true, proxy: 'http://proxy.example.com' }
+ )
+
+ expect(downloadMock).toHaveBeenCalledWith(
+ {
+ localPath: 'file://models/model1/model2.gguf',
+ modelId: 'model-id',
+ url: 'http://example.com/model2.gguf',
+ },
+ { ignoreSSL: true, proxy: 'http://proxy.example.com' }
+ )
+ })
+
+ it('should handle model file with no file_path', async () => {
+ executeMock.mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ ;(gguf as jest.Mock).mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ // @ts-ignore
+ global.NODE = 'node'
+ // @ts-ignore
+ global.DEFAULT_MODEL = {
+ parameters: { stop: [] },
+ }
+ const modelWithoutFilepath = { ...model, file_path: undefined }
+
+ await sut.downloadModel(modelWithoutFilepath, gpuSettings, network)
+
+ expect(downloadMock).toHaveBeenCalledWith(
+ expect.objectContaining({
+ localPath: 'file://models/model-id/model.gguf',
+ }),
+ expect.anything()
+ )
+ })
+
+ it('should handle model file with invalid file_path', async () => {
+ executeMock.mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ ;(gguf as jest.Mock).mockResolvedValue({
+ metadata: { 'tokenizer.ggml.eos_token_id': 0 },
+ })
+ // @ts-ignore
+ global.NODE = 'node'
+ // @ts-ignore
+ global.DEFAULT_MODEL = {
+ parameters: { stop: [] },
+ }
+ const modelWithInvalidFilepath = {
+ ...model,
+ file_path: 'file://models/invalid-model.json',
+ }
+
+ await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network)
+
+ expect(downloadMock).toHaveBeenCalledWith(
+ expect.objectContaining({
+ localPath: 'file://models/model1/model.gguf',
+ }),
+ expect.anything()
+ )
+ })
})
-
})
diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts
index beb9f1fed..20d23b747 100644
--- a/extensions/model-extension/src/index.ts
+++ b/extensions/model-extension/src/index.ts
@@ -24,7 +24,6 @@ import {
ModelEvent,
ModelFile,
dirName,
- ModelSettingParams,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@@ -77,14 +76,15 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is downloaded.
*/
async downloadModel(
- model: Model,
+ model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise {
// Create corresponding directory
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
- const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
+ const modelJsonPath =
+ model.file_path ?? (await joinPath([modelDirPath, 'model.json']))
// Download HF model - model.json not exist
if (!(await fs.existsSync(modelJsonPath))) {
@@ -152,11 +152,15 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._supportedModelFormat
)
if (source.filename) {
- path = await joinPath([modelDirPath, source.filename])
+ path = model.file_path
+ ? await joinPath([await dirName(model.file_path), source.filename])
+ : await joinPath([modelDirPath, source.filename])
}
+
const downloadRequest: DownloadRequest = {
url: source.url,
localPath: path,
+ modelId: model.id,
}
downloadFile(downloadRequest, network)
}
@@ -166,10 +170,13 @@ export default class JanModelExtension extends ModelExtension {
model.sources[0]?.url,
JanModelExtension._supportedModelFormat
)
- const path = await joinPath([modelDirPath, fileName])
+ const path = model.file_path
+ ? await joinPath([await dirName(model.file_path), fileName])
+ : await joinPath([modelDirPath, fileName])
const downloadRequest: DownloadRequest = {
url: model.sources[0]?.url,
localPath: path,
+ modelId: model.id,
}
downloadFile(downloadRequest, network)
From cf0a232001c918f87d6d8dba6fd18d0b0ce4e19a Mon Sep 17 00:00:00 2001
From: hiento09 <136591877+hiento09@users.noreply.github.com>
Date: Thu, 26 Sep 2024 15:45:39 +0700
Subject: [PATCH 32/37] ci: auto trigger jan docs ci for new release (#3734)
Co-authored-by: Hien To
---
.github/workflows/auto-trigger-jan-docs.yaml | 25 ++++++++++++++++++++
1 file changed, 25 insertions(+)
create mode 100644 .github/workflows/auto-trigger-jan-docs.yaml
diff --git a/.github/workflows/auto-trigger-jan-docs.yaml b/.github/workflows/auto-trigger-jan-docs.yaml
new file mode 100644
index 000000000..a3001a9e0
--- /dev/null
+++ b/.github/workflows/auto-trigger-jan-docs.yaml
@@ -0,0 +1,25 @@
+name: Trigger Docs Workflow
+
+on:
+ release:
+ types:
+ - published
+ workflow_dispatch:
+ push:
+ branches:
+ - ci/auto-trigger-jan-docs-for-new-release
+
+jobs:
+ trigger_docs_workflow:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Trigger external workflow using GitHub API
+ env:
+ GITHUB_TOKEN: ${{ secrets.PAT_SERVICE_ACCOUNT }}
+ run: |
+ curl -X POST \
+ -H "Accept: application/vnd.github.v3+json" \
+ -H "Authorization: token $GITHUB_TOKEN" \
+ https://api.github.com/repos/janhq/docs/actions/workflows/jan-docs.yml/dispatches \
+ -d '{"ref":"main"}'
From 8334076047f83f1a5e77e7b288530ab2dc8e984a Mon Sep 17 00:00:00 2001
From: Louis
Date: Mon, 30 Sep 2024 11:58:46 +0700
Subject: [PATCH 33/37] fix: #3491 - Unable to use tensorrt-llm (#3741)
* fix: #3491 - Unable to use tensorrt-llm
* fix: abortModelDownload input type
---
.../tensorrt-llm-extension/jest.config.js | 9 +
.../tensorrt-llm-extension/package.json | 8 +-
.../tensorrt-llm-extension/rollup.config.ts | 4 +-
.../tensorrt-llm-extension/src/index.test.ts | 186 ++++++++++++++++++
.../tensorrt-llm-extension/src/index.ts | 3 +-
.../tensorrt-llm-extension/tsconfig.json | 3 +-
web/hooks/useDownloadModel.ts | 9 +-
7 files changed, 214 insertions(+), 8 deletions(-)
create mode 100644 extensions/tensorrt-llm-extension/jest.config.js
create mode 100644 extensions/tensorrt-llm-extension/src/index.test.ts
diff --git a/extensions/tensorrt-llm-extension/jest.config.js b/extensions/tensorrt-llm-extension/jest.config.js
new file mode 100644
index 000000000..3e32adceb
--- /dev/null
+++ b/extensions/tensorrt-llm-extension/jest.config.js
@@ -0,0 +1,9 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ transform: {
+ 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
+ },
+ transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
+}
diff --git a/extensions/tensorrt-llm-extension/package.json b/extensions/tensorrt-llm-extension/package.json
index c5cb54809..7a7ef6ef0 100644
--- a/extensions/tensorrt-llm-extension/package.json
+++ b/extensions/tensorrt-llm-extension/package.json
@@ -22,6 +22,7 @@
"tensorrtVersion": "0.1.8",
"provider": "nitro-tensorrt-llm",
"scripts": {
+ "test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
@@ -49,7 +50,12 @@
"rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6",
- "typescript": "^5.2.2"
+ "typescript": "^5.2.2",
+ "@types/jest": "^29.5.12",
+ "jest": "^29.7.0",
+ "jest-junit": "^16.0.0",
+ "jest-runner": "^29.7.0",
+ "ts-jest": "^29.2.5"
},
"dependencies": {
"@janhq/core": "file:../../core",
diff --git a/extensions/tensorrt-llm-extension/rollup.config.ts b/extensions/tensorrt-llm-extension/rollup.config.ts
index 1fad0e711..50b4350e7 100644
--- a/extensions/tensorrt-llm-extension/rollup.config.ts
+++ b/extensions/tensorrt-llm-extension/rollup.config.ts
@@ -23,10 +23,10 @@ export default [
DOWNLOAD_RUNNER_URL:
process.platform === 'win32'
? JSON.stringify(
- 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
+ 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
)
: JSON.stringify(
- 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz'
+ 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz'
),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
INFERENCE_URL: JSON.stringify(
diff --git a/extensions/tensorrt-llm-extension/src/index.test.ts b/extensions/tensorrt-llm-extension/src/index.test.ts
new file mode 100644
index 000000000..48d6e71d7
--- /dev/null
+++ b/extensions/tensorrt-llm-extension/src/index.test.ts
@@ -0,0 +1,186 @@
+import TensorRTLLMExtension from '../src/index'
+import {
+ executeOnMain,
+ systemInformation,
+ fs,
+ baseName,
+ joinPath,
+ downloadFile,
+} from '@janhq/core'
+
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core/node'),
+ LocalOAIEngine: jest.fn().mockImplementation(function () {
+ // @ts-ignore
+ this.registerModels = () => {
+ return Promise.resolve()
+ }
+ // @ts-ignore
+ return this
+ }),
+ systemInformation: jest.fn(),
+ fs: {
+ existsSync: jest.fn(),
+ mkdir: jest.fn(),
+ },
+ joinPath: jest.fn(),
+ baseName: jest.fn(),
+ downloadFile: jest.fn(),
+ executeOnMain: jest.fn(),
+ showToast: jest.fn(),
+ events: {
+ emit: jest.fn(),
+ // @ts-ignore
+ on: (event, func) => {
+ func({ fileName: './' })
+ },
+ off: jest.fn(),
+ },
+}))
+
+// @ts-ignore
+global.COMPATIBILITY = {
+ platform: ['win32'],
+}
+// @ts-ignore
+global.PROVIDER = 'tensorrt-llm'
+// @ts-ignore
+global.INFERENCE_URL = 'http://localhost:5000'
+// @ts-ignore
+global.NODE = 'node'
+// @ts-ignore
+global.MODELS = []
+// @ts-ignore
+global.TENSORRT_VERSION = ''
+// @ts-ignore
+global.DOWNLOAD_RUNNER_URL = ''
+
+describe('TensorRTLLMExtension', () => {
+ let extension: TensorRTLLMExtension
+
+ beforeEach(() => {
+ // @ts-ignore
+ extension = new TensorRTLLMExtension()
+ jest.clearAllMocks()
+ })
+
+ describe('compatibility', () => {
+ it('should return the correct compatibility', () => {
+ const result = extension.compatibility()
+ expect(result).toEqual({
+ platform: ['win32'],
+ })
+ })
+ })
+
+ describe('install', () => {
+ it('should install if compatible', async () => {
+ const mockSystemInfo: any = {
+ osInfo: { platform: 'win32' },
+ gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
+ }
+ ;(executeOnMain as jest.Mock).mockResolvedValue({})
+ ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(false)
+ ;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
+ ;(baseName as jest.Mock).mockResolvedValue('./')
+ ;(joinPath as jest.Mock).mockResolvedValue('./')
+ ;(downloadFile as jest.Mock).mockResolvedValue({})
+
+ await extension.install()
+
+ expect(executeOnMain).toHaveBeenCalled()
+ })
+
+ it('should not install if not compatible', async () => {
+ const mockSystemInfo: any = {
+ osInfo: { platform: 'linux' },
+ gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
+ }
+ ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
+
+ jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
+ await extension.install()
+
+ expect(executeOnMain).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('installationState', () => {
+ it('should return NotCompatible if not compatible', async () => {
+ const mockSystemInfo: any = {
+ osInfo: { platform: 'linux' },
+ gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
+ }
+ ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
+
+ const result = await extension.installationState()
+
+ expect(result).toBe('NotCompatible')
+ })
+
+ it('should return Installed if executable exists', async () => {
+ const mockSystemInfo: any = {
+ osInfo: { platform: 'win32' },
+ gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
+ }
+ ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(true)
+
+ const result = await extension.installationState()
+
+ expect(result).toBe('Installed')
+ })
+
+ it('should return NotInstalled if executable does not exist', async () => {
+ const mockSystemInfo: any = {
+ osInfo: { platform: 'win32' },
+ gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
+ }
+ ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(false)
+
+ const result = await extension.installationState()
+
+ expect(result).toBe('NotInstalled')
+ })
+ })
+
+ describe('isCompatible', () => {
+ it('should return true for compatible system', () => {
+ const mockInfo: any = {
+ osInfo: { platform: 'win32' },
+ gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
+ }
+
+ const result = extension.isCompatible(mockInfo)
+
+ expect(result).toBe(true)
+ })
+
+ it('should return false for incompatible system', () => {
+ const mockInfo: any = {
+ osInfo: { platform: 'linux' },
+ gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
+ }
+
+ const result = extension.isCompatible(mockInfo)
+
+ expect(result).toBe(false)
+ })
+ })
+})
+
+describe('GitHub Release File URL Test', () => {
+ const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';
+
+ it('should return a status code 200 for the release file URL', async () => {
+ const response = await fetch(url, { method: 'HEAD' });
+ expect(response.status).toBe(200);
+ });
+
+ it('should not return a 404 status', async () => {
+ const response = await fetch(url, { method: 'HEAD' });
+ expect(response.status).not.toBe(404);
+ });
+});
diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts
index 7f68c43bd..11c86a9a7 100644
--- a/extensions/tensorrt-llm-extension/src/index.ts
+++ b/extensions/tensorrt-llm-extension/src/index.ts
@@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
override nodeModule = NODE
private supportedGpuArch = ['ampere', 'ada']
- private supportedPlatform = ['win32', 'linux']
override compatibility() {
return COMPATIBILITY as unknown as Compatibility
@@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
!!info.gpuSetting &&
!!firstGpu &&
info.gpuSetting.gpus.length > 0 &&
- this.supportedPlatform.includes(info.osInfo.platform) &&
+ this.compatibility().platform.includes(info.osInfo.platform) &&
!!firstGpu.arch &&
firstGpu.name.toLowerCase().includes('nvidia') &&
this.supportedGpuArch.includes(firstGpu.arch)
diff --git a/extensions/tensorrt-llm-extension/tsconfig.json b/extensions/tensorrt-llm-extension/tsconfig.json
index 478a05728..be07e716c 100644
--- a/extensions/tensorrt-llm-extension/tsconfig.json
+++ b/extensions/tensorrt-llm-extension/tsconfig.json
@@ -16,5 +16,6 @@
"resolveJsonModule": true,
"typeRoots": ["node_modules/@types"]
},
- "include": ["src"]
+ "include": ["src"],
+ "exclude": ["**/*.test.ts"]
}
diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts
index d0d13d93b..0cd21ea83 100644
--- a/web/hooks/useDownloadModel.ts
+++ b/web/hooks/useDownloadModel.ts
@@ -9,6 +9,8 @@ import {
ModelArtifact,
DownloadState,
GpuSetting,
+ ModelFile,
+ dirName,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
@@ -91,9 +93,12 @@ export default function useDownloadModel() {
]
)
- const abortModelDownload = useCallback(async (model: Model) => {
+ const abortModelDownload = useCallback(async (model: Model | ModelFile) => {
for (const source of model.sources) {
- const path = await joinPath(['models', model.id, source.filename])
+ const path =
+ 'file_path' in model
+ ? await joinPath([await dirName(model.file_path), source.filename])
+ : await joinPath(['models', model.id, source.filename])
await abortDownload(path)
}
}, [])
From ba1ddacde3d18ac6546cb1366745a837ed9704b7 Mon Sep 17 00:00:00 2001
From: Louis
Date: Mon, 30 Sep 2024 11:58:55 +0700
Subject: [PATCH 34/37] fix: correct model dropdown for local models (#3736)
* fix: correct model dropdown for local models
* fix: clean unused import
* test: add Model.atom and model.Engine tests
---
.../SystemMonitor/TableActiveModel/index.tsx | 4 +-
web/containers/ModelDropdown/index.test.tsx | 22 +-
web/containers/ModelDropdown/index.tsx | 27 +-
web/containers/Providers/EventHandler.tsx | 6 +-
web/containers/SetupRemoteModel/index.tsx | 4 +-
web/helpers/atoms/Model.atom.test.ts | 298 ++++++++++++++++++
web/helpers/atoms/Model.atom.ts | 6 +-
web/hooks/useStarterScreen.ts | 4 +-
.../Settings/MyModels/MyModelList/index.tsx | 4 +-
web/screens/Settings/MyModels/index.tsx | 4 +-
.../ChatBody/OnDeviceStarterScreen/index.tsx | 8 +-
.../ThreadCenterPanel/ChatInput/index.tsx | 4 +-
web/screens/Thread/ThreadRightPanel/index.tsx | 4 +-
web/utils/modelEngine.test.ts | 185 +++++++++++
web/utils/modelEngine.ts | 22 +-
15 files changed, 545 insertions(+), 57 deletions(-)
create mode 100644 web/helpers/atoms/Model.atom.test.ts
create mode 100644 web/utils/modelEngine.test.ts
diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
index e68f843a9..5e8549c7f 100644
--- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
+++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx
@@ -6,7 +6,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
import { toGibibytes } from '@/utils/converter'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
@@ -35,7 +35,7 @@ const TableActiveModel = () => {
})}
- {activeModel && localEngines.includes(activeModel.engine) ? (
+ {activeModel && isLocalEngine(activeModel.engine) ? (
{
engine: 'nitro',
}
- const configuredModels = [remoteModel, localModel]
+ const configuredModels = [remoteModel, localModel, localModel]
const mockConfiguredModel = configuredModels
const selectedModel = { id: 'selectedModel', name: 'selectedModel' }
@@ -94,8 +94,20 @@ describe('ModelDropdown', () => {
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
- expect(screen.getByText('On-device'))
- expect(screen.getByText('Cloud'))
+ expect(screen.getByText('On-device')).toBeInTheDocument()
+ expect(screen.getByText('Cloud')).toBeInTheDocument()
+ })
+ })
+
+ it('filters models correctly', async () => {
+ render( )
+
+ await waitFor(() => {
+ expect(screen.getByTestId('model-selector')).toBeInTheDocument()
+ fireEvent.click(screen.getByText('Cloud'))
+ fireEvent.change(screen.getByText('Cloud'), {
+ target: { value: 'remote' },
+ })
})
})
})
diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx
index 2a0c4ffaf..9ebcf4fa2 100644
--- a/web/containers/ModelDropdown/index.tsx
+++ b/web/containers/ModelDropdown/index.tsx
@@ -40,7 +40,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
import {
getLogoEngine,
getTitleByEngine,
- localEngines,
+ isLocalEngine,
priorityEngine,
} from '@/utils/modelEngine'
@@ -101,7 +101,7 @@ const ModelDropdown = ({
const isModelSupportRagAndTools = useCallback((model: Model) => {
return (
model?.engine === InferenceEngine.openai ||
- localEngines.includes(model?.engine as InferenceEngine)
+ isLocalEngine(model?.engine as InferenceEngine)
)
}, [])
@@ -113,10 +113,10 @@ const ModelDropdown = ({
)
.filter((e) => {
if (searchFilter === 'local') {
- return localEngines.includes(e.engine)
+ return isLocalEngine(e.engine)
}
if (searchFilter === 'remote') {
- return !localEngines.includes(e.engine)
+ return !isLocalEngine(e.engine)
}
})
.sort((a, b) => a.name.localeCompare(b.name))
@@ -236,7 +236,6 @@ const ModelDropdown = ({
for (const extension of extensions) {
if (typeof extension.getSettings === 'function') {
const settings = await extension.getSettings()
-
if (
(settings && settings.length > 0) ||
(await extension.installationState()) !== 'NotRequired'
@@ -295,7 +294,7 @@ const ModelDropdown = ({
}, [setShowEngineListModel, extensionHasSettings])
const isDownloadALocalModel = downloadedModels.some((x) =>
- localEngines.includes(x.engine)
+ isLocalEngine(x.engine)
)
if (strictedThread && !activeThread) {
@@ -377,7 +376,7 @@ const ModelDropdown = ({
{groupByEngine.map((engine, i) => {
- const apiKey = !localEngines.includes(engine)
+ const apiKey = !isLocalEngine(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]
?.apiKey.length > 1
: true
@@ -417,7 +416,7 @@ const ModelDropdown = ({
- {!localEngines.includes(engine) && (
+ {!isLocalEngine(engine) && (
)}
{!showModel ? (
@@ -438,7 +437,7 @@ const ModelDropdown = ({
- {engine === InferenceEngine.nitro &&
+ {isLocalEngine(engine) &&
!isDownloadALocalModel &&
showModel &&
!searchText.length && (
@@ -503,10 +502,7 @@ const ModelDropdown = ({
{filteredDownloadedModels
.filter((x) => x.engine === engine)
.filter((y) => {
- if (
- localEngines.includes(y.engine) &&
- !searchText.length
- ) {
+ if (isLocalEngine(y.engine) && !searchText.length) {
return downloadedModels.find((c) => c.id === y.id)
} else {
return y
@@ -530,10 +526,7 @@ const ModelDropdown = ({
: 'text-[hsla(var(--text-primary))]'
)}
onClick={() => {
- if (
- !apiKey &&
- !localEngines.includes(model.engine)
- )
+ if (!apiKey && !isLocalEngine(model.engine))
return null
if (isdDownloaded) {
onClickModelItem(model.id)
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index 1fbcd3919..5cc92219c 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -21,7 +21,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
@@ -242,9 +242,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
}
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
- if (
- !localEngines.includes(activeModelRef.current?.engine as InferenceEngine)
- ) {
+ if (!isLocalEngine(activeModelRef.current?.engine as InferenceEngine)) {
const updatedThread: Thread = {
...thread,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
diff --git a/web/containers/SetupRemoteModel/index.tsx b/web/containers/SetupRemoteModel/index.tsx
index 914f240de..1f5478d73 100644
--- a/web/containers/SetupRemoteModel/index.tsx
+++ b/web/containers/SetupRemoteModel/index.tsx
@@ -8,7 +8,7 @@ import { SettingsIcon, PlusIcon } from 'lucide-react'
import { MainViewState } from '@/constants/screens'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
@@ -74,7 +74,7 @@ const SetupRemoteModel = ({ engine }: Props) => {
)
}
- const apiKey = !localEngines.includes(engine)
+ const apiKey = !isLocalEngine(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]?.apiKey
.length > 1
: true
diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts
new file mode 100644
index 000000000..36f2ce71c
--- /dev/null
+++ b/web/helpers/atoms/Model.atom.test.ts
@@ -0,0 +1,298 @@
+import { act, renderHook, waitFor } from '@testing-library/react'
+import * as ModelAtoms from './Model.atom'
+import { useAtom, useAtomValue, useSetAtom } from 'jotai'
+
+describe('Model.atom.ts', () => {
+ let mockJotaiGet: jest.Mock
+ let mockJotaiSet: jest.Mock
+
+ beforeEach(() => {
+ mockJotaiGet = jest.fn()
+ mockJotaiSet = jest.fn()
+ })
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('stateModel', () => {
+ it('should initialize with correct default values', () => {
+ expect(ModelAtoms.stateModel.init).toEqual({
+ state: 'start',
+ loading: false,
+ model: '',
+ })
+ })
+ })
+ describe('activeAssistantModelAtom', () => {
+ it('should initialize as undefined', () => {
+ expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined()
+ })
+ })
+
+ describe('selectedModelAtom', () => {
+ it('should initialize as undefined', () => {
+ expect(ModelAtoms.selectedModelAtom.init).toBeUndefined()
+ })
+ })
+
+ describe('showEngineListModelAtom', () => {
+ it('should initialize as an empty array', () => {
+ expect(ModelAtoms.showEngineListModelAtom.init).toEqual([])
+ })
+ })
+
+ describe('addDownloadingModelAtom', () => {
+ it('should add downloading model', async () => {
+ const { result: setAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.addDownloadingModelAtom)
+ )
+ const { result: getAtom } = renderHook(() =>
+ useAtomValue(ModelAtoms.getDownloadingModelAtom)
+ )
+ act(() => {
+ setAtom.current({ id: '1' } as any)
+ })
+ expect(getAtom.current).toEqual([{ id: '1' }])
+ })
+ })
+
+ describe('removeDownloadingModelAtom', () => {
+ it('should remove downloading model', async () => {
+ const { result: setAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.addDownloadingModelAtom)
+ )
+ const { result: removeAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.removeDownloadingModelAtom)
+ )
+ const { result: getAtom } = renderHook(() =>
+ useAtomValue(ModelAtoms.getDownloadingModelAtom)
+ )
+ act(() => {
+ setAtom.current({ id: '1' } as any)
+ removeAtom.current('1')
+ })
+ expect(getAtom.current).toEqual([])
+ })
+ })
+
+ describe('removeDownloadedModelAtom', () => {
+ it('should remove downloaded model', async () => {
+ const { result: setAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.downloadedModelsAtom)
+ )
+ const { result: removeAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.removeDownloadedModelAtom)
+ )
+ const { result: getAtom } = renderHook(() =>
+ useAtomValue(ModelAtoms.downloadedModelsAtom)
+ )
+ act(() => {
+ setAtom.current([{ id: '1' }] as any)
+ })
+ expect(getAtom.current).toEqual([
+ {
+ id: '1',
+ },
+ ])
+ act(() => {
+ removeAtom.current('1')
+ })
+ expect(getAtom.current).toEqual([])
+ })
+ })
+
+ describe('importingModelAtom', () => {
+ afterEach(() => {
+ jest.resetAllMocks()
+ jest.clearAllMocks()
+ })
+ it('should not update for non-existing import', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: updateAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
+ )
+ act(() => {
+ importAtom.current[1]([])
+ updateAtom.current('2', 50)
+ })
+ expect(importAtom.current[0]).toEqual([])
+ })
+ it('should update progress for existing import', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: updateAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([
+ { importId: '1', status: 'MODEL_SELECTED' },
+ ] as any)
+ updateAtom.current('1', 50)
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'IMPORTING',
+ percentage: 50,
+ },
+ ])
+ })
+
+ it('should not update with invalid data', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: updateAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([
+ { importId: '1', status: 'MODEL_SELECTED' },
+ ] as any)
+ updateAtom.current('2', 50)
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'MODEL_SELECTED',
+ },
+ ])
+ })
+ it('should update import error', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: errorAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.setImportingModelErrorAtom)
+ )
+ act(() => {
+ importAtom.current[1]([
+ { importId: '1', status: 'IMPORTING', percentage: 50 },
+ ] as any)
+ errorAtom.current('1', 'unknown')
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'FAILED',
+ percentage: 50,
+ },
+ ])
+ })
+ it('should not update import error on invalid import ID', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: errorAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.setImportingModelErrorAtom)
+ )
+ act(() => {
+ importAtom.current[1]([
+ { importId: '1', status: 'IMPORTING', percentage: 50 },
+ ] as any)
+ errorAtom.current('2', 'unknown')
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'IMPORTING',
+ percentage: 50,
+ },
+ ])
+ })
+
+ it('should update import success', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: successAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
+ successAtom.current('1', 'id')
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'IMPORTED',
+ percentage: 1,
+ modelId: 'id',
+ },
+ ])
+ })
+
+ it('should update with invalid import ID', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: successAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
+ successAtom.current('2', 'id')
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ status: 'IMPORTING',
+ },
+ ])
+ })
+ it('should not update with valid data', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: updateAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.updateImportingModelAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([
+ { importId: '1', status: 'IMPORTING', percentage: 1 },
+ ] as any)
+ updateAtom.current('1', 'name', 'description', ['tag'])
+ })
+ expect(importAtom.current[0]).toEqual([
+ {
+ importId: '1',
+ percentage: 1,
+ status: 'IMPORTING',
+ name: 'name',
+ tags: ['tag'],
+ description: 'description',
+ },
+ ])
+ })
+
+ it('should not update when there is no importing model', async () => {
+ const { result: importAtom } = renderHook(() =>
+ useAtom(ModelAtoms.importingModelsAtom)
+ )
+ const { result: updateAtom } = renderHook(() =>
+ useSetAtom(ModelAtoms.updateImportingModelAtom)
+ )
+
+ act(() => {
+ importAtom.current[1]([])
+ updateAtom.current('1', 'name', 'description', ['tag'])
+ })
+ expect(importAtom.current[0]).toEqual([])
+ })
+ })
+
+ describe('defaultModelAtom', () => {
+ it('should initialize as undefined', () => {
+ expect(ModelAtoms.defaultModelAtom.init).toBeUndefined()
+ })
+ })
+})
diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts
index d2d0ca9f4..28a6384eb 100644
--- a/web/helpers/atoms/Model.atom.ts
+++ b/web/helpers/atoms/Model.atom.ts
@@ -1,8 +1,6 @@
-import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
+import { ImportingModel, Model, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
-import { localEngines } from '@/utils/modelEngine'
-
export const stateModel = atom({ state: 'start', loading: false, model: '' })
export const activeAssistantModelAtom = atom(undefined)
@@ -135,4 +133,4 @@ export const updateImportingModelAtom = atom(
export const selectedModelAtom = atom(undefined)
-export const showEngineListModelAtom = atom(localEngines)
+export const showEngineListModelAtom = atom([])
diff --git a/web/hooks/useStarterScreen.ts b/web/hooks/useStarterScreen.ts
index 1a6bbfbc7..3305c0072 100644
--- a/web/hooks/useStarterScreen.ts
+++ b/web/hooks/useStarterScreen.ts
@@ -2,7 +2,7 @@ import { useState, useEffect } from 'react'
import { useAtomValue } from 'jotai'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
@@ -13,7 +13,7 @@ export function useStarterScreen() {
const threads = useAtomValue(threadsAtom)
const isDownloadALocalModel = downloadedModels.some((x) =>
- localEngines.includes(x.engine)
+ isLocalEngine(x.engine)
)
const [extensionHasSettings, setExtensionHasSettings] = useState<
diff --git a/web/screens/Settings/MyModels/MyModelList/index.tsx b/web/screens/Settings/MyModels/MyModelList/index.tsx
index 329248923..c9ca6e867 100644
--- a/web/screens/Settings/MyModels/MyModelList/index.tsx
+++ b/web/screens/Settings/MyModels/MyModelList/index.tsx
@@ -16,7 +16,7 @@ import useDeleteModel from '@/hooks/useDeleteModel'
import { toGibibytes } from '@/utils/converter'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
@@ -74,7 +74,7 @@ const MyModelList = ({ model }: Props) => {
- {localEngines.includes(model.engine) && (
+ {isLocalEngine(model.engine) && (
diff --git a/web/screens/Settings/MyModels/index.tsx b/web/screens/Settings/MyModels/index.tsx
index 8dafd6e20..547e6153b 100644
--- a/web/screens/Settings/MyModels/index.tsx
+++ b/web/screens/Settings/MyModels/index.tsx
@@ -29,7 +29,7 @@ import { setImportModelStageAtom } from '@/hooks/useImportModel'
import {
getLogoEngine,
getTitleByEngine,
- localEngines,
+ isLocalEngine,
priorityEngine,
} from '@/utils/modelEngine'
@@ -222,7 +222,7 @@ const MyModels = () => {
- {!localEngines.includes(engine) && (
+ {!isLocalEngine(engine) && (
)}
{!showModel ? (
diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
index 26036a627..b1e9d081a 100644
--- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
@@ -28,7 +28,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
import {
getLogoEngine,
getTitleByEngine,
- localEngines,
+ isLocalEngine,
} from '@/utils/modelEngine'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
@@ -74,13 +74,11 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
}
})
- const remoteModel = configuredModels.filter(
- (x) => !localEngines.includes(x.engine)
- )
+ const remoteModel = configuredModels.filter((x) => !isLocalEngine(x.engine))
const filteredModels = configuredModels.filter((model) => {
return (
- localEngines.includes(model.engine) &&
+ isLocalEngine(model.engine) &&
model.name.toLowerCase().includes(searchValue.toLowerCase())
)
})
diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
index 235ebeae6..a7c5ad121 100644
--- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
@@ -24,7 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
import useSendChatMessage from '@/hooks/useSendChatMessage'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import FileUploadPreview from '../FileUploadPreview'
import ImageUploadPreview from '../ImageUploadPreview'
@@ -130,7 +130,7 @@ const ChatInput = () => {
const isModelSupportRagAndTools =
selectedModel?.engine === InferenceEngine.openai ||
- localEngines.includes(selectedModel?.engine as InferenceEngine)
+ isLocalEngine(selectedModel?.engine as InferenceEngine)
/**
* Handles the change event of the extension file input element by setting the file name state.
diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx
index 78119ba6d..027d1b0b6 100644
--- a/web/screens/Thread/ThreadRightPanel/index.tsx
+++ b/web/screens/Thread/ThreadRightPanel/index.tsx
@@ -28,7 +28,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
-import { localEngines } from '@/utils/modelEngine'
+import { isLocalEngine } from '@/utils/modelEngine'
import {
extractInferenceParams,
extractModelLoadParams,
@@ -63,7 +63,7 @@ const ThreadRightPanel = () => {
const isModelSupportRagAndTools =
selectedModel?.engine === InferenceEngine.openai ||
- localEngines.includes(selectedModel?.engine as InferenceEngine)
+ isLocalEngine(selectedModel?.engine as InferenceEngine)
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const { stopModel } = useActiveModel()
diff --git a/web/utils/modelEngine.test.ts b/web/utils/modelEngine.test.ts
new file mode 100644
index 000000000..738e04c2a
--- /dev/null
+++ b/web/utils/modelEngine.test.ts
@@ -0,0 +1,185 @@
+import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
+import {
+ getTitleByEngine,
+ isLocalEngine,
+ priorityEngine,
+ getLogoEngine,
+} from './modelEngine'
+
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core'),
+ EngineManager: {
+ instance: jest.fn().mockReturnValue({
+ get: jest.fn(),
+ }),
+ },
+}))
+
+describe('isLocalEngine', () => {
+ const mockEngineManagerInstance = EngineManager.instance()
+ const mockGet = mockEngineManagerInstance.get as jest.Mock
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should return false if engine is not found', () => {
+ mockGet.mockReturnValue(null)
+ const result = isLocalEngine('nonexistentEngine')
+ expect(result).toBe(false)
+ })
+
+ it('should return true if engine is an instance of LocalOAIEngine', () => {
+ const mockEngineObj = {
+ __proto__: {
+ constructor: {
+ __proto__: {
+ name: LocalOAIEngine.name,
+ },
+ },
+ },
+ }
+ mockGet.mockReturnValue(mockEngineObj)
+ const result = isLocalEngine('localEngine')
+ expect(result).toBe(true)
+ })
+
+ it('should return false if engine is not an instance of LocalOAIEngine', () => {
+ const mockEngineObj = {
+ __proto__: {
+ constructor: {
+ __proto__: {
+ name: 'SomeOtherEngine',
+ },
+ },
+ },
+ }
+ mockGet.mockReturnValue(mockEngineObj)
+ const result = isLocalEngine('someOtherEngine')
+ expect(result).toBe(false)
+ })
+
+ jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core'),
+ EngineManager: {
+ instance: jest.fn().mockReturnValue({
+ get: jest.fn(),
+ }),
+ },
+ }))
+
+ describe('getTitleByEngine', () => {
+ it('should return correct title for InferenceEngine.nitro', () => {
+ const result = getTitleByEngine(InferenceEngine.nitro)
+ expect(result).toBe('Llama.cpp (Nitro)')
+ })
+
+ it('should return correct title for InferenceEngine.nitro_tensorrt_llm', () => {
+ const result = getTitleByEngine(InferenceEngine.nitro_tensorrt_llm)
+ expect(result).toBe('TensorRT-LLM (Nitro)')
+ })
+
+ it('should return correct title for InferenceEngine.cortex_llamacpp', () => {
+ const result = getTitleByEngine(InferenceEngine.cortex_llamacpp)
+ expect(result).toBe('Llama.cpp (Cortex)')
+ })
+
+ it('should return correct title for InferenceEngine.cortex_onnx', () => {
+ const result = getTitleByEngine(InferenceEngine.cortex_onnx)
+ expect(result).toBe('Onnx (Cortex)')
+ })
+
+ it('should return correct title for InferenceEngine.cortex_tensorrtllm', () => {
+ const result = getTitleByEngine(InferenceEngine.cortex_tensorrtllm)
+ expect(result).toBe('TensorRT-LLM (Cortex)')
+ })
+
+ it('should return correct title for InferenceEngine.openai', () => {
+ const result = getTitleByEngine(InferenceEngine.openai)
+ expect(result).toBe('OpenAI')
+ })
+
+ it('should return correct title for InferenceEngine.openrouter', () => {
+ const result = getTitleByEngine(InferenceEngine.openrouter)
+ expect(result).toBe('OpenRouter')
+ })
+
+ it('should return capitalized engine name for unknown engine', () => {
+ const result = getTitleByEngine('unknownEngine' as InferenceEngine)
+ expect(result).toBe('UnknownEngine')
+ })
+ })
+
+ describe('priorityEngine', () => {
+ it('should contain the correct engines in the correct order', () => {
+ expect(priorityEngine).toEqual([
+ InferenceEngine.cortex_llamacpp,
+ InferenceEngine.cortex_onnx,
+ InferenceEngine.cortex_tensorrtllm,
+ InferenceEngine.nitro,
+ ])
+ })
+ })
+
+ describe('getLogoEngine', () => {
+ it('should return correct logo path for InferenceEngine.anthropic', () => {
+ const result = getLogoEngine(InferenceEngine.anthropic)
+ expect(result).toBe('images/ModelProvider/anthropic.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.nitro_tensorrt_llm', () => {
+ const result = getLogoEngine(InferenceEngine.nitro_tensorrt_llm)
+ expect(result).toBe('images/ModelProvider/nitro.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.cortex_llamacpp', () => {
+ const result = getLogoEngine(InferenceEngine.cortex_llamacpp)
+ expect(result).toBe('images/ModelProvider/cortex.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.mistral', () => {
+ const result = getLogoEngine(InferenceEngine.mistral)
+ expect(result).toBe('images/ModelProvider/mistral.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.martian', () => {
+ const result = getLogoEngine(InferenceEngine.martian)
+ expect(result).toBe('images/ModelProvider/martian.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.openrouter', () => {
+ const result = getLogoEngine(InferenceEngine.openrouter)
+ expect(result).toBe('images/ModelProvider/openRouter.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.openai', () => {
+ const result = getLogoEngine(InferenceEngine.openai)
+ expect(result).toBe('images/ModelProvider/openai.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.groq', () => {
+ const result = getLogoEngine(InferenceEngine.groq)
+ expect(result).toBe('images/ModelProvider/groq.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.triton_trtllm', () => {
+ const result = getLogoEngine(InferenceEngine.triton_trtllm)
+ expect(result).toBe('images/ModelProvider/triton_trtllm.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.cohere', () => {
+ const result = getLogoEngine(InferenceEngine.cohere)
+ expect(result).toBe('images/ModelProvider/cohere.svg')
+ })
+
+ it('should return correct logo path for InferenceEngine.nvidia', () => {
+ const result = getLogoEngine(InferenceEngine.nvidia)
+ expect(result).toBe('images/ModelProvider/nvidia.svg')
+ })
+
+ it('should return undefined for unknown engine', () => {
+ const result = getLogoEngine('unknownEngine' as InferenceEngine)
+ expect(result).toBeUndefined()
+ })
+ })
+})
diff --git a/web/utils/modelEngine.ts b/web/utils/modelEngine.ts
index 3d132c5d5..33b3ec3e1 100644
--- a/web/utils/modelEngine.ts
+++ b/web/utils/modelEngine.ts
@@ -1,4 +1,4 @@
-import { InferenceEngine } from '@janhq/core'
+import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
export const getLogoEngine = (engine: InferenceEngine) => {
switch (engine) {
@@ -32,13 +32,19 @@ export const getLogoEngine = (engine: InferenceEngine) => {
}
}
-export const localEngines = [
- InferenceEngine.nitro,
- InferenceEngine.nitro_tensorrt_llm,
- InferenceEngine.cortex_llamacpp,
- InferenceEngine.cortex_onnx,
- InferenceEngine.cortex_tensorrtllm,
-]
+/**
+ * Check whether the engine is conform to LocalOAIEngine
+ * @param engine
+ * @returns
+ */
+export const isLocalEngine = (engine: string) => {
+ const engineObj = EngineManager.instance().get(engine)
+ if (!engineObj) return false
+ return (
+ Object.getPrototypeOf(engineObj).constructor.__proto__.name ===
+ LocalOAIEngine.name
+ )
+}
export const getTitleByEngine = (engine: InferenceEngine) => {
switch (engine) {
From 7c63914e64885b4570c475e1b1c653a78faeda76 Mon Sep 17 00:00:00 2001
From: Louis
Date: Mon, 30 Sep 2024 14:34:27 +0700
Subject: [PATCH 35/37] fix: model dropdown should show recommended models to
download (#3742)
---
web/helpers/atoms/Model.atom.test.ts | 2 +-
web/helpers/atoms/Model.atom.ts | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts
index 36f2ce71c..4ab02cad9 100644
--- a/web/helpers/atoms/Model.atom.test.ts
+++ b/web/helpers/atoms/Model.atom.test.ts
@@ -38,7 +38,7 @@ describe('Model.atom.ts', () => {
describe('showEngineListModelAtom', () => {
it('should initialize as an empty array', () => {
- expect(ModelAtoms.showEngineListModelAtom.init).toEqual([])
+ expect(ModelAtoms.showEngineListModelAtom.init).toEqual(['nitro'])
})
})
diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts
index 28a6384eb..c817ee74b 100644
--- a/web/helpers/atoms/Model.atom.ts
+++ b/web/helpers/atoms/Model.atom.ts
@@ -1,4 +1,4 @@
-import { ImportingModel, Model, ModelFile } from '@janhq/core'
+import { ImportingModel, InferenceEngine, Model, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
export const stateModel = atom({ state: 'start', loading: false, model: '' })
@@ -133,4 +133,4 @@ export const updateImportingModelAtom = atom(
export const selectedModelAtom = atom(undefined)
-export const showEngineListModelAtom = atom([])
+export const showEngineListModelAtom = atom([InferenceEngine.nitro])
From 87a8bc7359e86416671d0d09834bbb8e4c77b776 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Mon, 30 Sep 2024 15:30:02 +0700
Subject: [PATCH 36/37] fix: xml not render correctly (#3743)
---
.../ThreadCenterPanel/SimpleTextMessage/index.tsx | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
index abbe6db43..da10300dc 100644
--- a/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
@@ -53,6 +53,15 @@ const SimpleTextMessage: React.FC = (props) => {
const clipboard = useClipboard({ timeout: 1000 })
+ function escapeHtml(html: string): string {
+ return html
+ .replace(/&/g, '&')
+ .replace(//g, '>')
+ .replace(/"/g, '"')
+ .replace(/'/g, ''')
+ }
+
const marked: Marked = new Marked(
markedHighlight({
langPrefix: 'hljs',
@@ -69,6 +78,9 @@ const SimpleTextMessage: React.FC = (props) => {
}),
{
renderer: {
+ html: (html: string) => {
+ return escapeHtml(html) // Escape any HTML
+ },
link: (href, title, text) => {
return Renderer.prototype.link
?.apply(this, [href, title, text])
From 87e1754e3af92a48cb6e41fb547da20cccec821c Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 1 Oct 2024 10:15:30 +0700
Subject: [PATCH 37/37] chore: improve models and threads caching (#3744)
* chore: managing and maintaining models and threads in the cache
* test: add tests for hooks
---
web/containers/Layout/RibbonPanel/index.tsx | 6 +-
web/containers/ModelDropdown/index.tsx | 10 +-
web/helpers/atoms/BottomPanel.atom.ts | 0
web/helpers/atoms/Extension.atom.test.ts | 78 ++++++++
web/helpers/atoms/Model.atom.test.ts | 7 +-
web/helpers/atoms/Model.atom.ts | 95 ++++++++--
web/helpers/atoms/SystemBar.atom.test.ts | 146 +++++++++++++++
web/helpers/atoms/Thread.atom.test.ts | 187 +++++++++++++++++++
web/helpers/atoms/Thread.atom.ts | 157 +++++++++++-----
web/hooks/useAssistant.test.ts | 95 ++++++++++
web/hooks/useClipboard.test.ts | 105 +++++++++++
web/hooks/useDeleteModel.test.ts | 73 ++++++++
web/hooks/useDeleteThread.test.ts | 106 +++++++++++
web/hooks/useDownloadModel.test.ts | 98 ++++++++++
web/hooks/useDropModelBinaries.test.ts | 129 +++++++++++++
web/hooks/useFactoryReset.test.ts | 89 +++++++++
web/hooks/useGetHFRepoData.test.ts | 39 ++++
web/hooks/useGetSystemResources.test.ts | 103 +++++++++++
web/hooks/useGpuSetting.test.ts | 87 +++++++++
web/hooks/useImportModel.test.ts | 70 +++++++
web/hooks/useLoadTheme.test.ts | 111 +++++++++++
web/hooks/useLogs.test.ts | 103 +++++++++++
web/hooks/useModels.test.ts | 61 +++++++
web/hooks/useModels.ts | 9 +
web/hooks/useSetActiveThread.ts | 2 +-
web/hooks/useThread.test.ts | 192 ++++++++++++++++++++
web/hooks/useThreads.ts | 2 +-
web/hooks/useUpdateModelParameters.ts | 2 +-
web/types/model.d.ts | 4 +
web/utils/modelParam.ts | 2 +-
30 files changed, 2084 insertions(+), 84 deletions(-)
delete mode 100644 web/helpers/atoms/BottomPanel.atom.ts
create mode 100644 web/helpers/atoms/Extension.atom.test.ts
create mode 100644 web/helpers/atoms/SystemBar.atom.test.ts
create mode 100644 web/helpers/atoms/Thread.atom.test.ts
create mode 100644 web/hooks/useAssistant.test.ts
create mode 100644 web/hooks/useClipboard.test.ts
create mode 100644 web/hooks/useDeleteModel.test.ts
create mode 100644 web/hooks/useDeleteThread.test.ts
create mode 100644 web/hooks/useDownloadModel.test.ts
create mode 100644 web/hooks/useDropModelBinaries.test.ts
create mode 100644 web/hooks/useFactoryReset.test.ts
create mode 100644 web/hooks/useGetHFRepoData.test.ts
create mode 100644 web/hooks/useGetSystemResources.test.ts
create mode 100644 web/hooks/useGpuSetting.test.ts
create mode 100644 web/hooks/useImportModel.test.ts
create mode 100644 web/hooks/useLoadTheme.test.ts
create mode 100644 web/hooks/useLogs.test.ts
create mode 100644 web/hooks/useModels.test.ts
create mode 100644 web/hooks/useThread.test.ts
create mode 100644 web/types/model.d.ts
diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx
index 7613584e0..2eb1bad70 100644
--- a/web/containers/Layout/RibbonPanel/index.tsx
+++ b/web/containers/Layout/RibbonPanel/index.tsx
@@ -16,14 +16,12 @@ import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
+import { isDownloadALocalModelAtom } from '@/helpers/atoms/Model.atom'
import {
reduceTransparentAtom,
selectedSettingAtom,
} from '@/helpers/atoms/Setting.atom'
-import {
- isDownloadALocalModelAtom,
- threadsAtom,
-} from '@/helpers/atoms/Thread.atom'
+import { threadsAtom } from '@/helpers/atoms/Thread.atom'
export default function RibbonPanel() {
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx
index 9ebcf4fa2..192c18131 100644
--- a/web/containers/ModelDropdown/index.tsx
+++ b/web/containers/ModelDropdown/index.tsx
@@ -513,7 +513,7 @@ const ModelDropdown = ({
const isDownloading = downloadingModels.some(
(md) => md.id === model.id
)
- const isdDownloaded = downloadedModels.some(
+ const isDownloaded = downloadedModels.some(
(c) => c.id === model.id
)
return (
@@ -528,7 +528,7 @@ const ModelDropdown = ({
onClick={() => {
if (!apiKey && !isLocalEngine(model.engine))
return null
- if (isdDownloaded) {
+ if (isDownloaded) {
onClickModelItem(model.id)
}
}}
@@ -537,7 +537,7 @@ const ModelDropdown = ({
- {!isdDownloaded && (
+ {!isDownloaded && (
{toGibibytes(model.metadata.size)}
)}
- {!isDownloading && !isdDownloaded ? (
+ {!isDownloading && !isDownloaded ? (
{
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('installingExtensionAtom', () => {
+ it('should initialize as an empty array', () => {
+ const { result } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))
+ expect(result.current).toEqual([])
+ })
+ })
+
+ describe('setInstallingExtensionAtom', () => {
+ it('should add a new installing extension', () => {
+ const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
+ const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))
+
+ act(() => {
+ setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
+ })
+
+ expect(getAtom.current).toEqual([{ extensionId: 'ext1', percentage: 0 }])
+ })
+
+ it('should update an existing installing extension', () => {
+ const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
+ const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))
+
+ act(() => {
+ setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
+ setAtom.current('ext1', { extensionId: 'ext1', percentage: 50 })
+ })
+
+ expect(getAtom.current).toEqual([{ extensionId: 'ext1', percentage: 50 }])
+ })
+ })
+
+ describe('removeInstallingExtensionAtom', () => {
+ it('should remove an installing extension', () => {
+ const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
+ const { result: removeAtom } = renderHook(() => useSetAtom(ExtensionAtoms.removeInstallingExtensionAtom))
+ const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))
+
+ act(() => {
+ setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
+ setAtom.current('ext2', { extensionId: 'ext2', percentage: 50 })
+ removeAtom.current('ext1')
+ })
+
+ expect(getAtom.current).toEqual([{ extensionId: 'ext2', percentage: 50 }])
+ })
+ })
+
+ describe('inActiveEngineProviderAtom', () => {
+ it('should initialize as an empty array', () => {
+ const { result } = renderHook(() => useAtomValue(ExtensionAtoms.inActiveEngineProviderAtom))
+ expect(result.current).toEqual([])
+ })
+
+ it('should persist value in storage', () => {
+ const { result } = renderHook(() => useAtom(ExtensionAtoms.inActiveEngineProviderAtom))
+
+ act(() => {
+ result.current[1](['provider1', 'provider2'])
+ })
+
+ // Simulate a re-render to check if the value persists
+ const { result: newResult } = renderHook(() => useAtomValue(ExtensionAtoms.inActiveEngineProviderAtom))
+ expect(newResult.current).toEqual(['provider1', 'provider2'])
+ })
+ })
+})
diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts
index 4ab02cad9..57827efec 100644
--- a/web/helpers/atoms/Model.atom.test.ts
+++ b/web/helpers/atoms/Model.atom.test.ts
@@ -1,4 +1,4 @@
-import { act, renderHook, waitFor } from '@testing-library/react'
+import { act, renderHook } from '@testing-library/react'
import * as ModelAtoms from './Model.atom'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
@@ -24,11 +24,6 @@ describe('Model.atom.ts', () => {
})
})
})
- describe('activeAssistantModelAtom', () => {
- it('should initialize as undefined', () => {
- expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined()
- })
- })
describe('selectedModelAtom', () => {
it('should initialize as undefined', () => {
diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts
index c817ee74b..6abc42c9e 100644
--- a/web/helpers/atoms/Model.atom.ts
+++ b/web/helpers/atoms/Model.atom.ts
@@ -1,8 +1,59 @@
import { ImportingModel, InferenceEngine, Model, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
+import { atomWithStorage } from 'jotai/utils'
+
+/**
+ * Enum for the keys used to store models in the local storage.
+ */
+enum ModelStorageAtomKeys {
+ DownloadedModels = 'downloadedModels',
+ AvailableModels = 'availableModels',
+}
+//// Models Atom
+/**
+ * Downloaded Models Atom
+ * This atom stores the list of models that have been downloaded.
+ */
+export const downloadedModelsAtom = atomWithStorage(
+ ModelStorageAtomKeys.DownloadedModels,
+ []
+)
+
+/**
+ * Configured Models Atom
+ * This atom stores the list of models that have been configured and available to download
+ */
+export const configuredModelsAtom = atomWithStorage(
+ ModelStorageAtomKeys.AvailableModels,
+ []
+)
+
+export const removeDownloadedModelAtom = atom(
+ null,
+ (get, set, modelId: string) => {
+ const downloadedModels = get(downloadedModelsAtom)
+
+ set(
+ downloadedModelsAtom,
+ downloadedModels.filter((e) => e.id !== modelId)
+ )
+ }
+)
+
+/**
+ * Atom to store the selected model (from ModelDropdown)
+ */
+export const selectedModelAtom = atom(undefined)
+
+/**
+ * Atom to store the expanded engine sections (from ModelDropdown)
+ */
+export const showEngineListModelAtom = atom([InferenceEngine.nitro])
+
+/// End Models Atom
+/// Model Download Atom
export const stateModel = atom({ state: 'start', loading: false, model: '' })
-export const activeAssistantModelAtom = atom(undefined)
/**
* Stores the list of models which are being downloaded.
@@ -30,28 +81,20 @@ export const removeDownloadingModelAtom = atom(
}
)
-export const downloadedModelsAtom = atom([])
-
-export const removeDownloadedModelAtom = atom(
- null,
- (get, set, modelId: string) => {
- const downloadedModels = get(downloadedModelsAtom)
-
- set(
- downloadedModelsAtom,
- downloadedModels.filter((e) => e.id !== modelId)
- )
- }
-)
-
-export const configuredModelsAtom = atom([])
-
-export const defaultModelAtom = atom(undefined)
+/// End Model Download Atom
+/// Model Import Atom
/// TODO: move this part to another atom
// store the paths of the models that are being imported
export const importingModelsAtom = atom([])
+// DEPRECATED: Remove when moving to cortex.cpp
+// Default model template when importing
+export const defaultModelAtom = atom(undefined)
+
+/**
+ * Importing progress Atom
+ */
export const updateImportingModelProgressAtom = atom(
null,
(get, set, importId: string, percentage: number) => {
@@ -69,6 +112,9 @@ export const updateImportingModelProgressAtom = atom(
}
)
+/**
+ * Importing error Atom
+ */
export const setImportingModelErrorAtom = atom(
null,
(get, set, importId: string, error: string) => {
@@ -87,6 +133,9 @@ export const setImportingModelErrorAtom = atom(
}
)
+/**
+ * Importing success Atom
+ */
export const setImportingModelSuccessAtom = atom(
null,
(get, set, importId: string, modelId: string) => {
@@ -105,6 +154,9 @@ export const setImportingModelSuccessAtom = atom(
}
)
+/**
+ * Update importing model metadata Atom
+ */
export const updateImportingModelAtom = atom(
null,
(
@@ -131,6 +183,9 @@ export const updateImportingModelAtom = atom(
}
)
-export const selectedModelAtom = atom(undefined)
+/// End Model Import Atom
-export const showEngineListModelAtom = atom([InferenceEngine.nitro])
+/// ModelDropdown States Atom
+export const isDownloadALocalModelAtom = atom(false)
+export const isAnyRemoteModelConfiguredAtom = atom(false)
+/// End ModelDropdown States Atom
diff --git a/web/helpers/atoms/SystemBar.atom.test.ts b/web/helpers/atoms/SystemBar.atom.test.ts
new file mode 100644
index 000000000..57a7c2ada
--- /dev/null
+++ b/web/helpers/atoms/SystemBar.atom.test.ts
@@ -0,0 +1,146 @@
+import { renderHook, act } from '@testing-library/react'
+import { useAtom } from 'jotai'
+import * as SystemBarAtoms from './SystemBar.atom'
+
+describe('SystemBar.atom.ts', () => {
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('totalRamAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.totalRamAtom))
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.totalRamAtom))
+ act(() => {
+ result.current[1](16384)
+ })
+ expect(result.current[0]).toBe(16384)
+ })
+ })
+
+ describe('usedRamAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.usedRamAtom))
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.usedRamAtom))
+ act(() => {
+ result.current[1](8192)
+ })
+ expect(result.current[0]).toBe(8192)
+ })
+ })
+
+ describe('cpuUsageAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.cpuUsageAtom))
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.cpuUsageAtom))
+ act(() => {
+ result.current[1](50)
+ })
+ expect(result.current[0]).toBe(50)
+ })
+ })
+
+ describe('ramUtilitizedAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.ramUtilitizedAtom)
+ )
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.ramUtilitizedAtom)
+ )
+ act(() => {
+ result.current[1](75)
+ })
+ expect(result.current[0]).toBe(75)
+ })
+ })
+
+ describe('gpusAtom', () => {
+ it('should initialize as an empty array', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.gpusAtom))
+ expect(result.current[0]).toEqual([])
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() => useAtom(SystemBarAtoms.gpusAtom))
+ const gpus = [{ id: 'gpu1' }, { id: 'gpu2' }]
+ act(() => {
+ result.current[1](gpus as any)
+ })
+ expect(result.current[0]).toEqual(gpus)
+ })
+ })
+
+ describe('nvidiaTotalVramAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.nvidiaTotalVramAtom)
+ )
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.nvidiaTotalVramAtom)
+ )
+ act(() => {
+ result.current[1](8192)
+ })
+ expect(result.current[0]).toBe(8192)
+ })
+ })
+
+ describe('availableVramAtom', () => {
+ it('should initialize as 0', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.availableVramAtom)
+ )
+ expect(result.current[0]).toBe(0)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.availableVramAtom)
+ )
+ act(() => {
+ result.current[1](4096)
+ })
+ expect(result.current[0]).toBe(4096)
+ })
+ })
+
+ describe('systemMonitorCollapseAtom', () => {
+ it('should initialize as false', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.systemMonitorCollapseAtom)
+ )
+ expect(result.current[0]).toBe(false)
+ })
+
+ it('should update correctly', () => {
+ const { result } = renderHook(() =>
+ useAtom(SystemBarAtoms.systemMonitorCollapseAtom)
+ )
+ act(() => {
+ result.current[1](true)
+ })
+ expect(result.current[0]).toBe(true)
+ })
+ })
+})
diff --git a/web/helpers/atoms/Thread.atom.test.ts b/web/helpers/atoms/Thread.atom.test.ts
new file mode 100644
index 000000000..cc88dd66e
--- /dev/null
+++ b/web/helpers/atoms/Thread.atom.test.ts
@@ -0,0 +1,187 @@
+// Thread.atom.test.ts
+
+import { act, renderHook } from '@testing-library/react'
+import * as ThreadAtoms from './Thread.atom'
+import { useAtom, useAtomValue, useSetAtom } from 'jotai'
+
+describe('Thread.atom.ts', () => {
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('threadStatesAtom', () => {
+ it('should initialize as an empty object', () => {
+ const { result: threadStatesAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadsAtom)
+ )
+ expect(threadStatesAtom.current[0]).toEqual([])
+ })
+ })
+
+ describe('threadsAtom', () => {
+ it('should initialize as an empty array', () => {
+ const { result: threadsAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadsAtom)
+ )
+ expect(threadsAtom.current[0]).toEqual([])
+ })
+ })
+
+ describe('threadDataReadyAtom', () => {
+ it('should initialize as false', () => {
+ const { result: threadDataReadyAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadsAtom)
+ )
+ expect(threadDataReadyAtom.current[0]).toEqual([])
+ })
+ })
+
+ describe('activeThreadIdAtom', () => {
+ it('should set and get active thread id', () => {
+ const { result: getAtom } = renderHook(() =>
+ useAtomValue(ThreadAtoms.getActiveThreadIdAtom)
+ )
+ const { result: setAtom } = renderHook(() =>
+ useSetAtom(ThreadAtoms.setActiveThreadIdAtom)
+ )
+
+ expect(getAtom.current).toBeUndefined()
+
+ act(() => {
+ setAtom.current('thread-1')
+ })
+
+ expect(getAtom.current).toBe('thread-1')
+ })
+ })
+
+ describe('activeThreadAtom', () => {
+ it('should return the active thread', () => {
+ const { result: threadsAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadsAtom)
+ )
+ const { result: setActiveThreadId } = renderHook(() =>
+ useSetAtom(ThreadAtoms.setActiveThreadIdAtom)
+ )
+ const { result: activeThread } = renderHook(() =>
+ useAtomValue(ThreadAtoms.activeThreadAtom)
+ )
+
+ act(() => {
+ threadsAtom.current[1]([
+ { id: 'thread-1', title: 'Test Thread' },
+ ] as any)
+ setActiveThreadId.current('thread-1')
+ })
+
+ expect(activeThread.current).toEqual({
+ id: 'thread-1',
+ title: 'Test Thread',
+ })
+ })
+ })
+
+ describe('updateThreadAtom', () => {
+ it('should update an existing thread', () => {
+ const { result: threadsAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadsAtom)
+ )
+ const { result: updateThread } = renderHook(() =>
+ useSetAtom(ThreadAtoms.updateThreadAtom)
+ )
+
+ act(() => {
+ threadsAtom.current[1]([
+ {
+ id: 'thread-1',
+ title: 'Old Title',
+ updated: new Date('2023-01-01').toISOString(),
+ },
+ {
+ id: 'thread-2',
+ title: 'Thread 2',
+ updated: new Date('2023-01-02').toISOString(),
+ },
+ ] as any)
+ })
+
+ act(() => {
+ updateThread.current({
+ id: 'thread-1',
+ title: 'New Title',
+ updated: new Date('2023-01-03').toISOString(),
+ } as any)
+ })
+
+ expect(threadsAtom.current[0]).toEqual([
+ {
+ id: 'thread-1',
+ title: 'New Title',
+ updated: new Date('2023-01-03').toISOString(),
+ },
+ {
+ id: 'thread-2',
+ title: 'Thread 2',
+ updated: new Date('2023-01-02').toISOString(),
+ },
+ ])
+ })
+ })
+
+ describe('setThreadModelParamsAtom', () => {
+ it('should set thread model params', () => {
+ const { result: paramsAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadModelParamsAtom)
+ )
+ const { result: setParams } = renderHook(() =>
+ useSetAtom(ThreadAtoms.setThreadModelParamsAtom)
+ )
+
+ act(() => {
+ setParams.current('thread-1', { modelName: 'gpt-3' } as any)
+ })
+
+ expect(paramsAtom.current[0]).toEqual({
+ 'thread-1': { modelName: 'gpt-3' },
+ })
+ })
+ })
+
+ describe('deleteThreadStateAtom', () => {
+ it('should delete a thread state', () => {
+ const { result: statesAtom } = renderHook(() =>
+ useAtom(ThreadAtoms.threadStatesAtom)
+ )
+ const { result: deleteState } = renderHook(() =>
+ useSetAtom(ThreadAtoms.deleteThreadStateAtom)
+ )
+
+ act(() => {
+ statesAtom.current[1]({
+ 'thread-1': { lastMessage: 'Hello' },
+ 'thread-2': { lastMessage: 'Hi' },
+ } as any)
+ })
+
+ act(() => {
+ deleteState.current('thread-1')
+ })
+
+ expect(statesAtom.current[0]).toEqual({
+ 'thread-2': { lastMessage: 'Hi' },
+ })
+ })
+ })
+
+ describe('modalActionThreadAtom', () => {
+ it('should initialize with undefined values', () => {
+ const { result } = renderHook(() =>
+ useAtomValue(ThreadAtoms.modalActionThreadAtom)
+ )
+ expect(result.current).toEqual({
+ showModal: undefined,
+ thread: undefined,
+ })
+ })
+ })
+})
diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts
index 6e94c9e17..1945fea45 100644
--- a/web/helpers/atoms/Thread.atom.ts
+++ b/web/helpers/atoms/Thread.atom.ts
@@ -1,45 +1,91 @@
-import {
- ModelRuntimeParams,
- ModelSettingParams,
- Thread,
- ThreadContent,
- ThreadState,
-} from '@janhq/core'
+import { Thread, ThreadContent, ThreadState } from '@janhq/core'
import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
+import { ModelParams } from '@/types/model'
+
+/**
+ * Thread Modal Action Enum
+ */
export enum ThreadModalAction {
Clean = 'clean',
Delete = 'delete',
EditTitle = 'edit-title',
}
-export const engineParamsUpdateAtom = atom(false)
+const ACTIVE_SETTING_INPUT_BOX = 'activeSettingInputBox'
+/**
+ * Enum for the keys used to store models in the local storage.
+ */
+enum ThreadStorageAtomKeys {
+ ThreadStates = 'threadStates',
+ ThreadList = 'threadList',
+ ThreadListReady = 'threadListReady',
+}
+
+//// Threads Atom
+/**
+ * Stores all thread states for the current user
+ */
+export const threadStatesAtom = atomWithStorage>(
+ ThreadStorageAtomKeys.ThreadStates,
+ {}
+)
+
+/**
+ * Stores all threads for the current user
+ */
+export const threadsAtom = atomWithStorage(
+ ThreadStorageAtomKeys.ThreadList,
+ []
+)
+
+/**
+ * Whether thread data is ready or not
+ * */
+export const threadDataReadyAtom = atomWithStorage(
+ ThreadStorageAtomKeys.ThreadListReady,
+ false
+)
+
+/**
+ * Store model params at thread level settings
+ */
+export const threadModelParamsAtom = atom>({})
+
+//// End Thread Atom
+
+/// Active Thread Atom
/**
* Stores the current active thread id.
*/
const activeThreadIdAtom = atom(undefined)
+/**
+ * Get the active thread id
+ */
export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom))
+/**
+ * Set the active thread id
+ */
export const setActiveThreadIdAtom = atom(
null,
(_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId)
)
-export const waitingToSendMessage = atom(undefined)
-
-export const isGeneratingResponseAtom = atom(undefined)
/**
- * Stores all thread states for the current user
+ * Get the current active thread metadata
*/
-export const threadStatesAtom = atom>({})
-
-// Whether thread data is ready or not
-export const threadDataReadyAtom = atom(false)
+export const activeThreadAtom = atom((get) =>
+ get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
+)
+/**
+ * Get the active thread state
+ */
export const activeThreadStateAtom = atom((get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
@@ -50,6 +96,38 @@ export const activeThreadStateAtom = atom((get) => {
return get(threadStatesAtom)[threadId]
})
+/**
+ * Get the active thread model params
+ */
+export const getActiveThreadModelParamsAtom = atom(
+ (get) => {
+ const threadId = get(activeThreadIdAtom)
+ if (!threadId) {
+ console.debug('Active thread id is undefined')
+ return undefined
+ }
+
+ return get(threadModelParamsAtom)[threadId]
+ }
+)
+/// End Active Thread Atom
+
+/// Threads State Atom
+export const engineParamsUpdateAtom = atom(false)
+
+/**
+ * Whether the thread is waiting to send a message
+ */
+export const waitingToSendMessage = atom(undefined)
+
+/**
+ * Whether the thread is generating a response
+ */
+export const isGeneratingResponseAtom = atom(undefined)
+
+/**
+ * Remove a thread state from the atom
+ */
export const deleteThreadStateAtom = atom(
null,
(get, set, threadId: string) => {
@@ -59,6 +137,9 @@ export const deleteThreadStateAtom = atom(
}
)
+/**
+ * Update the thread state with the new state
+ */
export const updateThreadWaitingForResponseAtom = atom(
null,
(get, set, threadId: string, waitingForResponse: boolean) => {
@@ -71,6 +152,10 @@ export const updateThreadWaitingForResponseAtom = atom(
set(threadStatesAtom, currentState)
}
)
+
+/**
+ * Update the thread last message
+ */
export const updateThreadStateLastMessageAtom = atom(
null,
(get, set, threadId: string, lastContent?: ThreadContent[]) => {
@@ -84,6 +169,9 @@ export const updateThreadStateLastMessageAtom = atom(
}
)
+/**
+ * Update a thread with the new thread metadata
+ */
export const updateThreadAtom = atom(
null,
(get, set, updatedThread: Thread) => {
@@ -103,33 +191,8 @@ export const updateThreadAtom = atom(
)
/**
- * Stores all threads for the current user
+ * Update the thread model params
*/
-export const threadsAtom = atom([])
-
-export const activeThreadAtom = atom((get) =>
- get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
-)
-
-/**
- * Store model params at thread level settings
- */
-export const threadModelParamsAtom = atom>({})
-
-export type ModelParams = ModelRuntimeParams | ModelSettingParams
-
-export const getActiveThreadModelParamsAtom = atom(
- (get) => {
- const threadId = get(activeThreadIdAtom)
- if (!threadId) {
- console.debug('Active thread id is undefined')
- return undefined
- }
-
- return get(threadModelParamsAtom)[threadId]
- }
-)
-
export const setThreadModelParamsAtom = atom(
null,
(get, set, threadId: string, params: ModelParams) => {
@@ -139,12 +202,17 @@ export const setThreadModelParamsAtom = atom(
}
)
-const ACTIVE_SETTING_INPUT_BOX = 'activeSettingInputBox'
+/**
+ * Settings input box active state
+ */
export const activeSettingInputBoxAtom = atomWithStorage(
ACTIVE_SETTING_INPUT_BOX,
false
)
+/**
+ * Whether thread thread is presenting a Modal or not
+ */
export const modalActionThreadAtom = atom<{
showModal: ThreadModalAction | undefined
thread: Thread | undefined
@@ -153,5 +221,4 @@ export const modalActionThreadAtom = atom<{
thread: undefined,
})
-export const isDownloadALocalModelAtom = atom(false)
-export const isAnyRemoteModelConfiguredAtom = atom(false)
+/// Ebd Threads State Atom
diff --git a/web/hooks/useAssistant.test.ts b/web/hooks/useAssistant.test.ts
new file mode 100644
index 000000000..e029bb7f6
--- /dev/null
+++ b/web/hooks/useAssistant.test.ts
@@ -0,0 +1,95 @@
+import { renderHook, act } from '@testing-library/react'
+import { useSetAtom } from 'jotai'
+import { events, AssistantEvent, ExtensionTypeEnum } from '@janhq/core'
+
+// Mock dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+jest.mock('@janhq/core')
+jest.mock('@/extension')
+
+import useAssistants from './useAssistants'
+import { extensionManager } from '@/extension'
+
+// Mock data
+const mockAssistants = [
+ { id: 'assistant-1', name: 'Assistant 1' },
+ { id: 'assistant-2', name: 'Assistant 2' },
+]
+
+const mockAssistantExtension = {
+ getAssistants: jest.fn().mockResolvedValue(mockAssistants),
+} as any
+
+describe('useAssistants', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ jest.spyOn(extensionManager, 'get').mockReturnValue(mockAssistantExtension)
+ })
+
+ it('should fetch and set assistants on mount', async () => {
+ const mockSetAssistants = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAssistants)
+
+ renderHook(() => useAssistants())
+
+ // Wait for useEffect to complete
+ await act(async () => {})
+
+ expect(mockAssistantExtension.getAssistants).toHaveBeenCalled()
+ expect(mockSetAssistants).toHaveBeenCalledWith(mockAssistants)
+ })
+
+ it('should update assistants when AssistantEvent.OnAssistantsUpdate is emitted', async () => {
+ const mockSetAssistants = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAssistants)
+
+ renderHook(() => useAssistants())
+
+ // Wait for initial useEffect to complete
+ await act(async () => {})
+
+ // Clear previous calls
+ mockSetAssistants.mockClear()
+
+ // Simulate AssistantEvent.OnAssistantsUpdate event
+ await act(async () => {
+ events.emit(AssistantEvent.OnAssistantsUpdate, '')
+ })
+
+ expect(mockAssistantExtension.getAssistants).toHaveBeenCalledTimes(1)
+ })
+
+ it('should unsubscribe from events on unmount', async () => {
+ const { unmount } = renderHook(() => useAssistants())
+
+ // Wait for useEffect to complete
+ await act(async () => {})
+
+ const offSpy = jest.spyOn(events, 'off')
+
+ unmount()
+
+ expect(offSpy).toHaveBeenCalledWith(
+ AssistantEvent.OnAssistantsUpdate,
+ expect.any(Function)
+ )
+ })
+
+ it('should handle case when AssistantExtension is not available', async () => {
+ const mockSetAssistants = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetAssistants)
+ ;(extensionManager.get as jest.Mock).mockReturnValue(undefined)
+
+ renderHook(() => useAssistants())
+
+ // Wait for useEffect to complete
+ await act(async () => {})
+
+ expect(mockSetAssistants).toHaveBeenCalledWith([])
+ })
+})
diff --git a/web/hooks/useClipboard.test.ts b/web/hooks/useClipboard.test.ts
new file mode 100644
index 000000000..a79f8132b
--- /dev/null
+++ b/web/hooks/useClipboard.test.ts
@@ -0,0 +1,105 @@
+import { renderHook, act } from '@testing-library/react'
+import { useClipboard } from './useClipboard'
+
+describe('useClipboard', () => {
+ let originalClipboard: any
+
+ beforeAll(() => {
+ originalClipboard = { ...global.navigator.clipboard }
+ const mockClipboard = {
+ writeText: jest.fn(() => Promise.resolve()),
+ }
+ // @ts-ignore
+ global.navigator.clipboard = mockClipboard
+ })
+
+ afterAll(() => {
+ // @ts-ignore
+ global.navigator.clipboard = originalClipboard
+ })
+
+ beforeEach(() => {
+ jest.useFakeTimers()
+ })
+
+ afterEach(() => {
+ jest.clearAllTimers()
+ jest.useRealTimers()
+ })
+
+ it('should copy text to clipboard', async () => {
+ const { result } = renderHook(() => useClipboard())
+
+ await act(async () => {
+ result.current.copy('Test text')
+ })
+
+ expect(navigator.clipboard.writeText).toHaveBeenCalledWith('Test text')
+ expect(result.current.copied).toBe(true)
+ expect(result.current.error).toBe(null)
+ })
+
+ it('should set copied to false after timeout', async () => {
+ const { result } = renderHook(() => useClipboard({ timeout: 1000 }))
+
+ await act(async () => {
+ result.current.copy('Test text')
+ })
+
+ expect(result.current.copied).toBe(true)
+
+ act(() => {
+ jest.advanceTimersByTime(1000)
+ })
+
+ expect(result.current.copied).toBe(false)
+ })
+
+ it('should handle clipboard errors', async () => {
+ const mockError = new Error('Clipboard error')
+ // @ts-ignore
+ navigator.clipboard.writeText.mockRejectedValueOnce(mockError)
+
+ const { result } = renderHook(() => useClipboard())
+
+ await act(async () => {
+ result.current.copy('Test text')
+ })
+
+ expect(result.current.error).toEqual(mockError)
+ expect(result.current.copied).toBe(false)
+ })
+
+ it('should reset state', async () => {
+ const { result } = renderHook(() => useClipboard())
+
+ await act(async () => {
+ result.current.copy('Test text')
+ })
+
+ expect(result.current.copied).toBe(true)
+
+ act(() => {
+ result.current.reset()
+ })
+
+ expect(result.current.copied).toBe(false)
+ expect(result.current.error).toBe(null)
+ })
+
+ it('should handle missing clipboard API', () => {
+ // @ts-ignore
+ delete global.navigator.clipboard
+
+ const { result } = renderHook(() => useClipboard())
+
+ act(() => {
+ result.current.copy('Test text')
+ })
+
+ expect(result.current.error).toEqual(
+ new Error('useClipboard: navigator.clipboard is not supported')
+ )
+ expect(result.current.copied).toBe(false)
+ })
+})
diff --git a/web/hooks/useDeleteModel.test.ts b/web/hooks/useDeleteModel.test.ts
new file mode 100644
index 000000000..336a1cd0c
--- /dev/null
+++ b/web/hooks/useDeleteModel.test.ts
@@ -0,0 +1,73 @@
+import { renderHook, act } from '@testing-library/react'
+import { extensionManager } from '@/extension/ExtensionManager'
+import useDeleteModel from './useDeleteModel'
+import { toaster } from '@/containers/Toast'
+import { useSetAtom } from 'jotai'
+
+// Mock the dependencies
+jest.mock('@/extension/ExtensionManager')
+jest.mock('@/containers/Toast')
+jest.mock('jotai', () => ({
+ useSetAtom: jest.fn(() => jest.fn()),
+ atom: jest.fn(),
+}))
+
+describe('useDeleteModel', () => {
+ const mockModel: any = {
+ id: 'test-model',
+ name: 'Test Model',
+ // Add other required properties of ModelFile
+ }
+
+ const mockDeleteModel = jest.fn()
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ deleteModel: mockDeleteModel,
+ })
+ })
+
+ it('should delete a model successfully', async () => {
+ const { result } = renderHook(() => useDeleteModel())
+
+ await act(async () => {
+ await result.current.deleteModel(mockModel)
+ })
+
+ expect(mockDeleteModel).toHaveBeenCalledWith(mockModel)
+ expect(toaster).toHaveBeenCalledWith({
+ title: 'Model Deletion Successful',
+ description: `Model ${mockModel.name} has been successfully deleted.`,
+ type: 'success',
+ })
+ })
+
+ it('should call removeDownloadedModel with the model id', async () => {
+ const { result } = renderHook(() => useDeleteModel())
+
+ await act(async () => {
+ await result.current.deleteModel(mockModel)
+ })
+
+ // Assuming useSetAtom returns a mock function
+ ;(useSetAtom as jest.Mock).mockReturnValue(jest.fn())
+ expect(useSetAtom).toHaveBeenCalled()
+ })
+
+ it('should handle errors during model deletion', async () => {
+ const error = new Error('Deletion failed')
+ mockDeleteModel.mockRejectedValue(error)
+
+ const { result } = renderHook(() => useDeleteModel())
+
+ await act(async () => {
+ await expect(result.current.deleteModel(mockModel)).rejects.toThrow(
+ 'Deletion failed'
+ )
+ })
+
+ expect(mockDeleteModel).toHaveBeenCalledWith(mockModel)
+ expect(toaster).not.toHaveBeenCalled()
+ })
+})
diff --git a/web/hooks/useDeleteThread.test.ts b/web/hooks/useDeleteThread.test.ts
new file mode 100644
index 000000000..d3a6138d0
--- /dev/null
+++ b/web/hooks/useDeleteThread.test.ts
@@ -0,0 +1,106 @@
+import { renderHook, act } from '@testing-library/react'
+import { useAtom, useAtomValue, useSetAtom } from 'jotai'
+import useDeleteThread from './useDeleteThread'
+import { extensionManager } from '@/extension/ExtensionManager'
+import { toaster } from '@/containers/Toast'
+
+// Mock the necessary dependencies
+// Mock dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+jest.mock('@/extension/ExtensionManager')
+jest.mock('@/containers/Toast')
+
+describe('useDeleteThread', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should delete a thread successfully', async () => {
+ const mockThreads = [
+ { id: 'thread1', title: 'Thread 1' },
+ { id: 'thread2', title: 'Thread 2' },
+ ]
+ const mockSetThreads = jest.fn()
+ ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
+
+ const mockDeleteThread = jest.fn()
+ extensionManager.get = jest.fn().mockReturnValue({
+ deleteThread: mockDeleteThread,
+ })
+
+ const { result } = renderHook(() => useDeleteThread())
+
+ await act(async () => {
+ await result.current.deleteThread('thread1')
+ })
+
+ expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
+ expect(mockSetThreads).toHaveBeenCalledWith([mockThreads[1]])
+ })
+
+ it('should clean a thread successfully', async () => {
+ const mockThreads = [{ id: 'thread1', title: 'Thread 1', metadata: {} }]
+ const mockSetThreads = jest.fn()
+ ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
+ const mockCleanMessages = jest.fn()
+ ;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
+ ;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
+
+ const mockWriteMessages = jest.fn()
+ const mockSaveThread = jest.fn()
+ extensionManager.get = jest.fn().mockReturnValue({
+ writeMessages: mockWriteMessages,
+ saveThread: mockSaveThread,
+ })
+
+ const { result } = renderHook(() => useDeleteThread())
+
+ await act(async () => {
+ await result.current.cleanThread('thread1')
+ })
+
+ expect(mockWriteMessages).toHaveBeenCalled()
+ expect(mockSaveThread).toHaveBeenCalledWith(
+ expect.objectContaining({
+ id: 'thread1',
+ title: 'New Thread',
+ metadata: expect.objectContaining({ lastMessage: undefined }),
+ })
+ )
+ })
+
+ it('should handle errors when deleting a thread', async () => {
+ const mockThreads = [{ id: 'thread1', title: 'Thread 1' }]
+ const mockSetThreads = jest.fn()
+ ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
+
+ const mockDeleteThread = jest
+ .fn()
+ .mockRejectedValue(new Error('Delete error'))
+ extensionManager.get = jest.fn().mockReturnValue({
+ deleteThread: mockDeleteThread,
+ })
+
+ const consoleErrorSpy = jest
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+
+ const { result } = renderHook(() => useDeleteThread())
+
+ await act(async () => {
+ await result.current.deleteThread('thread1')
+ })
+
+ expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
+ expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error))
+ expect(mockSetThreads).not.toHaveBeenCalled()
+ expect(toaster).not.toHaveBeenCalled()
+
+ consoleErrorSpy.mockRestore()
+ })
+})
diff --git a/web/hooks/useDownloadModel.test.ts b/web/hooks/useDownloadModel.test.ts
new file mode 100644
index 000000000..fc0b7c21f
--- /dev/null
+++ b/web/hooks/useDownloadModel.test.ts
@@ -0,0 +1,98 @@
+import { renderHook, act } from '@testing-library/react'
+import { useAtom, useSetAtom } from 'jotai'
+import useDownloadModel from './useDownloadModel'
+import * as core from '@janhq/core'
+import { extensionManager } from '@/extension/ExtensionManager'
+
+// Mock the necessary dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+jest.mock('@janhq/core')
+jest.mock('@/extension/ExtensionManager')
+jest.mock('./useGpuSetting', () => ({
+ __esModule: true,
+ default: () => ({
+ getGpuSettings: jest.fn().mockResolvedValue({ some: 'gpuSettings' }),
+ }),
+}))
+
+describe('useDownloadModel', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(useAtom as jest.Mock).mockReturnValue([false, jest.fn()])
+ })
+
+ it('should download a model', async () => {
+ const mockModel: core.Model = {
+ id: 'test-model',
+ sources: [{ filename: 'test.bin' }],
+ } as core.Model
+
+ const mockExtension = {
+ downloadModel: jest.fn().mockResolvedValue(undefined),
+ }
+ ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
+ ;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension)
+
+ const { result } = renderHook(() => useDownloadModel())
+
+ await act(async () => {
+ await result.current.downloadModel(mockModel)
+ })
+
+ expect(mockExtension.downloadModel).toHaveBeenCalledWith(
+ mockModel,
+ { some: 'gpuSettings' },
+ { ignoreSSL: undefined, proxy: '' }
+ )
+ })
+
+ it('should abort model download', async () => {
+ const mockModel: core.Model = {
+ id: 'test-model',
+ sources: [{ filename: 'test.bin' }],
+ } as core.Model
+
+ ;(core.joinPath as jest.Mock).mockResolvedValue('/path/to/model/test.bin')
+ ;(core.abortDownload as jest.Mock).mockResolvedValue(undefined)
+ ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
+ const { result } = renderHook(() => useDownloadModel())
+
+ await act(async () => {
+ await result.current.abortModelDownload(mockModel)
+ })
+
+ expect(core.abortDownload).toHaveBeenCalledWith('/path/to/model/test.bin')
+ })
+
+ it('should handle proxy settings', async () => {
+ const mockModel: core.Model = {
+ id: 'test-model',
+ sources: [{ filename: 'test.bin' }],
+ } as core.Model
+
+ const mockExtension = {
+ downloadModel: jest.fn().mockResolvedValue(undefined),
+ }
+ ;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
+ ;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension)
+ ;(useAtom as jest.Mock).mockReturnValueOnce([true, jest.fn()]) // proxyEnabled
+ ;(useAtom as jest.Mock).mockReturnValueOnce(['http://proxy.com', jest.fn()]) // proxy
+
+ const { result } = renderHook(() => useDownloadModel())
+
+ await act(async () => {
+ await result.current.downloadModel(mockModel)
+ })
+
+ expect(mockExtension.downloadModel).toHaveBeenCalledWith(
+ mockModel,
+ expect.objectContaining({ some: 'gpuSettings' }),
+ expect.anything()
+ )
+ })
+})
diff --git a/web/hooks/useDropModelBinaries.test.ts b/web/hooks/useDropModelBinaries.test.ts
new file mode 100644
index 000000000..dad8c6178
--- /dev/null
+++ b/web/hooks/useDropModelBinaries.test.ts
@@ -0,0 +1,129 @@
+// useDropModelBinaries.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { useSetAtom } from 'jotai'
+import { v4 as uuidv4 } from 'uuid'
+import useDropModelBinaries from './useDropModelBinaries'
+import { getFileInfoFromFile } from '@/utils/file'
+import { snackbar } from '@/containers/Toast'
+
+// Mock dependencies
+// Mock the necessary dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+jest.mock('uuid')
+jest.mock('@/utils/file')
+jest.mock('@/containers/Toast')
+
+describe('useDropModelBinaries', () => {
+ const mockSetImportingModels = jest.fn()
+ const mockSetImportModelStage = jest.fn()
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(useSetAtom as jest.Mock).mockReturnValueOnce(mockSetImportingModels)
+ ;(useSetAtom as jest.Mock).mockReturnValueOnce(mockSetImportModelStage)
+ ;(uuidv4 as jest.Mock).mockReturnValue('mock-uuid')
+ ;(getFileInfoFromFile as jest.Mock).mockResolvedValue([])
+ })
+
+ it('should handle dropping supported files', async () => {
+ const { result } = renderHook(() => useDropModelBinaries())
+
+ const mockFiles = [
+ { name: 'model1.gguf', path: '/path/to/model1.gguf', size: 1000 },
+ { name: 'model2.gguf', path: '/path/to/model2.gguf', size: 2000 },
+ ]
+
+ ;(getFileInfoFromFile as jest.Mock).mockResolvedValue(mockFiles)
+
+ await act(async () => {
+ await result.current.onDropModels([])
+ })
+
+ expect(mockSetImportingModels).toHaveBeenCalledWith([
+ {
+ importId: 'mock-uuid',
+ modelId: undefined,
+ name: 'model1',
+ description: '',
+ path: '/path/to/model1.gguf',
+ tags: [],
+ size: 1000,
+ status: 'PREPARING',
+ format: 'gguf',
+ },
+ {
+ importId: 'mock-uuid',
+ modelId: undefined,
+ name: 'model2',
+ description: '',
+ path: '/path/to/model2.gguf',
+ tags: [],
+ size: 2000,
+ status: 'PREPARING',
+ format: 'gguf',
+ },
+ ])
+ expect(mockSetImportModelStage).toHaveBeenCalledWith('MODEL_SELECTED')
+ })
+
+ it('should handle dropping unsupported files', async () => {
+ const { result } = renderHook(() => useDropModelBinaries())
+
+ const mockFiles = [
+ { name: 'unsupported.txt', path: '/path/to/unsupported.txt', size: 500 },
+ ]
+
+ ;(getFileInfoFromFile as jest.Mock).mockResolvedValue(mockFiles)
+
+ await act(async () => {
+ await result.current.onDropModels([])
+ })
+
+ expect(snackbar).toHaveBeenCalledWith({
+ description: 'Only files with .gguf extension can be imported.',
+ type: 'error',
+ })
+ expect(mockSetImportingModels).not.toHaveBeenCalled()
+ expect(mockSetImportModelStage).not.toHaveBeenCalled()
+ })
+
+ it('should handle dropping both supported and unsupported files', async () => {
+ const { result } = renderHook(() => useDropModelBinaries())
+
+ const mockFiles = [
+ { name: 'model.gguf', path: '/path/to/model.gguf', size: 1000 },
+ { name: 'unsupported.txt', path: '/path/to/unsupported.txt', size: 500 },
+ ]
+
+ ;(getFileInfoFromFile as jest.Mock).mockResolvedValue(mockFiles)
+
+ await act(async () => {
+ await result.current.onDropModels([])
+ })
+
+ expect(snackbar).toHaveBeenCalledWith({
+ description: 'Only files with .gguf extension can be imported.',
+ type: 'error',
+ })
+ expect(mockSetImportingModels).toHaveBeenCalledWith([
+ {
+ importId: 'mock-uuid',
+ modelId: undefined,
+ name: 'model',
+ description: '',
+ path: '/path/to/model.gguf',
+ tags: [],
+ size: 1000,
+ status: 'PREPARING',
+ format: 'gguf',
+ },
+ ])
+ expect(mockSetImportModelStage).toHaveBeenCalledWith('MODEL_SELECTED')
+ })
+})
diff --git a/web/hooks/useFactoryReset.test.ts b/web/hooks/useFactoryReset.test.ts
new file mode 100644
index 000000000..b9ec10d6b
--- /dev/null
+++ b/web/hooks/useFactoryReset.test.ts
@@ -0,0 +1,89 @@
+import { renderHook, act } from '@testing-library/react'
+import { useAtomValue, useSetAtom } from 'jotai'
+import useFactoryReset, { FactoryResetState } from './useFactoryReset'
+import { useActiveModel } from './useActiveModel'
+import { fs } from '@janhq/core'
+
+// Mock the dependencies
+jest.mock('jotai', () => ({
+ atom: jest.fn(),
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+}))
+jest.mock('./useActiveModel', () => ({
+ useActiveModel: jest.fn(),
+}))
+jest.mock('@janhq/core', () => ({
+ fs: {
+ rm: jest.fn(),
+ },
+}))
+
+describe('useFactoryReset', () => {
+ const mockStopModel = jest.fn()
+ const mockSetFactoryResetState = jest.fn()
+ const mockGetAppConfigurations = jest.fn()
+ const mockUpdateAppConfiguration = jest.fn()
+ const mockRelaunch = jest.fn()
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(useAtomValue as jest.Mock).mockReturnValue('/default/jan/data/folder')
+ ;(useSetAtom as jest.Mock).mockReturnValue(mockSetFactoryResetState)
+ ;(useActiveModel as jest.Mock).mockReturnValue({ stopModel: mockStopModel })
+ global.window ??= Object.create(window)
+ global.window.core = {
+ api: {
+ getAppConfigurations: mockGetAppConfigurations,
+ updateAppConfiguration: mockUpdateAppConfiguration,
+ relaunch: mockRelaunch,
+ },
+ }
+ mockGetAppConfigurations.mockResolvedValue({
+ data_folder: '/current/jan/data/folder',
+ quick_ask: false,
+ })
+ jest.spyOn(global, 'setTimeout')
+ })
+
+ it('should reset all correctly', async () => {
+ const { result } = renderHook(() => useFactoryReset())
+
+ await act(async () => {
+ await result.current.resetAll()
+ })
+
+ expect(mockSetFactoryResetState).toHaveBeenCalledWith(
+ FactoryResetState.Starting
+ )
+ expect(mockSetFactoryResetState).toHaveBeenCalledWith(
+ FactoryResetState.StoppingModel
+ )
+ expect(mockStopModel).toHaveBeenCalled()
+ expect(setTimeout).toHaveBeenCalledWith(expect.any(Function), 4000)
+ expect(mockSetFactoryResetState).toHaveBeenCalledWith(
+ FactoryResetState.DeletingData
+ )
+ expect(fs.rm).toHaveBeenCalledWith('/current/jan/data/folder')
+ expect(mockUpdateAppConfiguration).toHaveBeenCalledWith({
+ data_folder: '/default/jan/data/folder',
+ quick_ask: false,
+ })
+ expect(mockSetFactoryResetState).toHaveBeenCalledWith(
+ FactoryResetState.ClearLocalStorage
+ )
+ expect(mockRelaunch).toHaveBeenCalled()
+ })
+
+ it('should keep current folder when specified', async () => {
+ const { result } = renderHook(() => useFactoryReset())
+
+ await act(async () => {
+ await result.current.resetAll(true)
+ })
+
+ expect(mockUpdateAppConfiguration).not.toHaveBeenCalled()
+ })
+
+ // Add more tests as needed for error cases, edge cases, etc.
+})
diff --git a/web/hooks/useGetHFRepoData.test.ts b/web/hooks/useGetHFRepoData.test.ts
new file mode 100644
index 000000000..eaf86d79a
--- /dev/null
+++ b/web/hooks/useGetHFRepoData.test.ts
@@ -0,0 +1,39 @@
+import { renderHook, act } from '@testing-library/react'
+import { useGetHFRepoData } from './useGetHFRepoData'
+import { extensionManager } from '@/extension'
+
+jest.mock('@/extension', () => ({
+ extensionManager: {
+ get: jest.fn(),
+ },
+}))
+
+describe('useGetHFRepoData', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should fetch HF repo data successfully', async () => {
+ const mockData = { name: 'Test Repo', stars: 100 }
+ const mockFetchHuggingFaceRepoData = jest.fn().mockResolvedValue(mockData)
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ fetchHuggingFaceRepoData: mockFetchHuggingFaceRepoData,
+ })
+
+ const { result } = renderHook(() => useGetHFRepoData())
+
+ expect(result.current.loading).toBe(false)
+ expect(result.current.error).toBeUndefined()
+
+ let data
+ act(() => {
+ data = result.current.getHfRepoData('test-repo')
+ })
+
+ expect(result.current.loading).toBe(true)
+
+ expect(result.current.error).toBeUndefined()
+ expect(await data).toEqual(mockData)
+ expect(mockFetchHuggingFaceRepoData).toHaveBeenCalledWith('test-repo')
+ })
+})
diff --git a/web/hooks/useGetSystemResources.test.ts b/web/hooks/useGetSystemResources.test.ts
new file mode 100644
index 000000000..10e539e07
--- /dev/null
+++ b/web/hooks/useGetSystemResources.test.ts
@@ -0,0 +1,103 @@
+// useGetSystemResources.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import useGetSystemResources from './useGetSystemResources'
+import { extensionManager } from '@/extension/ExtensionManager'
+
+// Mock the extensionManager
+jest.mock('@/extension/ExtensionManager', () => ({
+ extensionManager: {
+ get: jest.fn(),
+ },
+}))
+
+// Mock the necessary dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: () => jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+
+describe('useGetSystemResources', () => {
+ const mockMonitoringExtension = {
+ getResourcesInfo: jest.fn(),
+ getCurrentLoad: jest.fn(),
+ }
+
+ beforeEach(() => {
+ jest.useFakeTimers()
+ ;(extensionManager.get as jest.Mock).mockReturnValue(
+ mockMonitoringExtension
+ )
+ })
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ jest.useRealTimers()
+ })
+
+ it('should fetch system resources on initial render', async () => {
+ mockMonitoringExtension.getResourcesInfo.mockResolvedValue({
+ mem: { usedMemory: 4000, totalMemory: 8000 },
+ })
+ mockMonitoringExtension.getCurrentLoad.mockResolvedValue({
+ cpu: { usage: 50 },
+ gpu: [],
+ })
+
+ const { result } = renderHook(() => useGetSystemResources())
+
+ expect(mockMonitoringExtension.getResourcesInfo).toHaveBeenCalledTimes(1)
+ })
+
+ it('should start watching system resources when watch is called', () => {
+ const { result } = renderHook(() => useGetSystemResources())
+
+ act(() => {
+ result.current.watch()
+ })
+
+ expect(mockMonitoringExtension.getResourcesInfo).toHaveBeenCalled()
+
+ // Fast-forward time by 2 seconds
+ act(() => {
+ jest.advanceTimersByTime(2000)
+ })
+
+ expect(mockMonitoringExtension.getResourcesInfo).toHaveBeenCalled()
+ })
+
+ it('should stop watching when stopWatching is called', () => {
+ const { result } = renderHook(() => useGetSystemResources())
+
+ act(() => {
+ result.current.watch()
+ })
+
+ act(() => {
+ result.current.stopWatching()
+ })
+
+ // Fast-forward time by 2 seconds
+ act(() => {
+ jest.advanceTimersByTime(2000)
+ })
+
+ // Expect no additional calls after stopping
+ expect(mockMonitoringExtension.getResourcesInfo).toHaveBeenCalled()
+ })
+
+ it('should not fetch resources if monitoring extension is not available', async () => {
+ ;(extensionManager.get as jest.Mock).mockReturnValue(null)
+
+ const { result } = renderHook(() => useGetSystemResources())
+
+ await act(async () => {
+ result.current.getSystemResources()
+ })
+
+ expect(mockMonitoringExtension.getResourcesInfo).not.toHaveBeenCalled()
+ expect(mockMonitoringExtension.getCurrentLoad).not.toHaveBeenCalled()
+ })
+})
diff --git a/web/hooks/useGpuSetting.test.ts b/web/hooks/useGpuSetting.test.ts
new file mode 100644
index 000000000..f52f07af8
--- /dev/null
+++ b/web/hooks/useGpuSetting.test.ts
@@ -0,0 +1,87 @@
+// useGpuSetting.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { ExtensionTypeEnum, MonitoringExtension } from '@janhq/core'
+
+// Mock dependencies
+jest.mock('@/extension')
+
+import useGpuSetting from './useGpuSetting'
+import { extensionManager } from '@/extension'
+
+describe('useGpuSetting', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should return GPU settings when available', async () => {
+ const mockGpuSettings = {
+ gpuCount: 2,
+ gpuNames: ['NVIDIA GeForce RTX 3080', 'NVIDIA GeForce RTX 3070'],
+ totalMemory: 20000,
+ freeMemory: 15000,
+ }
+
+ const mockMonitoringExtension: Partial = {
+ getGpuSetting: jest.fn().mockResolvedValue(mockGpuSettings),
+ }
+
+ jest
+ .spyOn(extensionManager, 'get')
+ .mockReturnValue(mockMonitoringExtension as MonitoringExtension)
+
+ const { result } = renderHook(() => useGpuSetting())
+
+ let gpuSettings
+ await act(async () => {
+ gpuSettings = await result.current.getGpuSettings()
+ })
+
+ expect(gpuSettings).toEqual(mockGpuSettings)
+ expect(extensionManager.get).toHaveBeenCalledWith(
+ ExtensionTypeEnum.SystemMonitoring
+ )
+ expect(mockMonitoringExtension.getGpuSetting).toHaveBeenCalled()
+ })
+
+ it('should return undefined when no GPU settings are found', async () => {
+ const mockMonitoringExtension: Partial = {
+ getGpuSetting: jest.fn().mockResolvedValue(undefined),
+ }
+
+ jest
+ .spyOn(extensionManager, 'get')
+ .mockReturnValue(mockMonitoringExtension as MonitoringExtension)
+
+ const { result } = renderHook(() => useGpuSetting())
+
+ let gpuSettings
+ await act(async () => {
+ gpuSettings = await result.current.getGpuSettings()
+ })
+
+ expect(gpuSettings).toBeUndefined()
+ expect(extensionManager.get).toHaveBeenCalledWith(
+ ExtensionTypeEnum.SystemMonitoring
+ )
+ expect(mockMonitoringExtension.getGpuSetting).toHaveBeenCalled()
+ })
+
+ it('should handle missing MonitoringExtension', async () => {
+ jest.spyOn(extensionManager, 'get').mockReturnValue(undefined)
+ jest.spyOn(console, 'debug').mockImplementation(() => {})
+
+ const { result } = renderHook(() => useGpuSetting())
+
+ let gpuSettings
+ await act(async () => {
+ gpuSettings = await result.current.getGpuSettings()
+ })
+
+ expect(gpuSettings).toBeUndefined()
+ expect(extensionManager.get).toHaveBeenCalledWith(
+ ExtensionTypeEnum.SystemMonitoring
+ )
+ expect(console.debug).toHaveBeenCalledWith('No GPU setting found')
+ })
+})
diff --git a/web/hooks/useImportModel.test.ts b/web/hooks/useImportModel.test.ts
new file mode 100644
index 000000000..2148f581b
--- /dev/null
+++ b/web/hooks/useImportModel.test.ts
@@ -0,0 +1,70 @@
+// useImportModel.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { extensionManager } from '@/extension'
+import useImportModel from './useImportModel'
+
+// Mock dependencies
+jest.mock('@janhq/core')
+jest.mock('@/extension')
+jest.mock('@/containers/Toast')
+jest.mock('uuid', () => ({ v4: () => 'mocked-uuid' }))
+
+describe('useImportModel', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should import models successfully', async () => {
+ const mockImportModels = jest.fn().mockResolvedValue(undefined)
+ const mockExtension = {
+ importModels: mockImportModels,
+ } as any
+
+ jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension)
+
+ const { result } = renderHook(() => useImportModel())
+
+ const models = [
+ { importId: '1', name: 'Model 1', path: '/path/to/model1' },
+ { importId: '2', name: 'Model 2', path: '/path/to/model2' },
+ ] as any
+
+ await act(async () => {
+ await result.current.importModels(models, 'local' as any)
+ })
+
+ expect(mockImportModels).toHaveBeenCalledWith(models, 'local')
+ })
+
+ it('should update model info successfully', async () => {
+ const mockUpdateModelInfo = jest
+ .fn()
+ .mockResolvedValue({ id: 'model-1', name: 'Updated Model' })
+ const mockExtension = {
+ updateModelInfo: mockUpdateModelInfo,
+ } as any
+
+ jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension)
+
+ const { result } = renderHook(() => useImportModel())
+
+ const modelInfo = { id: 'model-1', name: 'Updated Model' }
+
+ await act(async () => {
+ await result.current.updateModelInfo(modelInfo)
+ })
+
+ expect(mockUpdateModelInfo).toHaveBeenCalledWith(modelInfo)
+ })
+
+ it('should handle empty file paths', async () => {
+ const { result } = renderHook(() => useImportModel())
+
+ await act(async () => {
+ await result.current.sanitizeFilePaths([])
+ })
+
+ // Expect no state changes or side effects
+ })
+})
diff --git a/web/hooks/useLoadTheme.test.ts b/web/hooks/useLoadTheme.test.ts
new file mode 100644
index 000000000..a0d117fc5
--- /dev/null
+++ b/web/hooks/useLoadTheme.test.ts
@@ -0,0 +1,111 @@
+import { renderHook, act } from '@testing-library/react'
+import { useTheme } from 'next-themes'
+import { fs, joinPath } from '@janhq/core'
+import { useAtom, useAtomValue, useSetAtom } from 'jotai'
+
+import { useLoadTheme } from './useLoadTheme'
+
+// Mock dependencies
+jest.mock('next-themes')
+jest.mock('@janhq/core')
+
+// Mock dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+
+describe('useLoadTheme', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ const mockJanDataFolderPath = '/mock/path'
+ const mockThemesPath = '/mock/path/themes'
+ const mockSelectedThemeId = 'joi-light'
+ const mockThemeData = {
+ id: 'joi-light',
+ displayName: 'Joi Light',
+ nativeTheme: 'light',
+ variables: {
+ '--primary-color': '#007bff',
+ },
+ }
+
+ it('should load theme and set variables', async () => {
+ // Mock Jotai hooks
+ ;(useAtomValue as jest.Mock).mockReturnValue(mockJanDataFolderPath)
+ ;(useSetAtom as jest.Mock).mockReturnValue(jest.fn())
+ ;(useAtom as jest.Mock).mockReturnValue([mockSelectedThemeId, jest.fn()])
+ ;(useAtom as jest.Mock).mockReturnValue([mockThemeData, jest.fn()])
+
+ // Mock fs and joinPath
+ ;(fs.readdirSync as jest.Mock).mockResolvedValue(['joi-light', 'joi-dark'])
+ ;(fs.readFileSync as jest.Mock).mockResolvedValue(
+ JSON.stringify(mockThemeData)
+ )
+ ;(joinPath as jest.Mock).mockImplementation((paths) => paths.join('/'))
+
+ // Mock setTheme from next-themes
+ const mockSetTheme = jest.fn()
+ ;(useTheme as jest.Mock).mockReturnValue({ setTheme: mockSetTheme })
+
+ // Mock window.electronAPI
+ Object.defineProperty(window, 'electronAPI', {
+ value: {
+ setNativeThemeLight: jest.fn(),
+ setNativeThemeDark: jest.fn(),
+ },
+ writable: true,
+ })
+
+ const { result } = renderHook(() => useLoadTheme())
+
+ await act(async () => {
+ await result.current
+ })
+
+ // Assertions
+ expect(fs.readdirSync).toHaveBeenCalledWith(mockThemesPath)
+ expect(fs.readFileSync).toHaveBeenCalledWith(
+ `${mockThemesPath}/${mockSelectedThemeId}/theme.json`,
+ 'utf-8'
+ )
+ expect(mockSetTheme).toHaveBeenCalledWith('light')
+ expect(window.electronAPI.setNativeThemeLight).toHaveBeenCalled()
+ })
+
+ it('should set default theme if no selected theme', async () => {
+ // Mock Jotai hooks with empty selected theme
+ ;(useAtomValue as jest.Mock).mockReturnValue(mockJanDataFolderPath)
+ ;(useSetAtom as jest.Mock).mockReturnValue(jest.fn())
+ ;(useAtom as jest.Mock).mockReturnValue(['', jest.fn()])
+ ;(useAtom as jest.Mock).mockReturnValue([{}, jest.fn()])
+
+ const mockSetSelectedThemeId = jest.fn()
+ ;(useAtom as jest.Mock).mockReturnValue(['', mockSetSelectedThemeId])
+
+ const { result } = renderHook(() => useLoadTheme())
+
+ await act(async () => {
+ await result.current
+ })
+
+ expect(mockSetSelectedThemeId).toHaveBeenCalledWith('joi-light')
+ })
+
+ it('should handle missing janDataFolderPath', async () => {
+ // Mock Jotai hooks with empty janDataFolderPath
+ ;(useAtomValue as jest.Mock).mockReturnValue('')
+
+ const { result } = renderHook(() => useLoadTheme())
+
+ await act(async () => {
+ await result.current
+ })
+
+ expect(fs.readdirSync).not.toHaveBeenCalled()
+ })
+})
diff --git a/web/hooks/useLogs.test.ts b/web/hooks/useLogs.test.ts
new file mode 100644
index 000000000..a7a055bbd
--- /dev/null
+++ b/web/hooks/useLogs.test.ts
@@ -0,0 +1,103 @@
+// useLogs.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { useAtomValue } from 'jotai'
+import { fs, joinPath, openFileExplorer } from '@janhq/core'
+
+import { useLogs } from './useLogs'
+
+// Mock dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ atom: jest.fn(),
+}))
+
+jest.mock('@janhq/core', () => ({
+ fs: {
+ existsSync: jest.fn(),
+ readFileSync: jest.fn(),
+ writeFileSync: jest.fn(),
+ },
+ joinPath: jest.fn(),
+ openFileExplorer: jest.fn(),
+}))
+
+describe('useLogs', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ ;(useAtomValue as jest.Mock).mockReturnValue('/mock/jan/data/folder')
+ })
+
+ it('should get logs and sanitize them', async () => {
+ const mockLogs = '/mock/jan/data/folder/some/log/content'
+ const expectedSanitizedLogs = 'jan-data-folder/some/log/content'
+
+ ;(joinPath as jest.Mock).mockResolvedValue('file://logs/test.log')
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(true)
+ ;(fs.readFileSync as jest.Mock).mockResolvedValue(mockLogs)
+
+ const { result } = renderHook(() => useLogs())
+
+ await act(async () => {
+ const logs = await result.current.getLogs('test')
+ expect(logs).toBe(expectedSanitizedLogs)
+ })
+
+ expect(joinPath).toHaveBeenCalledWith(['file://logs', 'test.log'])
+ expect(fs.existsSync).toHaveBeenCalledWith('file://logs/test.log')
+ expect(fs.readFileSync).toHaveBeenCalledWith(
+ 'file://logs/test.log',
+ 'utf-8'
+ )
+ })
+
+ it('should return empty string if log file does not exist', async () => {
+ ;(joinPath as jest.Mock).mockResolvedValue('file://logs/nonexistent.log')
+ ;(fs.existsSync as jest.Mock).mockResolvedValue(false)
+
+ const { result } = renderHook(() => useLogs())
+
+ await act(async () => {
+ const logs = await result.current.getLogs('nonexistent')
+ expect(logs).toBe('')
+ })
+
+ expect(fs.readFileSync).not.toHaveBeenCalled()
+ })
+
+ it('should open server log', async () => {
+ ;(joinPath as jest.Mock).mockResolvedValue(
+ '/mock/jan/data/folder/logs/app.log'
+ )
+ ;(openFileExplorer as jest.Mock).mockResolvedValue(undefined)
+
+ const { result } = renderHook(() => useLogs())
+
+ await act(async () => {
+ await result.current.openServerLog()
+ })
+
+ expect(joinPath).toHaveBeenCalledWith([
+ '/mock/jan/data/folder',
+ 'logs',
+ 'app.log',
+ ])
+ expect(openFileExplorer).toHaveBeenCalledWith(
+ '/mock/jan/data/folder/logs/app.log'
+ )
+ })
+
+ it('should clear server log', async () => {
+ ;(joinPath as jest.Mock).mockResolvedValue('file://logs/app.log')
+ ;(fs.writeFileSync as jest.Mock).mockResolvedValue(undefined)
+
+ const { result } = renderHook(() => useLogs())
+
+ await act(async () => {
+ await result.current.clearServerLog()
+ })
+
+ expect(joinPath).toHaveBeenCalledWith(['file://logs', 'app.log'])
+ expect(fs.writeFileSync).toHaveBeenCalledWith('file://logs/app.log', '')
+ })
+})
diff --git a/web/hooks/useModels.test.ts b/web/hooks/useModels.test.ts
new file mode 100644
index 000000000..4c53ffaa7
--- /dev/null
+++ b/web/hooks/useModels.test.ts
@@ -0,0 +1,61 @@
+// useModels.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { events, ModelEvent } from '@janhq/core'
+import { extensionManager } from '@/extension'
+
+// Mock dependencies
+jest.mock('@janhq/core')
+jest.mock('@/extension')
+
+import useModels from './useModels'
+
+// Mock data
+const mockDownloadedModels = [
+ { id: 'model-1', name: 'Model 1' },
+ { id: 'model-2', name: 'Model 2' },
+]
+
+const mockConfiguredModels = [
+ { id: 'model-3', name: 'Model 3' },
+ { id: 'model-4', name: 'Model 4' },
+]
+
+const mockDefaultModel = { id: 'default-model', name: 'Default Model' }
+
+describe('useModels', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('should fetch and set models on mount', async () => {
+ const mockModelExtension = {
+ getDownloadedModels: jest.fn().mockResolvedValue(mockDownloadedModels),
+ getConfiguredModels: jest.fn().mockResolvedValue(mockConfiguredModels),
+ getDefaultModel: jest.fn().mockResolvedValue(mockDefaultModel),
+ } as any
+
+ jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
+
+ await act(async () => {
+ renderHook(() => useModels())
+ })
+
+ expect(mockModelExtension.getDownloadedModels).toHaveBeenCalled()
+ expect(mockModelExtension.getConfiguredModels).toHaveBeenCalled()
+ expect(mockModelExtension.getDefaultModel).toHaveBeenCalled()
+ })
+
+ it('should remove event listener on unmount', async () => {
+ const removeListenerSpy = jest.spyOn(events, 'off')
+
+ const { unmount } = renderHook(() => useModels())
+
+ unmount()
+
+ expect(removeListenerSpy).toHaveBeenCalledWith(
+ ModelEvent.OnModelsUpdate,
+ expect.any(Function)
+ )
+ })
+})
diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts
index 8333c35c3..58def79c6 100644
--- a/web/hooks/useModels.ts
+++ b/web/hooks/useModels.ts
@@ -18,6 +18,11 @@ import {
downloadedModelsAtom,
} from '@/helpers/atoms/Model.atom'
+/**
+ * useModels hook - Handles the state of models
+ * It fetches the downloaded models, configured models and default model from Model Extension
+ * and updates the atoms accordingly.
+ */
const useModels = () => {
const setDownloadedModels = useSetAtom(downloadedModelsAtom)
const setConfiguredModels = useSetAtom(configuredModelsAtom)
@@ -39,6 +44,7 @@ const useModels = () => {
setDefaultModel(defaultModel)
}
+ // Fetch all data
Promise.all([
getDownloadedModels(),
getConfiguredModels(),
@@ -59,16 +65,19 @@ const useModels = () => {
}, [getData])
}
+// TODO: Deprecated - Remove when moving to cortex.cpp
const getLocalDefaultModel = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getDefaultModel()
+// TODO: Deprecated - Remove when moving to cortex.cpp
const getLocalConfiguredModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getConfiguredModels() ?? []
+// TODO: Deprecated - Remove when moving to cortex.cpp
const getLocalDownloadedModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts
index 8e9268065..6b306224d 100644
--- a/web/hooks/useSetActiveThread.ts
+++ b/web/hooks/useSetActiveThread.ts
@@ -8,10 +8,10 @@ import {
setConvoMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import {
- ModelParams,
setActiveThreadIdAtom,
setThreadModelParamsAtom,
} from '@/helpers/atoms/Thread.atom'
+import { ModelParams } from '@/types/model'
export default function useSetActiveThread() {
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
diff --git a/web/hooks/useThread.test.ts b/web/hooks/useThread.test.ts
new file mode 100644
index 000000000..a40c709be
--- /dev/null
+++ b/web/hooks/useThread.test.ts
@@ -0,0 +1,192 @@
+// useThreads.test.ts
+
+import { renderHook, act } from '@testing-library/react'
+import { useSetAtom } from 'jotai'
+import { ExtensionTypeEnum } from '@janhq/core'
+import { extensionManager } from '@/extension/ExtensionManager'
+import useThreads from './useThreads'
+import {
+ threadDataReadyAtom,
+ threadModelParamsAtom,
+ threadsAtom,
+ threadStatesAtom,
+} from '@/helpers/atoms/Thread.atom'
+
+// Mock the necessary dependencies
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ useSetAtom: jest.fn(),
+ useAtom: jest.fn(),
+ atom: jest.fn(),
+}))
+jest.mock('@/extension/ExtensionManager')
+
+describe('useThreads', () => {
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ const mockThreads = [
+ {
+ id: 'thread1',
+ metadata: { lastMessage: 'Hello' },
+ assistants: [
+ {
+ model: {
+ parameters: { param1: 'value1' },
+ settings: { setting1: 'value1' },
+ },
+ },
+ ],
+ },
+ {
+ id: 'thread2',
+ metadata: { lastMessage: 'Hi there' },
+ assistants: [
+ {
+ model: {
+ parameters: { param2: 'value2' },
+ settings: { setting2: 'value2' },
+ },
+ },
+ ],
+ },
+ ]
+
+ it('should fetch and set threads data', async () => {
+ // Mock Jotai hooks
+ const mockSetThreadStates = jest.fn()
+ const mockSetThreads = jest.fn()
+ const mockSetThreadModelRuntimeParams = jest.fn()
+ const mockSetThreadDataReady = jest.fn()
+
+ ;(useSetAtom as jest.Mock).mockImplementation((atom) => {
+ switch (atom) {
+ case threadStatesAtom:
+ return mockSetThreadStates
+ case threadsAtom:
+ return mockSetThreads
+ case threadModelParamsAtom:
+ return mockSetThreadModelRuntimeParams
+ case threadDataReadyAtom:
+ return mockSetThreadDataReady
+ default:
+ return jest.fn()
+ }
+ })
+
+ // Mock extensionManager
+ const mockGetThreads = jest.fn().mockResolvedValue(mockThreads)
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ getThreads: mockGetThreads,
+ })
+
+ const { result } = renderHook(() => useThreads())
+
+ await act(async () => {
+ // Wait for useEffect to complete
+ })
+
+ // Assertions
+ expect(extensionManager.get).toHaveBeenCalledWith(
+ ExtensionTypeEnum.Conversational
+ )
+ expect(mockGetThreads).toHaveBeenCalled()
+
+ expect(mockSetThreadStates).toHaveBeenCalledWith({
+ thread1: {
+ hasMore: false,
+ waitingForResponse: false,
+ lastMessage: 'Hello',
+ },
+ thread2: {
+ hasMore: false,
+ waitingForResponse: false,
+ lastMessage: 'Hi there',
+ },
+ })
+
+ expect(mockSetThreads).toHaveBeenCalledWith(mockThreads)
+
+ expect(mockSetThreadModelRuntimeParams).toHaveBeenCalledWith({
+ thread1: { param1: 'value1', setting1: 'value1' },
+ thread2: { param2: 'value2', setting2: 'value2' },
+ })
+
+ expect(mockSetThreadDataReady).toHaveBeenCalledWith(true)
+ })
+
+ it('should handle empty threads', async () => {
+ // Mock empty threads
+ ;(extensionManager.get as jest.Mock).mockReturnValue({
+ getThreads: jest.fn().mockResolvedValue([]),
+ })
+
+ const mockSetThreadStates = jest.fn()
+ const mockSetThreads = jest.fn()
+ const mockSetThreadModelRuntimeParams = jest.fn()
+ const mockSetThreadDataReady = jest.fn()
+
+ ;(useSetAtom as jest.Mock).mockImplementation((atom) => {
+ switch (atom) {
+ case threadStatesAtom:
+ return mockSetThreadStates
+ case threadsAtom:
+ return mockSetThreads
+ case threadModelParamsAtom:
+ return mockSetThreadModelRuntimeParams
+ case threadDataReadyAtom:
+ return mockSetThreadDataReady
+ default:
+ return jest.fn()
+ }
+ })
+
+ const { result } = renderHook(() => useThreads())
+
+ await act(async () => {
+ // Wait for useEffect to complete
+ })
+
+ expect(mockSetThreadStates).toHaveBeenCalledWith({})
+ expect(mockSetThreads).toHaveBeenCalledWith([])
+ expect(mockSetThreadModelRuntimeParams).toHaveBeenCalledWith({})
+ expect(mockSetThreadDataReady).toHaveBeenCalledWith(true)
+ })
+
+ it('should handle missing ConversationalExtension', async () => {
+ // Mock missing ConversationalExtension
+ ;(extensionManager.get as jest.Mock).mockReturnValue(null)
+
+ const mockSetThreadStates = jest.fn()
+ const mockSetThreads = jest.fn()
+ const mockSetThreadModelRuntimeParams = jest.fn()
+ const mockSetThreadDataReady = jest.fn()
+
+ ;(useSetAtom as jest.Mock).mockImplementation((atom) => {
+ switch (atom) {
+ case threadStatesAtom:
+ return mockSetThreadStates
+ case threadsAtom:
+ return mockSetThreads
+ case threadModelParamsAtom:
+ return mockSetThreadModelRuntimeParams
+ case threadDataReadyAtom:
+ return mockSetThreadDataReady
+ default:
+ return jest.fn()
+ }
+ })
+
+ const { result } = renderHook(() => useThreads())
+
+ await act(async () => {
+ // Wait for useEffect to complete
+ })
+
+ expect(mockSetThreadStates).toHaveBeenCalledWith({})
+ expect(mockSetThreads).toHaveBeenCalledWith([])
+ expect(mockSetThreadModelRuntimeParams).toHaveBeenCalledWith({})
+ expect(mockSetThreadDataReady).toHaveBeenCalledWith(true)
+ })
+})
diff --git a/web/hooks/useThreads.ts b/web/hooks/useThreads.ts
index fd0b3456d..9366101c3 100644
--- a/web/hooks/useThreads.ts
+++ b/web/hooks/useThreads.ts
@@ -11,12 +11,12 @@ import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension/ExtensionManager'
import {
- ModelParams,
threadDataReadyAtom,
threadModelParamsAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Thread.atom'
+import { ModelParams } from '@/types/model'
const useThreads = () => {
const setThreadStates = useSetAtom(threadStatesAtom)
diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts
index af30210ad..2af6e3323 100644
--- a/web/hooks/useUpdateModelParameters.ts
+++ b/web/hooks/useUpdateModelParameters.ts
@@ -18,10 +18,10 @@ import {
import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
- ModelParams,
getActiveThreadModelParamsAtom,
setThreadModelParamsAtom,
} from '@/helpers/atoms/Thread.atom'
+import { ModelParams } from '@/types/model'
export type UpdateModelParameter = {
params?: ModelParams
diff --git a/web/types/model.d.ts b/web/types/model.d.ts
new file mode 100644
index 000000000..bbe9d2cc6
--- /dev/null
+++ b/web/types/model.d.ts
@@ -0,0 +1,4 @@
+/**
+ * ModelParams types
+ */
+export type ModelParams = ModelRuntimeParams | ModelSettingParams
diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts
index dda9cf761..315aeaeb3 100644
--- a/web/utils/modelParam.ts
+++ b/web/utils/modelParam.ts
@@ -2,7 +2,7 @@
/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
-import { ModelParams } from '@/helpers/atoms/Thread.atom'
+import { ModelParams } from '@/types/model'
/**
* Validation rules for model parameters