Merge pull request #4046 from janhq/fix/model-pull-options-http-proxy

feat: model pull options - HTTP Proxy requests route to cortex.cpp
This commit is contained in:
Louis 2024-11-19 21:50:54 +07:00 committed by GitHub
commit c501641c91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 170 additions and 19 deletions

View File

@ -12,6 +12,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
return ExtensionTypeEnum.Model
}
abstract configurePullOptions(configs: { [key: string]: any }): Promise<any>
abstract getModels(): Promise<Model[]>
abstract pullModel(model: string, id?: string, name?: string): Promise<void>
abstract cancelModelPull(modelId: string): Promise<void>

View File

@ -1 +1 @@
1.0.3-rc1
1.0.3-rc4

View File

@ -69,11 +69,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
super.onLoad()
await this.queue.add(() => this.clean())
this.queue.add(() => this.healthz())
this.queue.add(() => this.setDefaultEngine(systemInfo))
// Run the process watchdog
const systemInfo = await systemInformation()
await this.clean()
await executeOnMain(NODE, 'run', systemInfo)
this.subscribeToEvents()
@ -160,7 +160,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
return ky
.get(`${CORTEX_API_URL}/healthz`, {
retry: {
limit: 10,
limit: 20,
delay: () => 500,
methods: ['get'],
},
})
@ -192,6 +193,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
return ky
.delete(`${CORTEX_API_URL}/processmanager/destroy`, {
timeout: 2000, // maximum 2 seconds
retry: {
limit: 0,
},
})
.catch(() => {
// Do nothing

View File

@ -1,6 +1,6 @@
import PQueue from 'p-queue'
import ky from 'ky'
import { extractModelLoadParams, Model } from '@janhq/core'
import { extractModelLoadParams, Model } from '@janhq/core'
import { extractInferenceParams } from '@janhq/core'
/**
* cortex.cpp Model APIs interface
@ -18,6 +18,7 @@ interface ICortexAPI {
deleteModel(model: string): Promise<void>
updateModel(model: object): Promise<void>
cancelModelPull(model: string): Promise<void>
configs(body: { [key: string]: any }): Promise<void>
}
type ModelList = {
@ -52,7 +53,7 @@ export class CortexAPI implements ICortexAPI {
*/
getModels(): Promise<Model[]> {
return this.queue
.add(() => ky.get(`${API_URL}/models`).json<ModelList>())
.add(() => ky.get(`${API_URL}/v1/models`).json<ModelList>())
.then((e) =>
typeof e === 'object' ? e.data.map((e) => this.transformModel(e)) : []
)
@ -104,7 +105,7 @@ export class CortexAPI implements ICortexAPI {
*/
deleteModel(model: string): Promise<void> {
return this.queue.add(() =>
ky.delete(`${API_URL}/models/${model}`).json().then()
ky.delete(`${API_URL}/v1/models/${model}`).json().then()
)
}
@ -130,7 +131,7 @@ export class CortexAPI implements ICortexAPI {
cancelModelPull(model: string): Promise<void> {
return this.queue.add(() =>
ky
.delete(`${API_URL}/models/pull`, { json: { taskId: model } })
.delete(`${API_URL}/v1/models/pull`, { json: { taskId: model } })
.json()
.then()
)
@ -142,7 +143,7 @@ export class CortexAPI implements ICortexAPI {
*/
async getModelStatus(model: string): Promise<boolean> {
return this.queue
.add(() => ky.get(`${API_URL}/models/status/${model}`))
.add(() => ky.get(`${API_URL}/v1/models/status/${model}`))
.then((e) => true)
.catch(() => false)
}
@ -155,13 +156,24 @@ export class CortexAPI implements ICortexAPI {
return ky
.get(`${API_URL}/healthz`, {
retry: {
limit: 10,
limit: 20,
delay: () => 500,
methods: ['get'],
},
})
.then(() => {})
}
/**
* Configure model pull options
* @param body
*/
configs(body: { [key: string]: any }): Promise<void> {
return this.queue.add(() =>
ky.patch(`${API_URL}/v1/configs`, { json: body }).then(() => {})
)
}
/**
* TRansform model to the expected format (e.g. parameters, settings, metadata)
* @param model

View File

@ -227,6 +227,13 @@ export default class JanModelExtension extends ModelExtension {
return this.cortexAPI.getModelStatus(model)
}
/**
* Configure pull options such as proxy, headers, etc.
*/
async configurePullOptions(options: { [key: string]: any }): Promise<any> {
return this.cortexAPI.configs(options).catch((e) => console.debug(e))
}
/**
* Handle download state from main app
*/

View File

@ -267,7 +267,7 @@ const updateGpuInfo = async () =>
}
data = await updateCudaExistence(data)
console.log(data)
console.log('[MONITORING]::Cuda info: ', data)
writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2))
log(`[APP]::${JSON.stringify(data)}`)
resolve({})
@ -344,7 +344,7 @@ const updateCudaExistence = async (
data.cuda.version = match[1]
}
}
console.log(data)
console.log('[MONITORING]::Finalized cuda info update: ', data)
resolve()
})
})

View File

@ -29,7 +29,7 @@ const DataLoader: React.FC<Props> = ({ children }) => {
const setQuickAskEnabled = useSetAtom(quickAskEnabledAtom)
const setJanDefaultDataFolder = useSetAtom(defaultJanDataFolderAtom)
const setJanSettingScreen = useSetAtom(janSettingScreenAtom)
const { loadDataModel } = useModels()
const { loadDataModel, configurePullOptions } = useModels()
useThreads()
useAssistants()
@ -39,6 +39,8 @@ const DataLoader: React.FC<Props> = ({ children }) => {
useEffect(() => {
// Load data once
loadDataModel()
// Configure pull options once
configurePullOptions()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])

View File

@ -1,5 +1,5 @@
// useModels.test.ts
import { renderHook, act } from '@testing-library/react'
import { renderHook, act, waitFor } from '@testing-library/react'
import { events, ModelEvent, ModelManager } from '@janhq/core'
import { extensionManager } from '@/extension'
@ -36,7 +36,6 @@ describe('useModels', () => {
}),
get: () => undefined,
has: () => true,
// set: () => {}
},
})
@ -50,6 +49,85 @@ describe('useModels', () => {
expect(mockModelExtension.getModels).toHaveBeenCalled()
})
it('should return empty on error', async () => {
const mockModelExtension = {
getModels: jest.fn().mockRejectedValue(new Error('Error')),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
const { result } = renderHook(() => useModels())
await act(() => {
result.current?.loadDataModel()
})
expect(mockModelExtension.getModels()).rejects.toThrow()
})
it('should update states on models update', async () => {
const mockModelExtension = {
getModels: jest.fn().mockResolvedValue(models),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
jest.spyOn(events, 'on').mockImplementationOnce((event, cb) => {
cb({ fetch: false })
})
renderHook(() => useModels())
expect(mockModelExtension.getModels).not.toHaveBeenCalled()
})
it('should update states on models update', async () => {
const mockModelExtension = {
getModels: jest.fn().mockResolvedValue(models),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
jest.spyOn(events, 'on').mockImplementationOnce((event, cb) => {
cb({ fetch: true })
})
renderHook(() => useModels())
expect(mockModelExtension.getModels).toHaveBeenCalled()
})
it('should remove event listener on unmount', async () => {
const removeListenerSpy = jest.spyOn(events, 'off')

View File

@ -9,13 +9,18 @@ import {
ModelManager,
} from '@janhq/core'
import { useSetAtom, useAtom } from 'jotai'
import { useSetAtom, useAtom, useAtomValue } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension'
import {
ignoreSslAtom,
proxyAtom,
proxyEnabledAtom,
} from '@/helpers/atoms/AppConfig.atom'
import {
configuredModelsAtom,
downloadedModelsAtom,
@ -29,6 +34,9 @@ import {
const useModels = () => {
const [downloadedModels, setDownloadedModels] = useAtom(downloadedModelsAtom)
const setExtensionModels = useSetAtom(configuredModelsAtom)
const proxyEnabled = useAtomValue(proxyEnabledAtom)
const proxyUrl = useAtomValue(proxyAtom)
const proxyIgnoreSSL = useAtomValue(ignoreSslAtom)
const getData = useCallback(() => {
const getDownloadedModels = async () => {
@ -107,8 +115,25 @@ const useModels = () => {
}
}, [reloadData, updateStates])
const configurePullOptions = useCallback(() => {
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.configurePullOptions(
proxyEnabled
? {
proxy_url: proxyUrl,
verify_peer_ssl: !proxyIgnoreSSL,
}
: {
proxy_url: '',
verify_peer_ssl: false,
}
)
}, [proxyEnabled, proxyUrl, proxyIgnoreSSL])
return {
loadDataModel: getData,
configurePullOptions,
}
}

View File

@ -20,9 +20,12 @@ import { AlertTriangleIcon, AlertCircleIcon } from 'lucide-react'
import { twMerge } from 'tailwind-merge'
import { useDebouncedCallback } from 'use-debounce'
import { snackbar, toaster } from '@/containers/Toast'
import { useActiveModel } from '@/hooks/useActiveModel'
import useModels from '@/hooks/useModels'
import { useSettings } from '@/hooks/useSettings'
import DataFolder from './DataFolder'
@ -65,6 +68,7 @@ const Advanced = () => {
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
null
)
const { configurePullOptions } = useModels()
const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
@ -78,6 +82,15 @@ const Advanced = () => {
return y['name']
})
/**
* There could be a case where the state update is not synced
* so that retrieving state value from other hooks would not be accurate
* there is also a case where state update persist everytime user type in the input
*/
const updatePullOptions = useDebouncedCallback(
() => configurePullOptions(),
300
)
/**
* Handle proxy change
*/
@ -90,8 +103,9 @@ const Advanced = () => {
} else {
setProxy('')
}
updatePullOptions()
},
[setPartialProxy, setProxy]
[setPartialProxy, setProxy, updatePullOptions]
)
/**
@ -209,7 +223,9 @@ const Advanced = () => {
}
setGpusInUse(updatedGpusInUse)
await saveSettings({ gpusInUse: updatedGpusInUse })
window.core?.api?.relaunch()
// Reload window to apply changes
// This will trigger engine servers to restart
window.location.reload()
}
const gpuSelectionPlaceHolder =
@ -452,7 +468,10 @@ const Advanced = () => {
<Switch
data-testid="proxy-switch"
checked={proxyEnabled}
onChange={() => setProxyEnabled(!proxyEnabled)}
onChange={() => {
setProxyEnabled(!proxyEnabled)
updatePullOptions()
}}
/>
<div className="w-full">
<Input
@ -481,7 +500,10 @@ const Advanced = () => {
<Switch
data-testid="ignore-ssl-switch"
checked={ignoreSSL}
onChange={(e) => setIgnoreSSL(e.target.checked)}
onChange={(e) => {
setIgnoreSSL(e.target.checked)
updatePullOptions()
}}
/>
</div>