From 98bef7b7cffa811a67945d8c8f4659862c15026c Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 08:34:58 +0700
Subject: [PATCH 1/7] test: add model parameter validation rules and
persistence tests (#3618)
* test: add model parameter validation rules and persistence tests
* chore: fix CI cov step
* fix: invalid model settings should fallback to origin value
* test: support fallback integer settings
---
.../src/node/index.ts | 32 +-
.../tensorrt-llm-extension/src/node/index.ts | 4 +-
web/containers/Providers/EventHandler.tsx | 4 +-
web/containers/SliderRightPanel/index.tsx | 28 +-
web/hooks/useSendChatMessage.ts | 9 +-
web/hooks/useUpdateModelParameters.test.ts | 314 ++++++++++++++++++
web/hooks/useUpdateModelParameters.ts | 16 +-
.../LocalServerRightPanel/index.tsx | 13 +-
web/screens/Thread/ThreadRightPanel/index.tsx | 26 +-
web/utils/modelParam.test.ts | 183 ++++++++++
web/utils/modelParam.ts | 106 +++++-
11 files changed, 681 insertions(+), 54 deletions(-)
create mode 100644 web/hooks/useUpdateModelParameters.test.ts
create mode 100644 web/utils/modelParam.test.ts
diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts
index edc2d013d..3a969ad5e 100644
--- a/extensions/inference-nitro-extension/src/node/index.ts
+++ b/extensions/inference-nitro-extension/src/node/index.ts
@@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise {
if (!settings?.ngl) {
settings.ngl = 100
}
- log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
+ log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
method: 'POST',
headers: {
@@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise {
})
.then((res) => {
log(
- `[CORTEX]::Debug: Load model success with response ${JSON.stringify(
+ `[CORTEX]:: Load model success with response ${JSON.stringify(
res
)}`
)
@@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise {
async function validateModelStatus(modelId: string): Promise {
// Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
- log(`[CORTEX]::Debug: Validating model ${modelId}`)
+ log(`[CORTEX]:: Validating model ${modelId}`)
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
method: 'POST',
body: JSON.stringify({
@@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise {
retryDelay: 300,
}).then(async (res: Response) => {
log(
- `[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
+ `[CORTEX]:: Validate model state with response ${JSON.stringify(
res.status
)}`
)
@@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise {
// Otherwise, return an object with an error message.
if (body.model_loaded) {
log(
- `[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
+ `[CORTEX]:: Validate model state success with response ${JSON.stringify(
body
)}`
)
@@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise {
}
const errorBody = await res.text()
log(
- `[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
+ `[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
res.statusText
)}`
)
@@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise {
async function killSubprocess(): Promise {
const controller = new AbortController()
setTimeout(() => controller.abort(), 5000)
- log(`[CORTEX]::Debug: Request to kill cortex`)
+ log(`[CORTEX]:: Request to kill cortex`)
const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, {
@@ -321,17 +321,17 @@ async function killSubprocess(): Promise {
.then(() =>
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
)
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch((err) => {
log(
- `[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
+ `[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
})
}
if (subprocess?.pid && process.platform !== 'darwin') {
- log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
+ log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@@ -341,7 +341,7 @@ async function killSubprocess(): Promise {
} else {
tcpPortUsed
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.then(() => resolve())
.catch(() => {
log(
@@ -362,7 +362,7 @@ async function killSubprocess(): Promise {
* @returns A promise that resolves when the Nitro subprocess is started.
*/
function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
- log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
+ log(`[CORTEX]:: Spawning cortex subprocess...`)
return new Promise(async (resolve, reject) => {
let executableOptions = executableNitroFile(
@@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary
log(
- `[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
+ `[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
)
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
@@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
// Handle subprocess output
subprocess.stdout.on('data', (data: any) => {
- log(`[CORTEX]::Debug: ${data}`)
+ log(`[CORTEX]:: ${data}`)
})
subprocess.stderr.on('data', (data: any) => {
@@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
})
subprocess.on('close', (code: any) => {
- log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
+ log(`[CORTEX]:: cortex exited with code: ${code}`)
subprocess = undefined
reject(`child process exited with code ${code}`)
})
@@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise {
tcpPortUsed
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
.then(() => {
- log(`[CORTEX]::Debug: cortex is ready`)
+ log(`[CORTEX]:: cortex is ready`)
resolve()
})
})
diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts
index c8bc48459..77003389f 100644
--- a/extensions/tensorrt-llm-extension/src/node/index.ts
+++ b/extensions/tensorrt-llm-extension/src/node/index.ts
@@ -97,7 +97,7 @@ function unloadModel(): Promise {
}
if (subprocess?.pid) {
- log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
+ log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@@ -107,7 +107,7 @@ function unloadModel(): Promise {
return tcpPortUsed
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
.then(() => resolve())
- .then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
+ .then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch(() => {
killRequest()
})
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index e4c96aeb7..4809ce83e 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
-import { toRuntimeParams } from '@/utils/modelParam'
+import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import {
@@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
},
]
- const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
+ const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
const messageRequest: MessageRequest = {
id: msgId,
diff --git a/web/containers/SliderRightPanel/index.tsx b/web/containers/SliderRightPanel/index.tsx
index df415ffb5..c00d9f002 100644
--- a/web/containers/SliderRightPanel/index.tsx
+++ b/web/containers/SliderRightPanel/index.tsx
@@ -87,26 +87,28 @@ const SliderRightPanel = ({
onValueChanged?.(Number(min))
setVal(min.toString())
setShowTooltip({ max: false, min: true })
+ } else {
+ setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
}
}}
onChange={(e) => {
- // Should not accept invalid value or NaN
- // E.g. anything changes that trigger onValueChanged
- // Which is incorrect
- if (Number(e.target.value) > Number(max)) {
- setVal(max.toString())
- } else if (
- Number(e.target.value) < Number(min) ||
- !e.target.value.length
- ) {
- setVal(min.toString())
- } else if (Number.isNaN(Number(e.target.value))) return
-
- onValueChanged?.(Number(e.target.value))
// TODO: How to support negative number input?
+ // Passthru since it validates again onBlur
if (/^\d*\.?\d*$/.test(e.target.value)) {
setVal(e.target.value)
}
+
+ // Should not accept invalid value or NaN
+ // E.g. anything changes that trigger onValueChanged
+ // Which is incorrect
+ if (
+ Number(e.target.value) > Number(max) ||
+ Number(e.target.value) < Number(min) ||
+ Number.isNaN(Number(e.target.value))
+ ) {
+ return
+ }
+ onValueChanged?.(Number(e.target.value))
}}
/>
}
diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts
index 8c6013505..1dbd5b45e 100644
--- a/web/hooks/useSendChatMessage.ts
+++ b/web/hooks/useSendChatMessage.ts
@@ -23,7 +23,10 @@ import {
import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
@@ -189,8 +192,8 @@ export default function useSendChatMessage() {
if (engineParamsUpdate) setReloadModel(true)
- const runtimeParams = toRuntimeParams(activeModelParams)
- const settingParams = toSettingParams(activeModelParams)
+ const runtimeParams = extractInferenceParams(activeModelParams)
+ const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim()
diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts
new file mode 100644
index 000000000..bc60aa631
--- /dev/null
+++ b/web/hooks/useUpdateModelParameters.test.ts
@@ -0,0 +1,314 @@
+import { renderHook, act } from '@testing-library/react'
+// Mock dependencies
+jest.mock('ulidx')
+jest.mock('@/extension')
+
+import useUpdateModelParameters from './useUpdateModelParameters'
+import { extensionManager } from '@/extension'
+
+// Mock data
+let model: any = {
+ id: 'model-1',
+ engine: 'nitro',
+}
+
+let extension: any = {
+ saveThread: jest.fn(),
+}
+
+const mockThread: any = {
+ id: 'thread-1',
+ assistants: [
+ {
+ model: {
+ parameters: {},
+ settings: {},
+ },
+ },
+ ],
+ object: 'thread',
+ title: 'New Thread',
+ created: 0,
+ updated: 0,
+}
+
+describe('useUpdateModelParameters', () => {
+ beforeAll(() => {
+ jest.clearAllMocks()
+ jest.mock('./useRecommendedModel', () => ({
+ useRecommendedModel: () => ({
+ recommendedModel: model,
+ setRecommendedModel: jest.fn(),
+ downloadedModels: [],
+ }),
+ }))
+ })
+
+ it('should update model parameters and save thread when params are valid', async () => {
+ const mockValidParameters: any = {
+ params: {
+ // Inference
+ stop: ['', ''],
+ temperature: 0.5,
+ token_limit: 1000,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 0.3,
+ presence_penalty: 0.2,
+
+ // Load model
+ ctx_len: 1024,
+ ngl: 12,
+ embedding: true,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(mockThread, mockValidParameters)
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: ['', ''],
+ temperature: 0.5,
+ token_limit: 1000,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 0.3,
+ presence_penalty: 0.2,
+ },
+ settings: {
+ ctx_len: 1024,
+ ngl: 12,
+ embedding: true,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should not update invalid model parameters', async () => {
+ const mockInvalidParameters: any = {
+ params: {
+ // Inference
+ stop: [1, ''],
+ temperature: '0.5',
+ token_limit: '1000',
+ top_k: '0.7',
+ top_p: '0.1',
+ stream: 'true',
+ max_tokens: '1000',
+ frequency_penalty: '0.3',
+ presence_penalty: '0.2',
+
+ // Load model
+ ctx_len: '1024',
+ ngl: '12',
+ embedding: 'true',
+ n_parallel: '2',
+ cpu_threads: '4',
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockInvalidParameters
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ max_tokens: 1000,
+ token_limit: 1000,
+ },
+ settings: {
+ cpu_threads: 4,
+ ctx_len: 1024,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ n_parallel: 2,
+ ngl: 12,
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should update valid model parameters only', async () => {
+ const mockInvalidParameters: any = {
+ params: {
+ // Inference
+ stop: [''],
+ temperature: -0.5,
+ token_limit: 100.2,
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ max_tokens: 1000,
+ frequency_penalty: 1.2,
+ presence_penalty: 0.2,
+
+ // Load model
+ ctx_len: 1024,
+ ngl: 0,
+ embedding: 'true',
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ vision_model: 'vision',
+ text_model: 'text',
+ },
+ modelId: 'model-1',
+ engine: 'nitro',
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockInvalidParameters
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: [''],
+ top_k: 0.7,
+ top_p: 0.1,
+ stream: true,
+ token_limit: 100,
+ max_tokens: 1000,
+ presence_penalty: 0.2,
+ },
+ settings: {
+ ctx_len: 1024,
+ ngl: 0,
+ n_parallel: 2,
+ cpu_threads: 4,
+ prompt_template: 'template',
+ llama_model_path: 'path',
+ mmproj: 'mmproj',
+ },
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+
+ it('should handle missing modelId and engine gracefully', async () => {
+ const mockParametersWithoutModelIdAndEngine: any = {
+ params: {
+ stop: ['', ''],
+ temperature: 0.5,
+ },
+ }
+
+ // Spy functions
+ jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
+ jest.spyOn(extension, 'saveThread').mockReturnValue({})
+
+ const { result } = renderHook(() => useUpdateModelParameters())
+
+ await act(async () => {
+ await result.current.updateModelParameter(
+ mockThread,
+ mockParametersWithoutModelIdAndEngine
+ )
+ })
+
+ // Check if the model parameters are valid before persisting
+ expect(extension.saveThread).toHaveBeenCalledWith({
+ assistants: [
+ {
+ model: {
+ parameters: {
+ stop: ['', ''],
+ temperature: 0.5,
+ },
+ settings: {},
+ },
+ },
+ ],
+ created: 0,
+ id: 'thread-1',
+ object: 'thread',
+ title: 'New Thread',
+ updated: 0,
+ })
+ })
+})
diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts
index 79d877456..46bf07cd5 100644
--- a/web/hooks/useUpdateModelParameters.ts
+++ b/web/hooks/useUpdateModelParameters.ts
@@ -12,7 +12,10 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import useRecommendedModel from './useRecommendedModel'
@@ -47,12 +50,17 @@ export default function useUpdateModelParameters() {
const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId
? toUpdateSettings
- : { ...activeModelParams, ...toUpdateSettings }
+ : {
+ ...selectedModel?.parameters,
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ ...toUpdateSettings,
+ }
// update the state
setThreadModelParams(thread.id, updatedModelParams)
- const runtimeParams = toRuntimeParams(updatedModelParams)
- const settingParams = toSettingParams(updatedModelParams)
+ const runtimeParams = extractInferenceParams(updatedModelParams)
+ const settingParams = extractModelLoadParams(updatedModelParams)
const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {
diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx
index 309709c26..13e3cad57 100644
--- a/web/screens/LocalServer/LocalServerRightPanel/index.tsx
+++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx
@@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { getConfigurationsData } from '@/utils/componentSettings'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
@@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
const selectedModel = useAtomValue(selectedModelAtom)
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
- toSettingParams(selectedModel?.settings)
+ extractModelLoadParams(selectedModel?.settings)
)
useEffect(() => {
if (selectedModel) {
- setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
+ setCurrentModelSettingParams(
+ extractModelLoadParams(selectedModel?.settings)
+ )
}
}, [selectedModel])
- const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
+ const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,
diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx
index 9e7cdf7d8..e7d0a27b9 100644
--- a/web/screens/Thread/ThreadRightPanel/index.tsx
+++ b/web/screens/Thread/ThreadRightPanel/index.tsx
@@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
import { localEngines } from '@/utils/modelEngine'
-import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
+import {
+ extractInferenceParams,
+ extractModelLoadParams,
+} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools'
@@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
const settings = useMemo(() => {
// runtime setting
- const modelRuntimeParams = toRuntimeParams(activeModelParams)
+ const modelRuntimeParams = extractInferenceParams(
+ {
+ ...selectedModel?.parameters,
+ ...activeModelParams,
+ },
+ selectedModel?.parameters
+ )
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,
selectedModel
).filter((x) => x.key !== 'prompt_template')
// engine setting
- const modelEngineParams = toSettingParams(activeModelParams)
+ const modelEngineParams = extractModelLoadParams(
+ {
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ },
+ selectedModel?.settings
+ )
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel
@@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
}, [activeModelParams, selectedModel])
const promptTemplateSettings = useMemo(() => {
- const modelEngineParams = toSettingParams(activeModelParams)
+ const modelEngineParams = extractModelLoadParams({
+ ...selectedModel?.settings,
+ ...activeModelParams,
+ })
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel
diff --git a/web/utils/modelParam.test.ts b/web/utils/modelParam.test.ts
new file mode 100644
index 000000000..f1b858955
--- /dev/null
+++ b/web/utils/modelParam.test.ts
@@ -0,0 +1,183 @@
+// web/utils/modelParam.test.ts
+import { normalizeValue, validationRules } from './modelParam'
+
+describe('validationRules', () => {
+ it('should validate temperature correctly', () => {
+ expect(validationRules.temperature(0.5)).toBe(true)
+ expect(validationRules.temperature(2)).toBe(true)
+ expect(validationRules.temperature(0)).toBe(true)
+ expect(validationRules.temperature(-0.1)).toBe(false)
+ expect(validationRules.temperature(2.3)).toBe(false)
+ expect(validationRules.temperature('0.5')).toBe(false)
+ })
+
+ it('should validate token_limit correctly', () => {
+ expect(validationRules.token_limit(100)).toBe(true)
+ expect(validationRules.token_limit(1)).toBe(true)
+ expect(validationRules.token_limit(0)).toBe(true)
+ expect(validationRules.token_limit(-1)).toBe(false)
+ expect(validationRules.token_limit('100')).toBe(false)
+ })
+
+ it('should validate top_k correctly', () => {
+ expect(validationRules.top_k(0.5)).toBe(true)
+ expect(validationRules.top_k(1)).toBe(true)
+ expect(validationRules.top_k(0)).toBe(true)
+ expect(validationRules.top_k(-0.1)).toBe(false)
+ expect(validationRules.top_k(1.1)).toBe(false)
+ expect(validationRules.top_k('0.5')).toBe(false)
+ })
+
+ it('should validate top_p correctly', () => {
+ expect(validationRules.top_p(0.5)).toBe(true)
+ expect(validationRules.top_p(1)).toBe(true)
+ expect(validationRules.top_p(0)).toBe(true)
+ expect(validationRules.top_p(-0.1)).toBe(false)
+ expect(validationRules.top_p(1.1)).toBe(false)
+ expect(validationRules.top_p('0.5')).toBe(false)
+ })
+
+ it('should validate stream correctly', () => {
+ expect(validationRules.stream(true)).toBe(true)
+ expect(validationRules.stream(false)).toBe(true)
+ expect(validationRules.stream('true')).toBe(false)
+ expect(validationRules.stream(1)).toBe(false)
+ })
+
+ it('should validate max_tokens correctly', () => {
+ expect(validationRules.max_tokens(100)).toBe(true)
+ expect(validationRules.max_tokens(1)).toBe(true)
+ expect(validationRules.max_tokens(0)).toBe(true)
+ expect(validationRules.max_tokens(-1)).toBe(false)
+ expect(validationRules.max_tokens('100')).toBe(false)
+ })
+
+ it('should validate stop correctly', () => {
+ expect(validationRules.stop(['word1', 'word2'])).toBe(true)
+ expect(validationRules.stop([])).toBe(true)
+ expect(validationRules.stop(['word1', 2])).toBe(false)
+ expect(validationRules.stop('word1')).toBe(false)
+ })
+
+ it('should validate frequency_penalty correctly', () => {
+ expect(validationRules.frequency_penalty(0.5)).toBe(true)
+ expect(validationRules.frequency_penalty(1)).toBe(true)
+ expect(validationRules.frequency_penalty(0)).toBe(true)
+ expect(validationRules.frequency_penalty(-0.1)).toBe(false)
+ expect(validationRules.frequency_penalty(1.1)).toBe(false)
+ expect(validationRules.frequency_penalty('0.5')).toBe(false)
+ })
+
+ it('should validate presence_penalty correctly', () => {
+ expect(validationRules.presence_penalty(0.5)).toBe(true)
+ expect(validationRules.presence_penalty(1)).toBe(true)
+ expect(validationRules.presence_penalty(0)).toBe(true)
+ expect(validationRules.presence_penalty(-0.1)).toBe(false)
+ expect(validationRules.presence_penalty(1.1)).toBe(false)
+ expect(validationRules.presence_penalty('0.5')).toBe(false)
+ })
+
+ it('should validate ctx_len correctly', () => {
+ expect(validationRules.ctx_len(1024)).toBe(true)
+ expect(validationRules.ctx_len(1)).toBe(true)
+ expect(validationRules.ctx_len(0)).toBe(true)
+ expect(validationRules.ctx_len(-1)).toBe(false)
+ expect(validationRules.ctx_len('1024')).toBe(false)
+ })
+
+ it('should validate ngl correctly', () => {
+ expect(validationRules.ngl(12)).toBe(true)
+ expect(validationRules.ngl(1)).toBe(true)
+ expect(validationRules.ngl(0)).toBe(true)
+ expect(validationRules.ngl(-1)).toBe(false)
+ expect(validationRules.ngl('12')).toBe(false)
+ })
+
+ it('should validate embedding correctly', () => {
+ expect(validationRules.embedding(true)).toBe(true)
+ expect(validationRules.embedding(false)).toBe(true)
+ expect(validationRules.embedding('true')).toBe(false)
+ expect(validationRules.embedding(1)).toBe(false)
+ })
+
+ it('should validate n_parallel correctly', () => {
+ expect(validationRules.n_parallel(2)).toBe(true)
+ expect(validationRules.n_parallel(1)).toBe(true)
+ expect(validationRules.n_parallel(0)).toBe(true)
+ expect(validationRules.n_parallel(-1)).toBe(false)
+ expect(validationRules.n_parallel('2')).toBe(false)
+ })
+
+ it('should validate cpu_threads correctly', () => {
+ expect(validationRules.cpu_threads(4)).toBe(true)
+ expect(validationRules.cpu_threads(1)).toBe(true)
+ expect(validationRules.cpu_threads(0)).toBe(true)
+ expect(validationRules.cpu_threads(-1)).toBe(false)
+ expect(validationRules.cpu_threads('4')).toBe(false)
+ })
+
+ it('should validate prompt_template correctly', () => {
+ expect(validationRules.prompt_template('template')).toBe(true)
+ expect(validationRules.prompt_template('')).toBe(true)
+ expect(validationRules.prompt_template(123)).toBe(false)
+ })
+
+ it('should validate llama_model_path correctly', () => {
+ expect(validationRules.llama_model_path('path')).toBe(true)
+ expect(validationRules.llama_model_path('')).toBe(true)
+ expect(validationRules.llama_model_path(123)).toBe(false)
+ })
+
+ it('should validate mmproj correctly', () => {
+ expect(validationRules.mmproj('mmproj')).toBe(true)
+ expect(validationRules.mmproj('')).toBe(true)
+ expect(validationRules.mmproj(123)).toBe(false)
+ })
+
+ it('should validate vision_model correctly', () => {
+ expect(validationRules.vision_model(true)).toBe(true)
+ expect(validationRules.vision_model(false)).toBe(true)
+ expect(validationRules.vision_model('true')).toBe(false)
+ expect(validationRules.vision_model(1)).toBe(false)
+ })
+
+ it('should validate text_model correctly', () => {
+ expect(validationRules.text_model(true)).toBe(true)
+ expect(validationRules.text_model(false)).toBe(true)
+ expect(validationRules.text_model('true')).toBe(false)
+ expect(validationRules.text_model(1)).toBe(false)
+ })
+})
+
+describe('normalizeValue', () => {
+ it('should normalize ctx_len correctly', () => {
+ expect(normalizeValue('ctx_len', 100.5)).toBe(100)
+ expect(normalizeValue('ctx_len', '2')).toBe(2)
+ expect(normalizeValue('ctx_len', 100)).toBe(100)
+ })
+ it('should normalize token_limit correctly', () => {
+ expect(normalizeValue('token_limit', 100.5)).toBe(100)
+ expect(normalizeValue('token_limit', '1')).toBe(1)
+ expect(normalizeValue('token_limit', 0)).toBe(0)
+ })
+ it('should normalize max_tokens correctly', () => {
+ expect(normalizeValue('max_tokens', 100.5)).toBe(100)
+ expect(normalizeValue('max_tokens', '1')).toBe(1)
+ expect(normalizeValue('max_tokens', 0)).toBe(0)
+ })
+ it('should normalize ngl correctly', () => {
+ expect(normalizeValue('ngl', 12.5)).toBe(12)
+ expect(normalizeValue('ngl', '2')).toBe(2)
+ expect(normalizeValue('ngl', 0)).toBe(0)
+ })
+ it('should normalize n_parallel correctly', () => {
+ expect(normalizeValue('n_parallel', 2.5)).toBe(2)
+ expect(normalizeValue('n_parallel', '2')).toBe(2)
+ expect(normalizeValue('n_parallel', 0)).toBe(0)
+ })
+ it('should normalize cpu_threads correctly', () => {
+ expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
+ expect(normalizeValue('cpu_threads', '4')).toBe(4)
+ expect(normalizeValue('cpu_threads', 0)).toBe(0)
+ })
+})
diff --git a/web/utils/modelParam.ts b/web/utils/modelParam.ts
index a6d144c3e..dda9cf761 100644
--- a/web/utils/modelParam.ts
+++ b/web/utils/modelParam.ts
@@ -1,9 +1,69 @@
+/* eslint-disable @typescript-eslint/no-explicit-any */
+/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/helpers/atoms/Thread.atom'
-export const toRuntimeParams = (
- modelParams?: ModelParams
+/**
+ * Validation rules for model parameters
+ */
+export const validationRules: { [key: string]: (value: any) => boolean } = {
+ temperature: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 2,
+ token_limit: (value: any) => Number.isInteger(value) && value >= 0,
+ top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
+ top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
+ stream: (value: any) => typeof value === 'boolean',
+ max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
+ stop: (value: any) =>
+ Array.isArray(value) && value.every((v) => typeof v === 'string'),
+ frequency_penalty: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 1,
+ presence_penalty: (value: any) =>
+ typeof value === 'number' && value >= 0 && value <= 1,
+
+ ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
+ ngl: (value: any) => Number.isInteger(value) && value >= 0,
+ embedding: (value: any) => typeof value === 'boolean',
+ n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
+ cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
+ prompt_template: (value: any) => typeof value === 'string',
+ llama_model_path: (value: any) => typeof value === 'string',
+ mmproj: (value: any) => typeof value === 'string',
+ vision_model: (value: any) => typeof value === 'boolean',
+ text_model: (value: any) => typeof value === 'boolean',
+}
+
+/**
+ * There are some parameters that need to be normalized before being sent to the server
+ * E.g. ctx_len should be an integer, but it can be a float from the input field
+ * @param key
+ * @param value
+ * @returns
+ */
+export const normalizeValue = (key: string, value: any) => {
+ if (
+ key === 'token_limit' ||
+ key === 'max_tokens' ||
+ key === 'ctx_len' ||
+ key === 'ngl' ||
+ key === 'n_parallel' ||
+ key === 'cpu_threads'
+ ) {
+ // Convert to integer
+ return Math.floor(Number(value))
+ }
+ return value
+}
+
+/**
+ * Extract inference parameters from flat model parameters
+ * @param modelParams
+ * @returns
+ */
+export const extractInferenceParams = (
+ modelParams?: ModelParams,
+ originParams?: ModelParams
): ModelRuntimeParams => {
if (!modelParams) return {}
const defaultModelParams: ModelRuntimeParams = {
@@ -22,15 +82,35 @@ export const toRuntimeParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) {
- Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
+ const validate = validationRules[key]
+ if (validate && !validate(normalizeValue(key, value))) {
+ // Invalid value - fall back to origin value
+ if (originParams && key in originParams) {
+ Object.assign(runtimeParams, {
+ ...runtimeParams,
+ [key]: originParams[key as keyof typeof originParams],
+ })
+ }
+ } else {
+ Object.assign(runtimeParams, {
+ ...runtimeParams,
+ [key]: normalizeValue(key, value),
+ })
+ }
}
}
return runtimeParams
}
-export const toSettingParams = (
- modelParams?: ModelParams
+/**
+ * Extract model load parameters from flat model parameters
+ * @param modelParams
+ * @returns
+ */
+export const extractModelLoadParams = (
+ modelParams?: ModelParams,
+ originParams?: ModelParams
): ModelSettingParams => {
if (!modelParams) return {}
const defaultSettingParams: ModelSettingParams = {
@@ -49,7 +129,21 @@ export const toSettingParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) {
- Object.assign(settingParams, { ...settingParams, [key]: value })
+ const validate = validationRules[key]
+ if (validate && !validate(normalizeValue(key, value))) {
+ // Invalid value - fall back to origin value
+ if (originParams && key in originParams) {
+ Object.assign(modelParams, {
+ ...modelParams,
+ [key]: originParams[key as keyof typeof originParams],
+ })
+ }
+ } else {
+ Object.assign(settingParams, {
+ ...settingParams,
+ [key]: normalizeValue(key, value),
+ })
+ }
}
}
From 670013baa037003f82c29607345446a03dc07c0c Mon Sep 17 00:00:00 2001
From: Ronnie Ghose <1313566+RONNCC@users.noreply.github.com>
Date: Mon, 16 Sep 2024 19:25:08 -0700
Subject: [PATCH 2/7] Add support for 'o1-preview' and 'o1-mini' models (#3659)
Add support for 'o1-preview' and 'o1-mini' model names in the OpenAI API.
* **Update `models.json`**:
- Add 'o1-preview' model details with appropriate parameters and metadata.
- Add 'o1-mini' model details with appropriate parameters and metadata.
---
For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/janhq/jan?shareId=XXXX-XXXX-XXXX-XXXX).
---
.../resources/models.json | 60 +++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/extensions/inference-openai-extension/resources/models.json b/extensions/inference-openai-extension/resources/models.json
index 6852a1892..72517d540 100644
--- a/extensions/inference-openai-extension/resources/models.json
+++ b/extensions/inference-openai-extension/resources/models.json
@@ -119,5 +119,65 @@
]
},
"engine": "openai"
+ },
+ {
+ "sources": [
+ {
+ "url": "https://openai.com"
+ }
+ ],
+ "id": "o1-preview",
+ "object": "model",
+ "name": "OpenAI o1-preview",
+ "version": "1.0",
+ "description": "OpenAI o1-preview is a new model with complex reasoning",
+ "format": "api",
+ "settings": {},
+ "parameters": {
+ "max_tokens": 4096,
+ "temperature": 0.7,
+ "top_p": 0.95,
+ "stream": true,
+ "stop": [],
+ "frequency_penalty": 0,
+ "presence_penalty": 0
+ },
+ "metadata": {
+ "author": "OpenAI",
+ "tags": [
+ "General"
+ ]
+ },
+ "engine": "openai"
+ },
+ {
+ "sources": [
+ {
+ "url": "https://openai.com"
+ }
+ ],
+ "id": "o1-mini",
+ "object": "model",
+ "name": "OpenAI o1-mini",
+ "version": "1.0",
+ "description": "OpenAI o1-mini is a lightweight reasoning model",
+ "format": "api",
+ "settings": {},
+ "parameters": {
+ "max_tokens": 4096,
+ "temperature": 0.7,
+ "top_p": 0.95,
+ "stream": true,
+ "stop": [],
+ "frequency_penalty": 0,
+ "presence_penalty": 0
+ },
+ "metadata": {
+ "author": "OpenAI",
+ "tags": [
+ "General"
+ ]
+ },
+ "engine": "openai"
}
]
From c8a08f11155a64a2789a233f1518a0df041df623 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 09:25:55 +0700
Subject: [PATCH 3/7] fix: correct prompt template for Phi3 Medium model
(#3670)
---
extensions/inference-nitro-extension/package.json | 2 +-
.../resources/models/phi3-medium/model.json | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json
index 425e4b49c..ac3ed180a 100644
--- a/extensions/inference-nitro-extension/package.json
+++ b/extensions/inference-nitro-extension/package.json
@@ -1,7 +1,7 @@
{
"name": "@janhq/inference-cortex-extension",
"productName": "Cortex Inference Engine",
- "version": "1.0.16",
+ "version": "1.0.17",
"description": "This extension embeds cortex.cpp, a lightweight inference engine written in C++. See https://jan.ai.\nAdditional dependencies could be installed to run without Cuda Toolkit installation.",
"main": "dist/index.js",
"node": "dist/node/index.cjs.js",
diff --git a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
index 50944b9fe..7331b2fd8 100644
--- a/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
+++ b/extensions/inference-nitro-extension/resources/models/phi3-medium/model.json
@@ -8,12 +8,12 @@
"id": "phi3-medium",
"object": "model",
"name": "Phi-3 Medium Instruct Q4",
- "version": "1.3",
+ "version": "1.4",
"description": "Phi-3 Medium is Microsoft's latest SOTA model.",
"format": "gguf",
"settings": {
"ctx_len": 128000,
- "prompt_template": "<|user|> {prompt}<|end|><|assistant|><|end|>",
+ "prompt_template": "<|user|> {prompt}<|end|><|assistant|>",
"llama_model_path": "Phi-3-medium-128k-instruct-Q4_K_M.gguf",
"ngl": 33
},
From c3cb1924866e3b94c868d9eacc5c43190a955452 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 16:09:38 +0700
Subject: [PATCH 4/7] fix: #3667 - The recommended label should be hidden
(#3687)
---
web/containers/ModelLabel/ModelLabel.test.tsx | 100 ++++++++++++++++++
web/containers/ModelLabel/index.tsx | 6 +-
2 files changed, 101 insertions(+), 5 deletions(-)
create mode 100644 web/containers/ModelLabel/ModelLabel.test.tsx
diff --git a/web/containers/ModelLabel/ModelLabel.test.tsx b/web/containers/ModelLabel/ModelLabel.test.tsx
new file mode 100644
index 000000000..48504ff6a
--- /dev/null
+++ b/web/containers/ModelLabel/ModelLabel.test.tsx
@@ -0,0 +1,100 @@
+import React from 'react'
+import { render, waitFor, screen } from '@testing-library/react'
+import { useAtomValue } from 'jotai'
+import { useActiveModel } from '@/hooks/useActiveModel'
+import { useSettings } from '@/hooks/useSettings'
+import ModelLabel from '@/containers/ModelLabel'
+
+jest.mock('jotai', () => ({
+ useAtomValue: jest.fn(),
+ atom: jest.fn(),
+}))
+
+jest.mock('@/hooks/useActiveModel', () => ({
+ useActiveModel: jest.fn(),
+}))
+
+jest.mock('@/hooks/useSettings', () => ({
+ useSettings: jest.fn(),
+}))
+
+describe('ModelLabel', () => {
+ const mockUseAtomValue = useAtomValue as jest.Mock
+ const mockUseActiveModel = useActiveModel as jest.Mock
+ const mockUseSettings = useSettings as jest.Mock
+
+ const defaultProps: any = {
+ metadata: {
+ author: 'John Doe', // Add the 'author' property with a value
+ tags: ['8B'],
+ size: 100,
+ },
+ compact: false,
+ }
+
+ beforeEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('renders NotEnoughMemoryLabel when minimumRamModel is greater than totalRam', async () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(0)
+ .mockReturnValueOnce(0)
+ .mockReturnValueOnce(0)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ render()
+ await waitFor(() => {
+ expect(screen.getByText('Not enough RAM')).toBeDefined()
+ })
+ })
+
+ it('renders SlowOnYourDeviceLabel when minimumRamModel is less than totalRam but greater than availableRam', async () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(100)
+ .mockReturnValueOnce(50)
+ .mockReturnValueOnce(10)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ const props = {
+ ...defaultProps,
+ metadata: {
+ ...defaultProps.metadata,
+ size: 50,
+ },
+ }
+
+ render()
+ await waitFor(() => {
+ expect(screen.getByText('Slow on your device')).toBeDefined()
+ })
+ })
+
+ it('renders nothing when minimumRamModel is less than availableRam', () => {
+ mockUseAtomValue
+ .mockReturnValueOnce(100)
+ .mockReturnValueOnce(50)
+ .mockReturnValueOnce(0)
+ mockUseActiveModel.mockReturnValue({
+ activeModel: { metadata: { size: 0 } },
+ })
+ mockUseSettings.mockReturnValue({ settings: { run_mode: 'cpu' } })
+
+ const props = {
+ ...defaultProps,
+ metadata: {
+ ...defaultProps.metadata,
+ size: 10,
+ },
+ }
+
+ const { container } = render()
+ expect(container.firstChild).toBeNull()
+ })
+})
diff --git a/web/containers/ModelLabel/index.tsx b/web/containers/ModelLabel/index.tsx
index 2c32e288c..b0a3da96f 100644
--- a/web/containers/ModelLabel/index.tsx
+++ b/web/containers/ModelLabel/index.tsx
@@ -10,8 +10,6 @@ import { useSettings } from '@/hooks/useSettings'
import NotEnoughMemoryLabel from './NotEnoughMemoryLabel'
-import RecommendedLabel from './RecommendedLabel'
-
import SlowOnYourDeviceLabel from './SlowOnYourDeviceLabel'
import {
@@ -53,9 +51,7 @@ const ModelLabel = ({ metadata, compact }: Props) => {
/>
)
}
- if (minimumRamModel < availableRam && !compact) {
- return
- }
+
if (minimumRamModel < totalRam && minimumRamModel > availableRam) {
return
}
From 8e603bd5dbb80ef3050e313a0b046101ac81cc03 Mon Sep 17 00:00:00 2001
From: Louis
Date: Tue, 17 Sep 2024 16:43:47 +0700
Subject: [PATCH 5/7] fix: #3476 - Mismatch id between model json and path
(#3645)
* fix: mismatch between model json and path
* chore: revert preserve model settings
* test: add tests
---
.gitignore | 1 +
core/src/browser/core.test.ts | 179 +++---
core/src/browser/core.ts | 8 +
.../browser/extensions/engines/AIEngine.ts | 6 +-
.../extensions/engines/LocalOAIEngine.ts | 16 +-
core/src/browser/extensions/model.ts | 10 +-
core/src/node/api/processors/app.test.ts | 75 ++-
core/src/node/api/processors/app.ts | 16 +-
core/src/types/api/index.ts | 1 +
core/src/types/file/index.ts | 15 +
core/src/types/model/modelEntity.ts | 7 +
core/src/types/model/modelInterface.ts | 11 +-
.../inference-nitro-extension/src/index.ts | 3 +-
.../src/node/index.ts | 4 +-
extensions/model-extension/jest.config.js | 9 +
extensions/model-extension/package.json | 1 +
extensions/model-extension/rollup.config.ts | 4 +-
extensions/model-extension/src/index.test.ts | 564 ++++++++++++++++++
extensions/model-extension/src/index.ts | 85 ++-
extensions/model-extension/tsconfig.json | 3 +-
.../tensorrt-llm-extension/src/index.ts | 3 +-
web/containers/ModelDropdown/index.tsx | 22 +-
web/helpers/atoms/AppConfig.atom.ts | 7 -
web/helpers/atoms/Model.atom.ts | 19 +-
web/hooks/useActiveModel.ts | 6 +-
web/hooks/useCreateNewThread.ts | 20 +-
web/hooks/useDeleteModel.ts | 12 +-
web/hooks/useModels.ts | 5 +-
web/hooks/useRecommendedModel.ts | 12 +-
web/hooks/useUpdateModelParameters.ts | 60 +-
.../Hub/ModelList/ModelHeader/index.tsx | 4 +-
web/screens/Hub/ModelList/ModelItem/index.tsx | 4 +-
web/screens/Hub/ModelList/index.tsx | 14 +-
.../ModelDownloadRow/index.tsx | 8 +-
.../Settings/MyModels/MyModelList/index.tsx | 4 +-
35 files changed, 879 insertions(+), 339 deletions(-)
create mode 100644 extensions/model-extension/jest.config.js
create mode 100644 extensions/model-extension/src/index.test.ts
diff --git a/.gitignore b/.gitignore
index 646e6842a..eaee28a62 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,3 +45,4 @@ core/test_results.html
coverage
.yarn
.yarnrc
+*.tsbuildinfo
diff --git a/core/src/browser/core.test.ts b/core/src/browser/core.test.ts
index 84250888e..f38cc0b40 100644
--- a/core/src/browser/core.test.ts
+++ b/core/src/browser/core.test.ts
@@ -1,98 +1,109 @@
-import { openExternalUrl } from './core';
-import { joinPath } from './core';
-import { openFileExplorer } from './core';
-import { getJanDataFolderPath } from './core';
-import { abortDownload } from './core';
-import { getFileSize } from './core';
-import { executeOnMain } from './core';
+import { openExternalUrl } from './core'
+import { joinPath } from './core'
+import { openFileExplorer } from './core'
+import { getJanDataFolderPath } from './core'
+import { abortDownload } from './core'
+import { getFileSize } from './core'
+import { executeOnMain } from './core'
-it('should open external url', async () => {
- const url = 'http://example.com';
- globalThis.core = {
- api: {
- openExternalUrl: jest.fn().mockResolvedValue('opened')
+describe('test core apis', () => {
+ it('should open external url', async () => {
+ const url = 'http://example.com'
+ globalThis.core = {
+ api: {
+ openExternalUrl: jest.fn().mockResolvedValue('opened'),
+ },
}
- };
- const result = await openExternalUrl(url);
- expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url);
- expect(result).toBe('opened');
-});
+ const result = await openExternalUrl(url)
+ expect(globalThis.core.api.openExternalUrl).toHaveBeenCalledWith(url)
+ expect(result).toBe('opened')
+ })
-
-it('should join paths', async () => {
- const paths = ['/path/one', '/path/two'];
- globalThis.core = {
- api: {
- joinPath: jest.fn().mockResolvedValue('/path/one/path/two')
+ it('should join paths', async () => {
+ const paths = ['/path/one', '/path/two']
+ globalThis.core = {
+ api: {
+ joinPath: jest.fn().mockResolvedValue('/path/one/path/two'),
+ },
}
- };
- const result = await joinPath(paths);
- expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths);
- expect(result).toBe('/path/one/path/two');
-});
+ const result = await joinPath(paths)
+ expect(globalThis.core.api.joinPath).toHaveBeenCalledWith(paths)
+ expect(result).toBe('/path/one/path/two')
+ })
-
-it('should open file explorer', async () => {
- const path = '/path/to/open';
- globalThis.core = {
- api: {
- openFileExplorer: jest.fn().mockResolvedValue('opened')
+ it('should open file explorer', async () => {
+ const path = '/path/to/open'
+ globalThis.core = {
+ api: {
+ openFileExplorer: jest.fn().mockResolvedValue('opened'),
+ },
}
- };
- const result = await openFileExplorer(path);
- expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path);
- expect(result).toBe('opened');
-});
+ const result = await openFileExplorer(path)
+ expect(globalThis.core.api.openFileExplorer).toHaveBeenCalledWith(path)
+ expect(result).toBe('opened')
+ })
-
-it('should get jan data folder path', async () => {
- globalThis.core = {
- api: {
- getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data')
+ it('should get jan data folder path', async () => {
+ globalThis.core = {
+ api: {
+ getJanDataFolderPath: jest.fn().mockResolvedValue('/path/to/jan/data'),
+ },
}
- };
- const result = await getJanDataFolderPath();
- expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled();
- expect(result).toBe('/path/to/jan/data');
-});
+ const result = await getJanDataFolderPath()
+ expect(globalThis.core.api.getJanDataFolderPath).toHaveBeenCalled()
+ expect(result).toBe('/path/to/jan/data')
+ })
-
-it('should abort download', async () => {
- const fileName = 'testFile';
- globalThis.core = {
- api: {
- abortDownload: jest.fn().mockResolvedValue('aborted')
+ it('should abort download', async () => {
+ const fileName = 'testFile'
+ globalThis.core = {
+ api: {
+ abortDownload: jest.fn().mockResolvedValue('aborted'),
+ },
}
- };
- const result = await abortDownload(fileName);
- expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName);
- expect(result).toBe('aborted');
-});
+ const result = await abortDownload(fileName)
+ expect(globalThis.core.api.abortDownload).toHaveBeenCalledWith(fileName)
+ expect(result).toBe('aborted')
+ })
-
-it('should get file size', async () => {
- const url = 'http://example.com/file';
- globalThis.core = {
- api: {
- getFileSize: jest.fn().mockResolvedValue(1024)
+ it('should get file size', async () => {
+ const url = 'http://example.com/file'
+ globalThis.core = {
+ api: {
+ getFileSize: jest.fn().mockResolvedValue(1024),
+ },
}
- };
- const result = await getFileSize(url);
- expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url);
- expect(result).toBe(1024);
-});
+ const result = await getFileSize(url)
+ expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
+ expect(result).toBe(1024)
+ })
-
-it('should execute function on main process', async () => {
- const extension = 'testExtension';
- const method = 'testMethod';
- const args = ['arg1', 'arg2'];
- globalThis.core = {
- api: {
- invokeExtensionFunc: jest.fn().mockResolvedValue('result')
+ it('should execute function on main process', async () => {
+ const extension = 'testExtension'
+ const method = 'testMethod'
+ const args = ['arg1', 'arg2']
+ globalThis.core = {
+ api: {
+ invokeExtensionFunc: jest.fn().mockResolvedValue('result'),
+ },
}
- };
- const result = await executeOnMain(extension, method, ...args);
- expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args);
- expect(result).toBe('result');
-});
+ const result = await executeOnMain(extension, method, ...args)
+ expect(globalThis.core.api.invokeExtensionFunc).toHaveBeenCalledWith(extension, method, ...args)
+ expect(result).toBe('result')
+ })
+})
+
+describe('dirName - just a pass thru api', () => {
+ it('should retrieve the directory name from a file path', async () => {
+ const mockDirName = jest.fn()
+ globalThis.core = {
+ api: {
+ dirName: mockDirName.mockResolvedValue('/path/to'),
+ },
+ }
+ // Normal file path with extension
+ const path = '/path/to/file.txt'
+ await globalThis.core.api.dirName(path)
+ expect(mockDirName).toHaveBeenCalledWith(path)
+ })
+})
diff --git a/core/src/browser/core.ts b/core/src/browser/core.ts
index fdbceb06b..b19e0b339 100644
--- a/core/src/browser/core.ts
+++ b/core/src/browser/core.ts
@@ -68,6 +68,13 @@ const openFileExplorer: (path: string) => Promise = (path) =>
const joinPath: (paths: string[]) => Promise = (paths) =>
globalThis.core.api?.joinPath(paths)
+/**
+ * Get dirname of a file path.
+ * @param path - The file path to retrieve dirname.
+ * @returns {Promise} A promise that resolves the dirname.
+ */
+const dirName: (path: string) => Promise = (path) => globalThis.core.api?.dirName(path)
+
/**
* Retrieve the basename from an url.
* @param path - The path to retrieve.
@@ -161,5 +168,6 @@ export {
systemInformation,
showToast,
getFileSize,
+ dirName,
FileStat,
}
diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts
index 7cd9f513e..75354de88 100644
--- a/core/src/browser/extensions/engines/AIEngine.ts
+++ b/core/src/browser/extensions/engines/AIEngine.ts
@@ -2,7 +2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
-import { MessageRequest, Model, ModelEvent } from '../../../types'
+import { MessageRequest, Model, ModelEvent, ModelFile } from '../../../types'
import { EngineManager } from './EngineManager'
/**
@@ -21,7 +21,7 @@ export abstract class AIEngine extends BaseExtension {
override onLoad() {
this.registerEngine()
- events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
+ events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
@@ -78,7 +78,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads the model.
*/
- async loadModel(model: Model): Promise {
+ async loadModel(model: ModelFile): Promise {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts
index fb9e4962c..123b9a593 100644
--- a/core/src/browser/extensions/engines/LocalOAIEngine.ts
+++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts
@@ -1,6 +1,6 @@
-import { executeOnMain, getJanDataFolderPath, joinPath, systemInformation } from '../../core'
+import { executeOnMain, systemInformation, dirName } from '../../core'
import { events } from '../../events'
-import { Model, ModelEvent } from '../../../types'
+import { Model, ModelEvent, ModelFile } from '../../../types'
import { OAIEngine } from './OAIEngine'
/**
@@ -14,22 +14,24 @@ export abstract class LocalOAIEngine extends OAIEngine {
unloadModelFunctionName: string = 'unloadModel'
/**
- * On extension load, subscribe to events.
+ * This class represents a base for local inference providers in the OpenAI architecture.
+ * It extends the OAIEngine class and provides the implementation of loading and unloading models locally.
+ * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated.
+ * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped.
*/
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
- events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
+ events.on(ModelEvent.OnModelInit, (model: ModelFile) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
- override async loadModel(model: Model): Promise {
+ override async loadModel(model: ModelFile): Promise {
if (model.engine.toString() !== this.provider) return
- const modelFolderName = 'models'
- const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
+ const modelFolder = await dirName(model.file_path)
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts
index 5b3089403..040542927 100644
--- a/core/src/browser/extensions/model.ts
+++ b/core/src/browser/extensions/model.ts
@@ -4,6 +4,7 @@ import {
HuggingFaceRepoData,
ImportingModel,
Model,
+ ModelFile,
ModelInterface,
OptionType,
} from '../../types'
@@ -25,12 +26,11 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
network?: { proxy: string; ignoreSSL?: boolean }
): Promise
abstract cancelModelDownload(modelId: string): Promise
- abstract deleteModel(modelId: string): Promise
- abstract saveModel(model: Model): Promise
- abstract getDownloadedModels(): Promise
- abstract getConfiguredModels(): Promise
+ abstract deleteModel(model: ModelFile): Promise
+ abstract getDownloadedModels(): Promise
+ abstract getConfiguredModels(): Promise
abstract importModels(models: ImportingModel[], optionType: OptionType): Promise
- abstract updateModelInfo(modelInfo: Partial): Promise
+ abstract updateModelInfo(modelInfo: Partial): Promise
abstract fetchHuggingFaceRepoData(repoId: string): Promise
abstract getDefaultModel(): Promise
}
diff --git a/core/src/node/api/processors/app.test.ts b/core/src/node/api/processors/app.test.ts
index 3ada5df1e..5c4daef29 100644
--- a/core/src/node/api/processors/app.test.ts
+++ b/core/src/node/api/processors/app.test.ts
@@ -1,40 +1,57 @@
-import { App } from './app';
+jest.mock('../../helper', () => ({
+ ...jest.requireActual('../../helper'),
+ getJanDataFolderPath: () => './app',
+}))
+import { dirname } from 'path'
+import { App } from './app'
it('should call stopServer', () => {
- const app = new App();
- const stopServerMock = jest.fn().mockResolvedValue('Server stopped');
+ const app = new App()
+ const stopServerMock = jest.fn().mockResolvedValue('Server stopped')
jest.mock('@janhq/server', () => ({
- stopServer: stopServerMock
- }));
- const result = app.stopServer();
- expect(stopServerMock).toHaveBeenCalled();
-});
+ stopServer: stopServerMock,
+ }))
+ app.stopServer()
+ expect(stopServerMock).toHaveBeenCalled()
+})
it('should correctly retrieve basename', () => {
- const app = new App();
- const result = app.baseName('/path/to/file.txt');
- expect(result).toBe('file.txt');
-});
+ const app = new App()
+ const result = app.baseName('/path/to/file.txt')
+ expect(result).toBe('file.txt')
+})
it('should correctly identify subdirectories', () => {
- const app = new App();
- const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to';
- const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir';
- const result = app.isSubdirectory(basePath, subPath);
- expect(result).toBe(true);
-});
+ const app = new App()
+ const basePath = process.platform === 'win32' ? 'C:\\path\\to' : '/path/to'
+ const subPath = process.platform === 'win32' ? 'C:\\path\\to\\subdir' : '/path/to/subdir'
+ const result = app.isSubdirectory(basePath, subPath)
+ expect(result).toBe(true)
+})
it('should correctly join multiple paths', () => {
- const app = new App();
- const result = app.joinPath(['path', 'to', 'file']);
- const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file';
- expect(result).toBe(expectedPath);
-});
+ const app = new App()
+ const result = app.joinPath(['path', 'to', 'file'])
+ const expectedPath = process.platform === 'win32' ? 'path\\to\\file' : 'path/to/file'
+ expect(result).toBe(expectedPath)
+})
it('should call correct function with provided arguments using process method', () => {
- const app = new App();
- const mockFunc = jest.fn();
- app.joinPath = mockFunc;
- app.process('joinPath', ['path1', 'path2']);
- expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2']);
-});
+ const app = new App()
+ const mockFunc = jest.fn()
+ app.joinPath = mockFunc
+ app.process('joinPath', ['path1', 'path2'])
+ expect(mockFunc).toHaveBeenCalledWith(['path1', 'path2'])
+})
+
+it('should retrieve the directory name from a file path (Unix/Windows)', async () => {
+ const app = new App()
+ const path = 'C:/Users/John Doe/Desktop/file.txt'
+ expect(await app.dirName(path)).toBe('C:/Users/John Doe/Desktop')
+})
+
+it('should retrieve the directory name when using file protocol', async () => {
+ const app = new App()
+ const path = 'file:/models/file.txt'
+ expect(await app.dirName(path)).toBe(process.platform === 'win32' ? 'app\\models' : 'app/models')
+})
diff --git a/core/src/node/api/processors/app.ts b/core/src/node/api/processors/app.ts
index 15460ba56..a0808c5ac 100644
--- a/core/src/node/api/processors/app.ts
+++ b/core/src/node/api/processors/app.ts
@@ -1,4 +1,4 @@
-import { basename, isAbsolute, join, relative } from 'path'
+import { basename, dirname, isAbsolute, join, relative } from 'path'
import { Processor } from './Processor'
import {
@@ -6,6 +6,8 @@ import {
appResourcePath,
getAppConfigurations as appConfiguration,
updateAppConfiguration,
+ normalizeFilePath,
+ getJanDataFolderPath,
} from '../../helper'
export class App implements Processor {
@@ -28,6 +30,18 @@ export class App implements Processor {
return join(...args)
}
+ /**
+ * Get dirname of a file path.
+ * @param path - The file path to retrieve dirname.
+ */
+ dirName(path: string) {
+ const arg =
+ path.startsWith(`file:/`) || path.startsWith(`file:\\`)
+ ? join(getJanDataFolderPath(), normalizeFilePath(path))
+ : path
+ return dirname(arg)
+ }
+
/**
* Checks if the given path is a subdirectory of the given directory.
*
diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts
index bca11c0a8..8f1ff70bf 100644
--- a/core/src/types/api/index.ts
+++ b/core/src/types/api/index.ts
@@ -37,6 +37,7 @@ export enum AppRoute {
getAppConfigurations = 'getAppConfigurations',
updateAppConfiguration = 'updateAppConfiguration',
joinPath = 'joinPath',
+ dirName = 'dirName',
isSubdirectory = 'isSubdirectory',
baseName = 'baseName',
startServer = 'startServer',
diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts
index 1b36a5777..4db956b1e 100644
--- a/core/src/types/file/index.ts
+++ b/core/src/types/file/index.ts
@@ -52,3 +52,18 @@ type DownloadSize = {
total: number
transferred: number
}
+
+/**
+ * The file metadata
+ */
+export type FileMetadata = {
+ /**
+ * The origin file path.
+ */
+ file_path: string
+
+ /**
+ * The file name.
+ */
+ file_name: string
+}
diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts
index f154f7f04..933c698c3 100644
--- a/core/src/types/model/modelEntity.ts
+++ b/core/src/types/model/modelEntity.ts
@@ -1,3 +1,5 @@
+import { FileMetadata } from '../file'
+
/**
* Represents the information about a model.
* @stored
@@ -151,3 +153,8 @@ export type ModelRuntimeParams = {
export type ModelInitFailed = Model & {
error: Error
}
+
+/**
+ * ModelFile is the model.json entity and it's file metadata
+ */
+export type ModelFile = Model & FileMetadata
diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts
index 639c7c8d3..5b5856231 100644
--- a/core/src/types/model/modelInterface.ts
+++ b/core/src/types/model/modelInterface.ts
@@ -1,5 +1,5 @@
import { GpuSetting } from '../miscellaneous'
-import { Model } from './modelEntity'
+import { Model, ModelFile } from './modelEntity'
/**
* Model extension for managing models.
@@ -29,14 +29,7 @@ export interface ModelInterface {
* @param modelId - The ID of the model to delete.
* @returns A Promise that resolves when the model has been deleted.
*/
- deleteModel(modelId: string): Promise
-
- /**
- * Saves a model.
- * @param model - The model to save.
- * @returns A Promise that resolves when the model has been saved.
- */
- saveModel(model: Model): Promise
+ deleteModel(model: ModelFile): Promise
/**
* Gets a list of downloaded models.
diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts
index d79e076d4..6e825e8fd 100644
--- a/extensions/inference-nitro-extension/src/index.ts
+++ b/extensions/inference-nitro-extension/src/index.ts
@@ -22,6 +22,7 @@ import {
downloadFile,
DownloadState,
DownloadEvent,
+ ModelFile,
} from '@janhq/core'
declare const CUDA_DOWNLOAD_URL: string
@@ -94,7 +95,7 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
this.nitroProcessInfo = health
}
- override loadModel(model: Model): Promise {
+ override loadModel(model: ModelFile): Promise {
if (model.engine !== this.provider) return Promise.resolve()
this.getNitroProcessHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts
index 3a969ad5e..98ca4572f 100644
--- a/extensions/inference-nitro-extension/src/node/index.ts
+++ b/extensions/inference-nitro-extension/src/node/index.ts
@@ -6,12 +6,12 @@ import fetchRT from 'fetch-retry'
import {
log,
getSystemResourceInfo,
- Model,
InferenceEngine,
ModelSettingParams,
PromptTemplate,
SystemInformation,
getJanDataFolderPath,
+ ModelFile,
} from '@janhq/core/node'
import { executableNitroFile } from './execute'
import terminate from 'terminate'
@@ -25,7 +25,7 @@ const fetchRetry = fetchRT(fetch)
*/
interface ModelInitOptions {
modelFolder: string
- model: Model
+ model: ModelFile
}
// The PORT to use for the Nitro subprocess
const PORT = 3928
diff --git a/extensions/model-extension/jest.config.js b/extensions/model-extension/jest.config.js
new file mode 100644
index 000000000..3e32adceb
--- /dev/null
+++ b/extensions/model-extension/jest.config.js
@@ -0,0 +1,9 @@
+/** @type {import('ts-jest').JestConfigWithTsJest} */
+module.exports = {
+ preset: 'ts-jest',
+ testEnvironment: 'node',
+ transform: {
+ 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
+ },
+ transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
+}
diff --git a/extensions/model-extension/package.json b/extensions/model-extension/package.json
index 4a2c61b71..9a406dcf4 100644
--- a/extensions/model-extension/package.json
+++ b/extensions/model-extension/package.json
@@ -8,6 +8,7 @@
"author": "Jan ",
"license": "AGPL-3.0",
"scripts": {
+ "test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts --configPlugin @rollup/plugin-typescript --bundleConfigAsCjs",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install"
},
diff --git a/extensions/model-extension/rollup.config.ts b/extensions/model-extension/rollup.config.ts
index c3f3acc77..d36d8ffac 100644
--- a/extensions/model-extension/rollup.config.ts
+++ b/extensions/model-extension/rollup.config.ts
@@ -27,7 +27,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
- typescript({ useTsconfigDeclarationDir: true }),
+ typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Compile TypeScript files
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
// commonjs(),
@@ -62,7 +62,7 @@ export default [
// Allow json resolution
json(),
// Compile TypeScript files
- typescript({ useTsconfigDeclarationDir: true }),
+ typescript({ useTsconfigDeclarationDir: true, exclude: ['**/__tests__', '**/*.test.ts'], }),
// Allow bundling cjs modules (unlike webpack, rollup doesn't understand cjs)
commonjs(),
// Allow node_modules resolution, so you can use 'external' to control
diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts
new file mode 100644
index 000000000..6816d7101
--- /dev/null
+++ b/extensions/model-extension/src/index.test.ts
@@ -0,0 +1,564 @@
+const readDirSyncMock = jest.fn()
+const existMock = jest.fn()
+const readFileSyncMock = jest.fn()
+
+jest.mock('@janhq/core', () => ({
+ ...jest.requireActual('@janhq/core/node'),
+ fs: {
+ existsSync: existMock,
+ readdirSync: readDirSyncMock,
+ readFileSync: readFileSyncMock,
+ fileStat: () => ({
+ isDirectory: false,
+ }),
+ },
+ dirName: jest.fn(),
+ joinPath: (paths) => paths.join('/'),
+ ModelExtension: jest.fn(),
+}))
+
+import JanModelExtension from '.'
+import { fs, dirName } from '@janhq/core'
+
+describe('JanModelExtension', () => {
+ let sut: JanModelExtension
+
+ beforeAll(() => {
+ // @ts-ignore
+ sut = new JanModelExtension()
+ })
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ describe('getConfiguredModels', () => {
+ describe("when there's no models are pre-populated", () => {
+ it('should return empty array', async () => {
+ // Mock configured models data
+ const configuredModels = []
+ existMock.mockReturnValue(true)
+ readDirSyncMock.mockReturnValue([])
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual([])
+ })
+ })
+
+ describe("when there's are pre-populated models - all flattened", () => {
+ it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+
+ describe("when there's are pre-populated models - there are nested folders", () => {
+ it('returns configured models data - flatten folder - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else if (path.includes('model2/model2-1'))
+ return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getConfiguredModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('getDownloadedModels', () => {
+ describe('no models downloaded', () => {
+ it('should return empty array', async () => {
+ // Mock downloaded models data
+ const downloadedModels = []
+ existMock.mockReturnValue(true)
+ readDirSyncMock.mockReturnValue([])
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual([])
+ })
+ })
+ describe('only one model is downloaded', () => {
+ describe('flatten folder', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.gguf']
+ else return ['model.json']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('all models are downloaded', () => {
+ describe('nested folders', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else return ['model.json', 'test.gguf']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('all models are downloaded with uppercased GGUF files', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.GGUF']
+ else return ['model.json', 'test.gguf']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+
+ describe('all models are downloaded - GGUF & Tensort RT', () => {
+ it('returns downloaded models - with correct file_path and model id', async () => {
+ // Mock configured models data
+ const configuredModels = [
+ {
+ id: '1',
+ name: 'Model 1',
+ version: '1.0.0',
+ description: 'Model 1 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model1',
+ },
+ format: 'onnx',
+ sources: [],
+ created: new Date(),
+ updated: new Date(),
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ {
+ id: '2',
+ name: 'Model 2',
+ version: '2.0.0',
+ description: 'Model 2 description',
+ object: {
+ type: 'model',
+ uri: 'http://localhost:5000/models/model2',
+ },
+ format: 'onnx',
+ sources: [],
+ parameters: {},
+ settings: {},
+ metadata: {},
+ engine: 'test',
+ } as any,
+ ]
+ existMock.mockReturnValue(true)
+
+ readDirSyncMock.mockImplementation((path) => {
+ if (path === 'file://models') return ['model1', 'model2/model2-1']
+ else if (path === 'file://models/model1')
+ return ['model.json', 'test.gguf']
+ else return ['model.json', 'test.engine']
+ })
+
+ readFileSyncMock.mockImplementation((path) => {
+ if (path.includes('model1'))
+ return JSON.stringify(configuredModels[0])
+ else return JSON.stringify(configuredModels[1])
+ })
+
+ const result = await sut.getDownloadedModels()
+ expect(result).toEqual(
+ expect.arrayContaining([
+ expect.objectContaining({
+ file_path: 'file://models/model1/model.json',
+ id: '1',
+ }),
+ expect.objectContaining({
+ file_path: 'file://models/model2/model2-1/model.json',
+ id: '2',
+ }),
+ ])
+ )
+ })
+ })
+ })
+
+ describe('deleteModel', () => {
+ describe('model is a GGUF model', () => {
+ it('should delete the GGUF file', async () => {
+ fs.unlinkSync = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ readDirSyncMock.mockImplementation((path) => {
+ return ['model.json', 'test.gguf']
+ })
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledWith(
+ 'file://models/model1/test.gguf'
+ )
+ })
+
+ it('no gguf file presented', async () => {
+ fs.unlinkSync = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ readDirSyncMock.mockReturnValue(['model.json'])
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledTimes(0)
+ })
+
+ it('delete an imported model', async () => {
+ fs.rm = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ readDirSyncMock.mockReturnValue(['model.json', 'test.gguf'])
+
+ // MARK: This is a tricky logic implement?
+ // I will just add test for now but will align on the legacy implementation
+ fs.readFileSync = jest.fn().mockReturnValue(
+ JSON.stringify({
+ metadata: {
+ author: 'user',
+ },
+ })
+ )
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.rm).toHaveBeenCalledWith('file://models/model1')
+ })
+
+ it('delete tensorrt-models', async () => {
+ fs.rm = jest.fn()
+ const dirMock = dirName as jest.Mock
+ dirMock.mockReturnValue('file://models/model1')
+
+ readDirSyncMock.mockReturnValue(['model.json', 'test.engine'])
+
+ fs.readFileSync = jest.fn().mockReturnValue(JSON.stringify({}))
+
+ existMock.mockReturnValue(true)
+
+ await sut.deleteModel({
+ file_path: 'file://models/model1/model.json',
+ } as any)
+
+ expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
+ })
+ })
+ })
+})
diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts
index e2f68a58c..ac9b06a09 100644
--- a/extensions/model-extension/src/index.ts
+++ b/extensions/model-extension/src/index.ts
@@ -22,6 +22,8 @@ import {
getFileSize,
AllQuantizations,
ModelEvent,
+ ModelFile,
+ dirName,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@@ -48,16 +50,7 @@ export default class JanModelExtension extends ModelExtension {
]
private static readonly _tensorRtEngineFormat = '.engine'
private static readonly _supportedGpuArch = ['ampere', 'ada']
- private static readonly _safetensorsRegexs = [
- /model\.safetensors$/,
- /model-[0-9]+-of-[0-9]+\.safetensors$/,
- ]
- private static readonly _pytorchRegexs = [
- /pytorch_model\.bin$/,
- /consolidated\.[0-9]+\.pth$/,
- /pytorch_model-[0-9]+-of-[0-9]+\.bin$/,
- /.*\.pt$/,
- ]
+
interrupted = false
/**
@@ -319,9 +312,9 @@ export default class JanModelExtension extends ModelExtension {
* @param filePath - The path to the model file to delete.
* @returns A Promise that resolves when the model is deleted.
*/
- async deleteModel(modelId: string): Promise {
+ async deleteModel(model: ModelFile): Promise {
try {
- const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
+ const dirPath = await dirName(model.file_path)
const jsonFilePath = await joinPath([
dirPath,
JanModelExtension._modelMetadataFileName,
@@ -330,9 +323,11 @@ export default class JanModelExtension extends ModelExtension {
await this.readModelMetadata(jsonFilePath)
) as Model
+ // TODO: This is so tricky?
+ // Should depend on sources?
const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user'
- if (isUserImportModel) {
+ if (isUserImportModel) {
// just delete the folder
return fs.rm(dirPath)
}
@@ -350,30 +345,11 @@ export default class JanModelExtension extends ModelExtension {
}
}
- /**
- * Saves a model file.
- * @param model - The model to save.
- * @returns A Promise that resolves when the model is saved.
- */
- async saveModel(model: Model): Promise {
- const jsonFilePath = await joinPath([
- JanModelExtension._homeDir,
- model.id,
- JanModelExtension._modelMetadataFileName,
- ])
-
- try {
- await fs.writeFileSync(jsonFilePath, JSON.stringify(model, null, 2))
- } catch (err) {
- console.error(err)
- }
- }
-
/**
* Gets all downloaded models.
* @returns A Promise that resolves with an array of all models.
*/
- async getDownloadedModels(): Promise {
+ async getDownloadedModels(): Promise {
return await this.getModelsMetadata(
async (modelDir: string, model: Model) => {
if (!JanModelExtension._offlineInferenceEngine.includes(model.engine))
@@ -425,8 +401,10 @@ export default class JanModelExtension extends ModelExtension {
): Promise {
// try to find model.json recursively inside each folder
if (!(await fs.existsSync(folderFullPath))) return undefined
+
const files: string[] = await fs.readdirSync(folderFullPath)
if (files.length === 0) return undefined
+
if (files.includes(JanModelExtension._modelMetadataFileName)) {
return joinPath([
folderFullPath,
@@ -446,7 +424,7 @@ export default class JanModelExtension extends ModelExtension {
private async getModelsMetadata(
selector?: (path: string, model: Model) => Promise
- ): Promise {
+ ): Promise {
try {
if (!(await fs.existsSync(JanModelExtension._homeDir))) {
console.debug('Model folder not found')
@@ -469,6 +447,7 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._homeDir,
dirName,
])
+
const jsonPath = await this.getModelJsonPath(folderFullPath)
if (await fs.existsSync(jsonPath)) {
@@ -486,6 +465,8 @@ export default class JanModelExtension extends ModelExtension {
},
]
}
+ model.file_path = jsonPath
+ model.file_name = JanModelExtension._modelMetadataFileName
if (selector && !(await selector?.(dirName, model))) {
return
@@ -506,7 +487,7 @@ export default class JanModelExtension extends ModelExtension {
typeof result.value === 'object'
? result.value
: JSON.parse(result.value)
- return model as Model
+ return model as ModelFile
} catch {
console.debug(`Unable to parse model metadata: ${result.value}`)
}
@@ -637,7 +618,7 @@ export default class JanModelExtension extends ModelExtension {
* Gets all available models.
* @returns A Promise that resolves with an array of all models.
*/
- async getConfiguredModels(): Promise {
+ async getConfiguredModels(): Promise {
return this.getModelsMetadata()
}
@@ -669,7 +650,7 @@ export default class JanModelExtension extends ModelExtension {
modelBinaryPath: string,
modelFolderName: string,
modelFolderPath: string
- ): Promise {
+ ): Promise {
const fileStats = await fs.fileStat(modelBinaryPath, true)
const binaryFileSize = fileStats.size
@@ -732,25 +713,21 @@ export default class JanModelExtension extends ModelExtension {
await fs.writeFileSync(modelFilePath, JSON.stringify(model, null, 2))
- return model
+ return {
+ ...model,
+ file_path: modelFilePath,
+ file_name: JanModelExtension._modelMetadataFileName,
+ }
}
- async updateModelInfo(modelInfo: Partial): Promise {
- const modelId = modelInfo.id
+ async updateModelInfo(modelInfo: Partial): Promise {
if (modelInfo.id == null) throw new Error('Model ID is required')
- const janDataFolderPath = await getJanDataFolderPath()
- const jsonFilePath = await joinPath([
- janDataFolderPath,
- 'models',
- modelId,
- JanModelExtension._modelMetadataFileName,
- ])
const model = JSON.parse(
- await this.readModelMetadata(jsonFilePath)
- ) as Model
+ await this.readModelMetadata(modelInfo.file_path)
+ ) as ModelFile
- const updatedModel: Model = {
+ const updatedModel: ModelFile = {
...model,
...modelInfo,
parameters: {
@@ -765,9 +742,15 @@ export default class JanModelExtension extends ModelExtension {
...model.metadata,
...modelInfo.metadata,
},
+ // Should not persist file_path & file_name
+ file_path: undefined,
+ file_name: undefined,
}
- await fs.writeFileSync(jsonFilePath, JSON.stringify(updatedModel, null, 2))
+ await fs.writeFileSync(
+ modelInfo.file_path,
+ JSON.stringify(updatedModel, null, 2)
+ )
return updatedModel
}
diff --git a/extensions/model-extension/tsconfig.json b/extensions/model-extension/tsconfig.json
index addd8e127..0d3252934 100644
--- a/extensions/model-extension/tsconfig.json
+++ b/extensions/model-extension/tsconfig.json
@@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
- "include": ["./src"]
+ "include": ["./src"],
+ "exclude": ["**/*.test.ts"]
}
diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts
index 189abc706..7f68c43bd 100644
--- a/extensions/tensorrt-llm-extension/src/index.ts
+++ b/extensions/tensorrt-llm-extension/src/index.ts
@@ -23,6 +23,7 @@ import {
ModelEvent,
getJanDataFolderPath,
SystemInformation,
+ ModelFile,
} from '@janhq/core'
/**
@@ -137,7 +138,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {})
}
- override async loadModel(model: Model): Promise {
+ override async loadModel(model: ModelFile): Promise {
if ((await this.installationState()) === 'Installed')
return super.loadModel(model)
diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx
index 92d8addd0..d8743ddce 100644
--- a/web/containers/ModelDropdown/index.tsx
+++ b/web/containers/ModelDropdown/index.tsx
@@ -46,7 +46,6 @@ import {
import { extensionManager } from '@/extension'
-import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import {
configuredModelsAtom,
@@ -91,8 +90,6 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured')
)
- const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
-
const { updateThreadMetadata } = useCreateNewThread()
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
@@ -191,27 +188,14 @@ const ModelDropdown = ({
],
})
- // Default setting ctx_len for the model for a better onboarding experience
- // TODO: When Cortex support hardware instructions, we should remove this
- const defaultContextLength = preserveModelSettings
- ? model?.metadata?.default_ctx_len
- : 2048
- const defaultMaxTokens = preserveModelSettings
- ? model?.metadata?.default_max_tokens
- : 2048
const overriddenSettings =
- model?.settings.ctx_len && model.settings.ctx_len > 2048
- ? { ctx_len: defaultContextLength ?? 2048 }
- : {}
- const overriddenParameters =
- model?.parameters.max_tokens && model.parameters.max_tokens
- ? { max_tokens: defaultMaxTokens ?? 2048 }
+ model?.settings.ctx_len && model.settings.ctx_len > 4096
+ ? { ctx_len: 4096 }
: {}
const modelParams = {
...model?.parameters,
...model?.settings,
- ...overriddenParameters,
...overriddenSettings,
}
@@ -222,6 +206,7 @@ const ModelDropdown = ({
if (model)
updateModelParameter(activeThread, {
params: modelParams,
+ modelPath: model.file_path,
modelId: model.id,
engine: model.engine,
})
@@ -235,7 +220,6 @@ const ModelDropdown = ({
setThreadModelParams,
updateModelParameter,
updateThreadMetadata,
- preserveModelSettings,
]
)
diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts
index e7b7efaec..f4acc7dc2 100644
--- a/web/helpers/atoms/AppConfig.atom.ts
+++ b/web/helpers/atoms/AppConfig.atom.ts
@@ -7,7 +7,6 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled'
-const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
export const janDataFolderPathAtom = atom('')
@@ -24,9 +23,3 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
export const hostAtom = atom('http://localhost:1337/')
-
-// This feature is to allow user to cache model settings on thread creation
-export const preserveModelSettingsAtom = atomWithStorage(
- PRESERVE_MODEL_SETTINGS,
- false
-)
diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts
index 77b1bfa4e..d2d0ca9f4 100644
--- a/web/helpers/atoms/Model.atom.ts
+++ b/web/helpers/atoms/Model.atom.ts
@@ -1,4 +1,4 @@
-import { ImportingModel, Model, InferenceEngine } from '@janhq/core'
+import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
import { localEngines } from '@/utils/modelEngine'
@@ -32,18 +32,7 @@ export const removeDownloadingModelAtom = atom(
}
)
-export const downloadedModelsAtom = atom([])
-
-export const updateDownloadedModelAtom = atom(
- null,
- (get, set, updatedModel: Model) => {
- const models: Model[] = get(downloadedModelsAtom).map((c) =>
- c.id === updatedModel.id ? updatedModel : c
- )
-
- set(downloadedModelsAtom, models)
- }
-)
+export const downloadedModelsAtom = atom([])
export const removeDownloadedModelAtom = atom(
null,
@@ -57,7 +46,7 @@ export const removeDownloadedModelAtom = atom(
}
)
-export const configuredModelsAtom = atom([])
+export const configuredModelsAtom = atom([])
export const defaultModelAtom = atom(undefined)
@@ -144,6 +133,6 @@ export const updateImportingModelAtom = atom(
}
)
-export const selectedModelAtom = atom(undefined)
+export const selectedModelAtom = atom(undefined)
export const showEngineListModelAtom = atom(localEngines)
diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts
index 9768ac4c4..2d53678c3 100644
--- a/web/hooks/useActiveModel.ts
+++ b/web/hooks/useActiveModel.ts
@@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef } from 'react'
-import { EngineManager, Model } from '@janhq/core'
+import { EngineManager, Model, ModelFile } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast'
@@ -11,7 +11,7 @@ import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
-export const activeModelAtom = atom(undefined)
+export const activeModelAtom = atom(undefined)
export const loadModelErrorAtom = atom(undefined)
type ModelState = {
@@ -37,7 +37,7 @@ export function useActiveModel() {
const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom)
const isVulkanEnabled = useAtomValue(vulkanEnabledAtom)
- const downloadedModelsRef = useRef([])
+ const downloadedModelsRef = useRef([])
useEffect(() => {
downloadedModelsRef.current = downloadedModels
diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts
index 80acfa3cc..5548259fd 100644
--- a/web/hooks/useCreateNewThread.ts
+++ b/web/hooks/useCreateNewThread.ts
@@ -7,8 +7,8 @@ import {
Thread,
ThreadAssistantInfo,
ThreadState,
- Model,
AssistantTool,
+ ModelFile,
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'
@@ -26,10 +26,7 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension'
-import {
- experimentalFeatureEnabledAtom,
- preserveModelSettingsAtom,
-} from '@/helpers/atoms/AppConfig.atom'
+import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
@@ -67,7 +64,6 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
- const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
@@ -80,7 +76,7 @@ export const useCreateNewThread = () => {
const requestCreateNewThread = async (
assistant: Assistant,
- model?: Model | undefined
+ model?: ModelFile | undefined
) => {
// Stop generating if any
setIsGeneratingResponse(false)
@@ -109,19 +105,13 @@ export const useCreateNewThread = () => {
enabled: true,
settings: assistant.tools && assistant.tools[0].settings,
}
- const defaultContextLength = preserveModelSettings
- ? defaultModel?.metadata?.default_ctx_len
- : 2048
- const defaultMaxTokens = preserveModelSettings
- ? defaultModel?.metadata?.default_max_tokens
- : 2048
const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
- ? { ctx_len: defaultContextLength ?? 2048 }
+ ? { ctx_len: 4096 }
: {}
const overriddenParameters = defaultModel?.parameters.max_tokens
- ? { max_tokens: defaultMaxTokens ?? 2048 }
+ ? { max_tokens: 4096 }
: {}
const createdAt = Date.now()
diff --git a/web/hooks/useDeleteModel.ts b/web/hooks/useDeleteModel.ts
index 9736f8256..5a7a319b2 100644
--- a/web/hooks/useDeleteModel.ts
+++ b/web/hooks/useDeleteModel.ts
@@ -1,6 +1,6 @@
import { useCallback } from 'react'
-import { ExtensionTypeEnum, ModelExtension, Model } from '@janhq/core'
+import { ExtensionTypeEnum, ModelExtension, ModelFile } from '@janhq/core'
import { useSetAtom } from 'jotai'
@@ -13,8 +13,8 @@ export default function useDeleteModel() {
const removeDownloadedModel = useSetAtom(removeDownloadedModelAtom)
const deleteModel = useCallback(
- async (model: Model) => {
- await localDeleteModel(model.id)
+ async (model: ModelFile) => {
+ await localDeleteModel(model)
removeDownloadedModel(model.id)
toaster({
title: 'Model Deletion Successful',
@@ -28,5 +28,7 @@ export default function useDeleteModel() {
return { deleteModel }
}
-const localDeleteModel = async (id: string) =>
- extensionManager.get(ExtensionTypeEnum.Model)?.deleteModel(id)
+const localDeleteModel = async (model: ModelFile) =>
+ extensionManager
+ .get(ExtensionTypeEnum.Model)
+ ?.deleteModel(model)
diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts
index 5a6f13e03..8333c35c3 100644
--- a/web/hooks/useModels.ts
+++ b/web/hooks/useModels.ts
@@ -5,6 +5,7 @@ import {
Model,
ModelEvent,
ModelExtension,
+ ModelFile,
events,
} from '@janhq/core'
@@ -63,12 +64,12 @@ const getLocalDefaultModel = async (): Promise =>
.get(ExtensionTypeEnum.Model)
?.getDefaultModel()
-const getLocalConfiguredModels = async (): Promise =>
+const getLocalConfiguredModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getConfiguredModels() ?? []
-const getLocalDownloadedModels = async (): Promise =>
+const getLocalDownloadedModels = async (): Promise =>
extensionManager
.get(ExtensionTypeEnum.Model)
?.getDownloadedModels() ?? []
diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts
index 21a9c69e7..ed56efa55 100644
--- a/web/hooks/useRecommendedModel.ts
+++ b/web/hooks/useRecommendedModel.ts
@@ -1,6 +1,6 @@
import { useCallback, useEffect, useState } from 'react'
-import { Model, InferenceEngine } from '@janhq/core'
+import { Model, InferenceEngine, ModelFile } from '@janhq/core'
import { atom, useAtomValue } from 'jotai'
@@ -24,12 +24,16 @@ export const LAST_USED_MODEL_ID = 'last-used-model-id'
*/
export default function useRecommendedModel() {
const activeModel = useAtomValue(activeModelAtom)
- const [sortedModels, setSortedModels] = useState([])
- const [recommendedModel, setRecommendedModel] = useState()
+ const [sortedModels, setSortedModels] = useState([])
+ const [recommendedModel, setRecommendedModel] = useState<
+ ModelFile | undefined
+ >()
const activeThread = useAtomValue(activeThreadAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
- const getAndSortDownloadedModels = useCallback(async (): Promise => {
+ const getAndSortDownloadedModels = useCallback(async (): Promise<
+ ModelFile[]
+ > => {
const models = downloadedModels.sort((a, b) =>
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
? 1
diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts
index 46bf07cd5..af30210ad 100644
--- a/web/hooks/useUpdateModelParameters.ts
+++ b/web/hooks/useUpdateModelParameters.ts
@@ -4,8 +4,6 @@ import {
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
- Model,
- ModelExtension,
Thread,
ThreadAssistantInfo,
} from '@janhq/core'
@@ -17,14 +15,8 @@ import {
extractModelLoadParams,
} from '@/utils/modelParam'
-import useRecommendedModel from './useRecommendedModel'
-
import { extensionManager } from '@/extension'
-import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
-import {
- selectedModelAtom,
- updateDownloadedModelAtom,
-} from '@/helpers/atoms/Model.atom'
+import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
ModelParams,
getActiveThreadModelParamsAtom,
@@ -34,16 +26,14 @@ import {
export type UpdateModelParameter = {
params?: ModelParams
modelId?: string
+ modelPath?: string
engine?: InferenceEngine
}
export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
- const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
+ const [selectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
- const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
- const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
- const { recommendedModel, setRecommendedModel } = useRecommendedModel()
const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => {
@@ -83,50 +73,8 @@ export default function useUpdateModelParameters() {
await extensionManager
.get(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
-
- // Persists default settings to model file
- // Do not overwrite ctx_len and max_tokens
- if (preserveModelFeatureEnabled) {
- const defaultContextLength = settingParams.ctx_len
- const defaultMaxTokens = runtimeParams.max_tokens
-
- // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
- const { ctx_len, ...toSaveSettings } = settingParams
- // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-unused-vars
- const { max_tokens, ...toSaveParams } = runtimeParams
-
- const updatedModel = {
- id: settings.modelId ?? selectedModel?.id,
- parameters: {
- ...toSaveSettings,
- },
- settings: {
- ...toSaveParams,
- },
- metadata: {
- default_ctx_len: defaultContextLength,
- default_max_tokens: defaultMaxTokens,
- },
- } as Partial
-
- const model = await extensionManager
- .get(ExtensionTypeEnum.Model)
- ?.updateModelInfo(updatedModel)
- if (model) updateDownloadedModel(model)
- if (selectedModel?.id === model?.id) setSelectedModel(model)
- if (recommendedModel?.id === model?.id) setRecommendedModel(model)
- }
},
- [
- activeModelParams,
- selectedModel,
- setThreadModelParams,
- preserveModelFeatureEnabled,
- updateDownloadedModel,
- setSelectedModel,
- recommendedModel,
- setRecommendedModel,
- ]
+ [activeModelParams, selectedModel, setThreadModelParams]
)
const processStopWords = (params: ModelParams): ModelParams => {
diff --git a/web/screens/Hub/ModelList/ModelHeader/index.tsx b/web/screens/Hub/ModelList/ModelHeader/index.tsx
index b20977aff..44a3fd278 100644
--- a/web/screens/Hub/ModelList/ModelHeader/index.tsx
+++ b/web/screens/Hub/ModelList/ModelHeader/index.tsx
@@ -1,6 +1,6 @@
import { useCallback } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { Button, Badge, Tooltip } from '@janhq/joi'
import { useAtomValue, useSetAtom } from 'jotai'
@@ -38,7 +38,7 @@ import {
} from '@/helpers/atoms/SystemBar.atom'
type Props = {
- model: Model
+ model: ModelFile
onClick: () => void
open: string
}
diff --git a/web/screens/Hub/ModelList/ModelItem/index.tsx b/web/screens/Hub/ModelList/ModelItem/index.tsx
index c9b2f1329..ec9d885a1 100644
--- a/web/screens/Hub/ModelList/ModelItem/index.tsx
+++ b/web/screens/Hub/ModelList/ModelItem/index.tsx
@@ -1,6 +1,6 @@
import { useState } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { Badge } from '@janhq/joi'
import { twMerge } from 'tailwind-merge'
@@ -12,7 +12,7 @@ import ModelItemHeader from '@/screens/Hub/ModelList/ModelHeader'
import { toGibibytes } from '@/utils/converter'
type Props = {
- model: Model
+ model: ModelFile
}
const ModelItem: React.FC = ({ model }) => {
diff --git a/web/screens/Hub/ModelList/index.tsx b/web/screens/Hub/ModelList/index.tsx
index aea67b4e3..8fc30d541 100644
--- a/web/screens/Hub/ModelList/index.tsx
+++ b/web/screens/Hub/ModelList/index.tsx
@@ -1,6 +1,6 @@
import { useMemo } from 'react'
-import { Model } from '@janhq/core'
+import { ModelFile } from '@janhq/core'
import { useAtomValue } from 'jotai'
@@ -9,16 +9,16 @@ import ModelItem from '@/screens/Hub/ModelList/ModelItem'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
type Props = {
- models: Model[]
+ models: ModelFile[]
}
const ModelList = ({ models }: Props) => {
const downloadedModels = useAtomValue(downloadedModelsAtom)
- const sortedModels: Model[] = useMemo(() => {
- const featuredModels: Model[] = []
- const remoteModels: Model[] = []
- const localModels: Model[] = []
- const remainingModels: Model[] = []
+ const sortedModels: ModelFile[] = useMemo(() => {
+ const featuredModels: ModelFile[] = []
+ const remoteModels: ModelFile[] = []
+ const localModels: ModelFile[] = []
+ const remainingModels: ModelFile[] = []
models.forEach((m) => {
if (m.metadata?.tags?.includes('Featured')) {
featuredModels.push(m)
diff --git a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
index 951a11d59..c3f09f171 100644
--- a/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
+++ b/web/screens/Settings/HuggingFaceRepoDetailModal/ModelDownloadRow/index.tsx
@@ -53,7 +53,7 @@ const ModelDownloadRow: React.FC = ({
const { requestCreateNewThread } = useCreateNewThread()
const setMainViewState = useSetAtom(mainViewStateAtom)
const assistants = useAtomValue(assistantsAtom)
- const isDownloaded = downloadedModels.find((md) => md.id === fileName) != null
+ const downloadedModel = downloadedModels.find((md) => md.id === fileName)
const setHfImportingStage = useSetAtom(importHuggingFaceModelStageAtom)
const defaultModel = useAtomValue(defaultModelAtom)
@@ -100,12 +100,12 @@ const ModelDownloadRow: React.FC = ({
alert('No assistant available')
return
}
- await requestCreateNewThread(assistants[0], model)
+ await requestCreateNewThread(assistants[0], downloadedModel)
setMainViewState(MainViewState.Thread)
setHfImportingStage('NONE')
}, [
assistants,
- model,
+ downloadedModel,
requestCreateNewThread,
setMainViewState,
setHfImportingStage,
@@ -139,7 +139,7 @@ const ModelDownloadRow: React.FC = ({
- {isDownloaded ? (
+ {downloadedModel ? (
-
+
Clear
From c62b6e984282003d14160ce1b222c66fa4b79038 Mon Sep 17 00:00:00 2001
From: Faisal Amir
Date: Tue, 17 Sep 2024 22:13:18 +0700
Subject: [PATCH 7/7] fix: small leftover issues with new starter screen
(#3661)
* fix: fix duplicate render progress component
* fix: minor ui issue
* chore: add manual recommend model
* chore: make button create thread invisible
* chore: fix conflict
* chore: remove selector create thread icon
* test: added unit test thread screen
---
electron/tests/e2e/thread.e2e.spec.ts | 29 ++++---
web/containers/Layout/RibbonPanel/index.tsx | 16 ++--
web/containers/Layout/TopPanel/index.tsx | 5 +-
web/helpers/atoms/Thread.atom.ts | 3 +
web/hooks/useStarterScreen.ts | 7 +-
.../ChatBody/OnDeviceStarterScreen/index.tsx | 78 +++++++++++--------
web/screens/Thread/index.test.tsx | 35 +++++++++
7 files changed, 109 insertions(+), 64 deletions(-)
create mode 100644 web/screens/Thread/index.test.tsx
diff --git a/electron/tests/e2e/thread.e2e.spec.ts b/electron/tests/e2e/thread.e2e.spec.ts
index c13e91119..5d7328053 100644
--- a/electron/tests/e2e/thread.e2e.spec.ts
+++ b/electron/tests/e2e/thread.e2e.spec.ts
@@ -1,32 +1,29 @@
import { expect } from '@playwright/test'
import { page, test, TIMEOUT } from '../config/fixtures'
-test('Select GPT model from Hub and Chat with Invalid API Key', async ({ hubPage }) => {
+test('Select GPT model from Hub and Chat with Invalid API Key', async ({
+ hubPage,
+}) => {
await hubPage.navigateByMenu()
await hubPage.verifyContainerVisible()
// Select the first GPT model
await page
.locator('[data-testid^="use-model-btn"][data-testid*="gpt"]')
- .first().click()
-
- // Attempt to create thread and chat in Thread page
- await page
- .getByTestId('btn-create-thread')
+ .first()
.click()
- await page
- .getByTestId('txt-input-chat')
- .fill('dummy value')
+ await page.getByTestId('txt-input-chat').fill('dummy value')
- await page
- .getByTestId('btn-send-chat')
- .click()
+ await page.getByTestId('btn-send-chat').click()
- await page.waitForFunction(() => {
- const loaders = document.querySelectorAll('[data-testid$="loader"]');
- return !loaders.length;
- }, { timeout: TIMEOUT });
+ await page.waitForFunction(
+ () => {
+ const loaders = document.querySelectorAll('[data-testid$="loader"]')
+ return !loaders.length
+ },
+ { timeout: TIMEOUT }
+ )
const APIKeyError = page.getByTestId('invalid-API-key-error')
await expect(APIKeyError).toBeVisible({
diff --git a/web/containers/Layout/RibbonPanel/index.tsx b/web/containers/Layout/RibbonPanel/index.tsx
index 6bed2b424..7613584e0 100644
--- a/web/containers/Layout/RibbonPanel/index.tsx
+++ b/web/containers/Layout/RibbonPanel/index.tsx
@@ -12,17 +12,18 @@ import { twMerge } from 'tailwind-merge'
import { MainViewState } from '@/constants/screens'
-import { localEngines } from '@/utils/modelEngine'
-
import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
-import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
+
import {
reduceTransparentAtom,
selectedSettingAtom,
} from '@/helpers/atoms/Setting.atom'
-import { threadsAtom } from '@/helpers/atoms/Thread.atom'
+import {
+ isDownloadALocalModelAtom,
+ threadsAtom,
+} from '@/helpers/atoms/Thread.atom'
export default function RibbonPanel() {
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
@@ -32,8 +33,9 @@ export default function RibbonPanel() {
const matches = useMediaQuery('(max-width: 880px)')
const reduceTransparent = useAtomValue(reduceTransparentAtom)
const setSelectedSetting = useSetAtom(selectedSettingAtom)
- const downloadedModels = useAtomValue(downloadedModelsAtom)
+
const threads = useAtomValue(threadsAtom)
+ const isDownloadALocalModel = useAtomValue(isDownloadALocalModelAtom)
const onMenuClick = (state: MainViewState) => {
if (mainViewState === state) return
@@ -43,10 +45,6 @@ export default function RibbonPanel() {
setEditMessage('')
}
- const isDownloadALocalModel = downloadedModels.some((x) =>
- localEngines.includes(x.engine)
- )
-
const RibbonNavMenus = [
{
name: 'Thread',
diff --git a/web/containers/Layout/TopPanel/index.tsx b/web/containers/Layout/TopPanel/index.tsx
index 213f7dfa9..aff616973 100644
--- a/web/containers/Layout/TopPanel/index.tsx
+++ b/web/containers/Layout/TopPanel/index.tsx
@@ -23,6 +23,7 @@ import { toaster } from '@/containers/Toast'
import { MainViewState } from '@/constants/screens'
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
+import { useStarterScreen } from '@/hooks/useStarterScreen'
import {
mainViewStateAtom,
@@ -58,6 +59,8 @@ const TopPanel = () => {
requestCreateNewThread(assistants[0])
}
+ const { isShowStarterScreen } = useStarterScreen()
+
return (
{
)}
)}
- {mainViewState === MainViewState.Thread && (
+ {mainViewState === MainViewState.Thread && !isShowStarterScreen && (
(false)
+export const isAnyRemoteModelConfiguredAtom = atom(false)
diff --git a/web/hooks/useStarterScreen.ts b/web/hooks/useStarterScreen.ts
index fbd6ef578..1a6bbfbc7 100644
--- a/web/hooks/useStarterScreen.ts
+++ b/web/hooks/useStarterScreen.ts
@@ -63,14 +63,9 @@ export function useStarterScreen() {
(x) => x.apiKey.length > 1
)
- let isShowStarterScreen
-
- isShowStarterScreen =
+ const isShowStarterScreen =
!isAnyRemoteModelConfigured && !isDownloadALocalModel && !threads.length
- // Remove this part when we rework on starter screen
- isShowStarterScreen = false
-
return {
extensionHasSettings,
isShowStarterScreen,
diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
index 3ae32ca8c..26036a627 100644
--- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
+++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx
@@ -58,9 +58,21 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
const configuredModels = useAtomValue(configuredModelsAtom)
const setMainViewState = useSetAtom(mainViewStateAtom)
- const featuredModel = configuredModels.filter(
- (x) => x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
- )
+ const recommendModel = ['gemma-2-2b-it', 'llama3.1-8b-instruct']
+
+ const featuredModel = configuredModels.filter((x) => {
+ const manualRecommendModel = configuredModels.filter((x) =>
+ recommendModel.includes(x.id)
+ )
+
+ if (manualRecommendModel.length === 2) {
+ return x.id === recommendModel[0] || x.id === recommendModel[1]
+ } else {
+ return (
+ x.metadata.tags.includes('Featured') && x.metadata.size < 5000000000
+ )
+ }
+ })
const remoteModel = configuredModels.filter(
(x) => !localEngines.includes(x.engine)
@@ -105,7 +117,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
width={48}
height={48}
/>
- Select a model to start
+ Select a model to start
@@ -120,7 +132,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/>
@@ -205,39 +217,41 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return (
-
{featModel.name}
-
+
{featModel.name}
+
{featModel.metadata.author}
{isDownloading ? (
- {Object.values(downloadStates).map((item, i) => (
-
-
-
-
-
- {formatDownloadPercentage(item?.percent)}
-
+ {Object.values(downloadStates)
+ .filter((x) => x.modelId === featModel.id)
+ .map((item, i) => (
+
+
+
+
+
+ {formatDownloadPercentage(item?.percent)}
+
+
-
- ))}
+ ))}
) : (
@@ -248,7 +262,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
>
Download
-
+
{toGibibytes(featModel.metadata.size)}
@@ -257,7 +271,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
)
})}
-
+
Cloud Models
@@ -268,7 +282,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
return (
{row.map((remoteEngine) => {
const engineLogo = getLogoEngine(
@@ -298,7 +312,7 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
/>
)}
-
+
{getTitleByEngine(
remoteEngine as InferenceEngine
)}
diff --git a/web/screens/Thread/index.test.tsx b/web/screens/Thread/index.test.tsx
new file mode 100644
index 000000000..01af0ffc5
--- /dev/null
+++ b/web/screens/Thread/index.test.tsx
@@ -0,0 +1,35 @@
+import React from 'react'
+import { render, screen } from '@testing-library/react'
+import ThreadScreen from './index'
+import { useStarterScreen } from '../../hooks/useStarterScreen'
+import '@testing-library/jest-dom'
+
+global.ResizeObserver = class {
+ observe() {}
+ unobserve() {}
+ disconnect() {}
+}
+// Mock the useStarterScreen hook
+jest.mock('@/hooks/useStarterScreen')
+
+describe('ThreadScreen', () => {
+ it('renders OnDeviceStarterScreen when isShowStarterScreen is true', () => {
+ ;(useStarterScreen as jest.Mock).mockReturnValue({
+ isShowStarterScreen: true,
+ extensionHasSettings: false,
+ })
+
+ const { getByText } = render()
+ expect(getByText('Select a model to start')).toBeInTheDocument()
+ })
+
+ it('renders Thread panels when isShowStarterScreen is false', () => {
+ ;(useStarterScreen as jest.Mock).mockReturnValue({
+ isShowStarterScreen: false,
+ extensionHasSettings: false,
+ })
+
+ const { getByText } = render()
+ expect(getByText('Welcome!')).toBeInTheDocument()
+ })
+})