Merge pull request #4046 from janhq/fix/model-pull-options-http-proxy

feat: model pull options - HTTP Proxy requests route to cortex.cpp
This commit is contained in:
Louis 2024-11-19 21:50:54 +07:00 committed by GitHub
commit c501641c91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 170 additions and 19 deletions

View File

@ -12,6 +12,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
return ExtensionTypeEnum.Model return ExtensionTypeEnum.Model
} }
abstract configurePullOptions(configs: { [key: string]: any }): Promise<any>
abstract getModels(): Promise<Model[]> abstract getModels(): Promise<Model[]>
abstract pullModel(model: string, id?: string, name?: string): Promise<void> abstract pullModel(model: string, id?: string, name?: string): Promise<void>
abstract cancelModelPull(modelId: string): Promise<void> abstract cancelModelPull(modelId: string): Promise<void>

View File

@ -1 +1 @@
1.0.3-rc1 1.0.3-rc4

View File

@ -69,11 +69,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
super.onLoad() super.onLoad()
await this.queue.add(() => this.clean())
this.queue.add(() => this.healthz()) this.queue.add(() => this.healthz())
this.queue.add(() => this.setDefaultEngine(systemInfo)) this.queue.add(() => this.setDefaultEngine(systemInfo))
// Run the process watchdog // Run the process watchdog
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
await this.clean()
await executeOnMain(NODE, 'run', systemInfo) await executeOnMain(NODE, 'run', systemInfo)
this.subscribeToEvents() this.subscribeToEvents()
@ -160,7 +160,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
return ky return ky
.get(`${CORTEX_API_URL}/healthz`, { .get(`${CORTEX_API_URL}/healthz`, {
retry: { retry: {
limit: 10, limit: 20,
delay: () => 500,
methods: ['get'], methods: ['get'],
}, },
}) })
@ -192,6 +193,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
return ky return ky
.delete(`${CORTEX_API_URL}/processmanager/destroy`, { .delete(`${CORTEX_API_URL}/processmanager/destroy`, {
timeout: 2000, // maximum 2 seconds timeout: 2000, // maximum 2 seconds
retry: {
limit: 0,
},
}) })
.catch(() => { .catch(() => {
// Do nothing // Do nothing

View File

@ -1,6 +1,6 @@
import PQueue from 'p-queue' import PQueue from 'p-queue'
import ky from 'ky' import ky from 'ky'
import { extractModelLoadParams, Model } from '@janhq/core' import { extractModelLoadParams, Model } from '@janhq/core'
import { extractInferenceParams } from '@janhq/core' import { extractInferenceParams } from '@janhq/core'
/** /**
* cortex.cpp Model APIs interface * cortex.cpp Model APIs interface
@ -18,6 +18,7 @@ interface ICortexAPI {
deleteModel(model: string): Promise<void> deleteModel(model: string): Promise<void>
updateModel(model: object): Promise<void> updateModel(model: object): Promise<void>
cancelModelPull(model: string): Promise<void> cancelModelPull(model: string): Promise<void>
configs(body: { [key: string]: any }): Promise<void>
} }
type ModelList = { type ModelList = {
@ -52,7 +53,7 @@ export class CortexAPI implements ICortexAPI {
*/ */
getModels(): Promise<Model[]> { getModels(): Promise<Model[]> {
return this.queue return this.queue
.add(() => ky.get(`${API_URL}/models`).json<ModelList>()) .add(() => ky.get(`${API_URL}/v1/models`).json<ModelList>())
.then((e) => .then((e) =>
typeof e === 'object' ? e.data.map((e) => this.transformModel(e)) : [] typeof e === 'object' ? e.data.map((e) => this.transformModel(e)) : []
) )
@ -104,7 +105,7 @@ export class CortexAPI implements ICortexAPI {
*/ */
deleteModel(model: string): Promise<void> { deleteModel(model: string): Promise<void> {
return this.queue.add(() => return this.queue.add(() =>
ky.delete(`${API_URL}/models/${model}`).json().then() ky.delete(`${API_URL}/v1/models/${model}`).json().then()
) )
} }
@ -130,7 +131,7 @@ export class CortexAPI implements ICortexAPI {
cancelModelPull(model: string): Promise<void> { cancelModelPull(model: string): Promise<void> {
return this.queue.add(() => return this.queue.add(() =>
ky ky
.delete(`${API_URL}/models/pull`, { json: { taskId: model } }) .delete(`${API_URL}/v1/models/pull`, { json: { taskId: model } })
.json() .json()
.then() .then()
) )
@ -142,7 +143,7 @@ export class CortexAPI implements ICortexAPI {
*/ */
async getModelStatus(model: string): Promise<boolean> { async getModelStatus(model: string): Promise<boolean> {
return this.queue return this.queue
.add(() => ky.get(`${API_URL}/models/status/${model}`)) .add(() => ky.get(`${API_URL}/v1/models/status/${model}`))
.then((e) => true) .then((e) => true)
.catch(() => false) .catch(() => false)
} }
@ -155,13 +156,24 @@ export class CortexAPI implements ICortexAPI {
return ky return ky
.get(`${API_URL}/healthz`, { .get(`${API_URL}/healthz`, {
retry: { retry: {
limit: 10, limit: 20,
delay: () => 500,
methods: ['get'], methods: ['get'],
}, },
}) })
.then(() => {}) .then(() => {})
} }
/**
* Configure model pull options
* @param body
*/
configs(body: { [key: string]: any }): Promise<void> {
return this.queue.add(() =>
ky.patch(`${API_URL}/v1/configs`, { json: body }).then(() => {})
)
}
/** /**
* TRansform model to the expected format (e.g. parameters, settings, metadata) * TRansform model to the expected format (e.g. parameters, settings, metadata)
* @param model * @param model

View File

@ -227,6 +227,13 @@ export default class JanModelExtension extends ModelExtension {
return this.cortexAPI.getModelStatus(model) return this.cortexAPI.getModelStatus(model)
} }
/**
* Configure pull options such as proxy, headers, etc.
*/
async configurePullOptions(options: { [key: string]: any }): Promise<any> {
return this.cortexAPI.configs(options).catch((e) => console.debug(e))
}
/** /**
* Handle download state from main app * Handle download state from main app
*/ */

View File

@ -267,7 +267,7 @@ const updateGpuInfo = async () =>
} }
data = await updateCudaExistence(data) data = await updateCudaExistence(data)
console.log(data) console.log('[MONITORING]::Cuda info: ', data)
writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2))
log(`[APP]::${JSON.stringify(data)}`) log(`[APP]::${JSON.stringify(data)}`)
resolve({}) resolve({})
@ -344,7 +344,7 @@ const updateCudaExistence = async (
data.cuda.version = match[1] data.cuda.version = match[1]
} }
} }
console.log(data) console.log('[MONITORING]::Finalized cuda info update: ', data)
resolve() resolve()
}) })
}) })

View File

@ -29,7 +29,7 @@ const DataLoader: React.FC<Props> = ({ children }) => {
const setQuickAskEnabled = useSetAtom(quickAskEnabledAtom) const setQuickAskEnabled = useSetAtom(quickAskEnabledAtom)
const setJanDefaultDataFolder = useSetAtom(defaultJanDataFolderAtom) const setJanDefaultDataFolder = useSetAtom(defaultJanDataFolderAtom)
const setJanSettingScreen = useSetAtom(janSettingScreenAtom) const setJanSettingScreen = useSetAtom(janSettingScreenAtom)
const { loadDataModel } = useModels() const { loadDataModel, configurePullOptions } = useModels()
useThreads() useThreads()
useAssistants() useAssistants()
@ -39,6 +39,8 @@ const DataLoader: React.FC<Props> = ({ children }) => {
useEffect(() => { useEffect(() => {
// Load data once // Load data once
loadDataModel() loadDataModel()
// Configure pull options once
configurePullOptions()
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])

View File

@ -1,5 +1,5 @@
// useModels.test.ts // useModels.test.ts
import { renderHook, act } from '@testing-library/react' import { renderHook, act, waitFor } from '@testing-library/react'
import { events, ModelEvent, ModelManager } from '@janhq/core' import { events, ModelEvent, ModelManager } from '@janhq/core'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
@ -36,7 +36,6 @@ describe('useModels', () => {
}), }),
get: () => undefined, get: () => undefined,
has: () => true, has: () => true,
// set: () => {}
}, },
}) })
@ -50,6 +49,85 @@ describe('useModels', () => {
expect(mockModelExtension.getModels).toHaveBeenCalled() expect(mockModelExtension.getModels).toHaveBeenCalled()
}) })
it('should return empty on error', async () => {
const mockModelExtension = {
getModels: jest.fn().mockRejectedValue(new Error('Error')),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
const { result } = renderHook(() => useModels())
await act(() => {
result.current?.loadDataModel()
})
expect(mockModelExtension.getModels()).rejects.toThrow()
})
it('should update states on models update', async () => {
const mockModelExtension = {
getModels: jest.fn().mockResolvedValue(models),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
jest.spyOn(events, 'on').mockImplementationOnce((event, cb) => {
cb({ fetch: false })
})
renderHook(() => useModels())
expect(mockModelExtension.getModels).not.toHaveBeenCalled()
})
it('should update states on models update', async () => {
const mockModelExtension = {
getModels: jest.fn().mockResolvedValue(models),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => ({
filter: () => models,
}),
}),
get: () => undefined,
has: () => true,
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
jest.spyOn(events, 'on').mockImplementationOnce((event, cb) => {
cb({ fetch: true })
})
renderHook(() => useModels())
expect(mockModelExtension.getModels).toHaveBeenCalled()
})
it('should remove event listener on unmount', async () => { it('should remove event listener on unmount', async () => {
const removeListenerSpy = jest.spyOn(events, 'off') const removeListenerSpy = jest.spyOn(events, 'off')

View File

@ -9,13 +9,18 @@ import {
ModelManager, ModelManager,
} from '@janhq/core' } from '@janhq/core'
import { useSetAtom, useAtom } from 'jotai' import { useSetAtom, useAtom, useAtomValue } from 'jotai'
import { useDebouncedCallback } from 'use-debounce' import { useDebouncedCallback } from 'use-debounce'
import { isLocalEngine } from '@/utils/modelEngine' import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import {
ignoreSslAtom,
proxyAtom,
proxyEnabledAtom,
} from '@/helpers/atoms/AppConfig.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
downloadedModelsAtom, downloadedModelsAtom,
@ -29,6 +34,9 @@ import {
const useModels = () => { const useModels = () => {
const [downloadedModels, setDownloadedModels] = useAtom(downloadedModelsAtom) const [downloadedModels, setDownloadedModels] = useAtom(downloadedModelsAtom)
const setExtensionModels = useSetAtom(configuredModelsAtom) const setExtensionModels = useSetAtom(configuredModelsAtom)
const proxyEnabled = useAtomValue(proxyEnabledAtom)
const proxyUrl = useAtomValue(proxyAtom)
const proxyIgnoreSSL = useAtomValue(ignoreSslAtom)
const getData = useCallback(() => { const getData = useCallback(() => {
const getDownloadedModels = async () => { const getDownloadedModels = async () => {
@ -107,8 +115,25 @@ const useModels = () => {
} }
}, [reloadData, updateStates]) }, [reloadData, updateStates])
const configurePullOptions = useCallback(() => {
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.configurePullOptions(
proxyEnabled
? {
proxy_url: proxyUrl,
verify_peer_ssl: !proxyIgnoreSSL,
}
: {
proxy_url: '',
verify_peer_ssl: false,
}
)
}, [proxyEnabled, proxyUrl, proxyIgnoreSSL])
return { return {
loadDataModel: getData, loadDataModel: getData,
configurePullOptions,
} }
} }

View File

@ -20,9 +20,12 @@ import { AlertTriangleIcon, AlertCircleIcon } from 'lucide-react'
import { twMerge } from 'tailwind-merge' import { twMerge } from 'tailwind-merge'
import { useDebouncedCallback } from 'use-debounce'
import { snackbar, toaster } from '@/containers/Toast' import { snackbar, toaster } from '@/containers/Toast'
import { useActiveModel } from '@/hooks/useActiveModel' import { useActiveModel } from '@/hooks/useActiveModel'
import useModels from '@/hooks/useModels'
import { useSettings } from '@/hooks/useSettings' import { useSettings } from '@/hooks/useSettings'
import DataFolder from './DataFolder' import DataFolder from './DataFolder'
@ -65,6 +68,7 @@ const Advanced = () => {
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>( const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
null null
) )
const { configurePullOptions } = useModels()
const [toggle, setToggle] = useState<HTMLDivElement | null>(null) const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
@ -78,6 +82,15 @@ const Advanced = () => {
return y['name'] return y['name']
}) })
/**
* There could be a case where the state update is not synced
* so that retrieving state value from other hooks would not be accurate
* there is also a case where state update persist everytime user type in the input
*/
const updatePullOptions = useDebouncedCallback(
() => configurePullOptions(),
300
)
/** /**
* Handle proxy change * Handle proxy change
*/ */
@ -90,8 +103,9 @@ const Advanced = () => {
} else { } else {
setProxy('') setProxy('')
} }
updatePullOptions()
}, },
[setPartialProxy, setProxy] [setPartialProxy, setProxy, updatePullOptions]
) )
/** /**
@ -209,7 +223,9 @@ const Advanced = () => {
} }
setGpusInUse(updatedGpusInUse) setGpusInUse(updatedGpusInUse)
await saveSettings({ gpusInUse: updatedGpusInUse }) await saveSettings({ gpusInUse: updatedGpusInUse })
window.core?.api?.relaunch() // Reload window to apply changes
// This will trigger engine servers to restart
window.location.reload()
} }
const gpuSelectionPlaceHolder = const gpuSelectionPlaceHolder =
@ -452,7 +468,10 @@ const Advanced = () => {
<Switch <Switch
data-testid="proxy-switch" data-testid="proxy-switch"
checked={proxyEnabled} checked={proxyEnabled}
onChange={() => setProxyEnabled(!proxyEnabled)} onChange={() => {
setProxyEnabled(!proxyEnabled)
updatePullOptions()
}}
/> />
<div className="w-full"> <div className="w-full">
<Input <Input
@ -481,7 +500,10 @@ const Advanced = () => {
<Switch <Switch
data-testid="ignore-ssl-switch" data-testid="ignore-ssl-switch"
checked={ignoreSSL} checked={ignoreSSL}
onChange={(e) => setIgnoreSSL(e.target.checked)} onChange={(e) => {
setIgnoreSSL(e.target.checked)
updatePullOptions()
}}
/> />
</div> </div>