feat: sync model hub and download progress from cortex.cpp

This commit is contained in:
Louis 2024-10-21 12:18:14 +07:00
parent f44f291bd8
commit 03e15fb70f
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
21 changed files with 192 additions and 86 deletions

View File

@ -1 +1 @@
npm run lint --fix
oxlint --fix || npm run lint --fix

View File

@ -13,7 +13,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
}
abstract getModels(): Promise<Model[]>
abstract pullModel(model: string): Promise<void>
abstract pullModel(model: string, id?: string): Promise<void>
abstract cancelModelPull(modelId: string): Promise<void>
abstract importModel(model: string, modePath: string): Promise<void>
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>

View File

@ -1,5 +1,4 @@
import { Model } from './modelEntity'
import { OptionType } from './modelImport'
/**
* Model extension for managing models.
@ -10,14 +9,14 @@ export interface ModelInterface {
* @param model - The model to download.
* @returns A Promise that resolves when the model has been downloaded.
*/
pullModel(model: string): Promise<void>
pullModel(model: string, id?: string): Promise<void>
/**
* Cancels the download of a specific model.
* @param {string} modelId - The ID of the model to cancel the download for.
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
*/
cancelModelPull(modelId: string): Promise<void>
cancelModelPull(model: string): Promise<void>
/**
* Deletes a model.

View File

@ -1,16 +1,13 @@
import * as monitoringInterface from './monitoringInterface'
import * as resourceInfo from './resourceInfo'
import * as monitoringInterface from './monitoringInterface';
import * as resourceInfo from './resourceInfo';
import * as index from './index'
import * as index from './index';
import * as monitoringInterface from './monitoringInterface';
import * as resourceInfo from './resourceInfo';
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
for (const key in monitoringInterface) {
expect(index[key]).toBe(monitoringInterface[key]);
}
for (const key in resourceInfo) {
expect(index[key]).toBe(resourceInfo[key]);
}
});
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
for (const key in monitoringInterface) {
expect(index[key]).toBe(monitoringInterface[key])
}
for (const key in resourceInfo) {
expect(index[key]).toBe(resourceInfo[key])
}
})

View File

@ -1,5 +1,5 @@
import path from 'path'
import { log, SystemInformation } from '@janhq/core/node'
import { getJanDataFolderPath, log, SystemInformation } from '@janhq/core/node'
import { executableCortexFile } from './execute'
import { ProcessWatchdog } from './watchdog'
@ -40,9 +40,18 @@ function run(systemInfo?: SystemInformation): Promise<any> {
executableOptions.enginePath
)
const dataFolderPath = getJanDataFolderPath()
watchdog = new ProcessWatchdog(
executableOptions.executablePath,
['--start-server', '--port', LOCAL_PORT.toString()],
[
'--start-server',
'--port',
LOCAL_PORT.toString(),
'--config_file_path',
`${path.join(dataFolderPath, '.janrc')}`,
'--data_folder_path',
dataFolderPath,
],
{
cwd: executableOptions.enginePath,
env: {

View File

@ -20,6 +20,8 @@ export default [
replace({
preventAssignment: true,
SETTINGS: JSON.stringify(settingJson),
API_URL: 'http://127.0.0.1:39291',
SOCKET_URL: 'ws://127.0.0.1:39291',
}),
// Allow json resolution
json(),

View File

@ -1,6 +1,8 @@
export {}
declare global {
declare const NODE: string
declare const API_URL: string
declare const SOCKET_URL: string
interface Core {
api: APIFunctions

View File

@ -1,6 +1,7 @@
import PQueue from 'p-queue'
import ky from 'ky'
import {
DownloadEvent,
events,
Model,
ModelEvent,
@ -13,18 +14,12 @@ import {
interface ICortexAPI {
getModel(model: string): Promise<Model>
getModels(): Promise<Model[]>
pullModel(model: string): Promise<void>
pullModel(model: string, id?: string): Promise<void>
importModel(path: string, modelPath: string): Promise<void>
deleteModel(model: string): Promise<void>
updateModel(model: object): Promise<void>
cancelModelPull(model: string): Promise<void>
}
/**
* Simple CortexAPI service
* It could be replaced by cortex client sdk later on
*/
const API_URL = 'http://127.0.0.1:39291'
const SOCKET_URL = 'ws://127.0.0.1:39291'
type ModelList = {
data: any[]
@ -71,10 +66,10 @@ export class CortexAPI implements ICortexAPI {
* @param model
* @returns
*/
pullModel(model: string): Promise<void> {
pullModel(model: string, id?: string): Promise<void> {
return this.queue.add(() =>
ky
.post(`${API_URL}/v1/models/pull`, { json: { model } })
.post(`${API_URL}/v1/models/pull`, { json: { model, id } })
.json()
.catch(async (e) => {
throw (await e.response?.json()) ?? e
@ -160,7 +155,6 @@ export class CortexAPI implements ICortexAPI {
() =>
new Promise<void>((resolve) => {
this.socket = new WebSocket(`${SOCKET_URL}/events`)
console.log('Socket connected')
this.socket.addEventListener('message', (event) => {
const data = JSON.parse(event.data)
@ -173,7 +167,7 @@ export class CortexAPI implements ICortexAPI {
(accumulator, currentValue) => accumulator + currentValue.bytes,
0
)
const percent = ((transferred ?? 1) / (total ?? 1)) * 100
const percent = (transferred / total || 0) * 100
events.emit(data.type, {
modelId: data.task.id,
@ -184,7 +178,13 @@ export class CortexAPI implements ICortexAPI {
},
})
// Update models list from Hub
events.emit(ModelEvent.OnModelsUpdate, {})
if (data.type === DownloadEvent.onFileDownloadSuccess) {
// Delay for the state update from cortex.cpp
// Just to be sure
setTimeout(() => {
events.emit(ModelEvent.OnModelsUpdate, {})
}, 500)
}
})
resolve()
})

View File

@ -47,11 +47,11 @@ export default class JanModelExtension extends ModelExtension {
* @param model - The model to download.
* @returns A Promise that resolves when the model is downloaded.
*/
async pullModel(model: string): Promise<void> {
async pullModel(model: string, id?: string): Promise<void> {
/**
* Sending POST to /models/pull/{id} endpoint to pull the model
*/
return this.cortexAPI.pullModel(model)
return this.cortexAPI.pullModel(model, id)
}
/**

View File

@ -4,7 +4,7 @@ import { Model } from '@janhq/core'
import { Modal, Button, Progress, ModalClose } from '@janhq/joi'
import { useAtomValue } from 'jotai'
import { useAtomValue, useSetAtom } from 'jotai'
import useDownloadModel from '@/hooks/useDownloadModel'
@ -12,7 +12,7 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
import { formatDownloadPercentage } from '@/utils/converter'
import { getDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
import { removeDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
type Props = {
model: Model
@ -21,20 +21,16 @@ type Props = {
const ModalCancelDownload = ({ model, isFromList }: Props) => {
const { abortModelDownload } = useDownloadModel()
const downloadingModels = useAtomValue(getDownloadingModelAtom)
const removeModelDownload = useSetAtom(removeDownloadingModelAtom)
const allDownloadStates = useAtomValue(modelDownloadStateAtom)
const downloadState = allDownloadStates[model.id]
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
const cancelText = `Cancel ${formatDownloadPercentage(downloadState?.percent ?? 0)}`
const onAbortDownloadClick = useCallback(() => {
if (downloadState?.modelId) {
const model = downloadingModels.find(
(model) => model === downloadState.modelId
)
if (model) abortModelDownload(model)
}
}, [downloadState, downloadingModels, abortModelDownload])
removeModelDownload(model.id)
abortModelDownload(downloadState?.modelId ?? model.id)
}, [downloadState, abortModelDownload, removeModelDownload, model])
return (
<Modal
@ -51,13 +47,13 @@ const ModalCancelDownload = ({ model, isFromList }: Props) => {
<Progress
className="w-[80px]"
value={
formatDownloadPercentage(downloadState?.percent, {
formatDownloadPercentage(downloadState?.percent ?? 0, {
hidePercentage: true,
}) as number
}
/>
<span className="tabular-nums">
{formatDownloadPercentage(downloadState.percent)}
{formatDownloadPercentage(downloadState?.percent ?? 0)}
</span>
</div>
</Button>

View File

@ -472,7 +472,10 @@ const ModelDropdown = ({
size={18}
className="cursor-pointer text-[hsla(var(--app-link))]"
onClick={() =>
downloadModel(model.sources[0].url)
downloadModel(
model.sources[0].url,
model.id
)
}
/>
) : (
@ -559,7 +562,10 @@ const ModelDropdown = ({
size={18}
className="cursor-pointer text-[hsla(var(--app-link))]"
onClick={() =>
downloadModel(model.sources[0].url)
downloadModel(
model.sources[0].url,
model.id
)
}
/>
) : (

View File

@ -23,11 +23,17 @@ import {
removeInstallingExtensionAtom,
setInstallingExtensionAtom,
} from '@/helpers/atoms/Extension.atom'
import {
addDownloadingModelAtom,
removeDownloadingModelAtom,
} from '@/helpers/atoms/Model.atom'
const EventListenerWrapper = ({ children }: PropsWithChildren) => {
const setDownloadState = useSetAtom(setDownloadStateAtom)
const setInstallingExtension = useSetAtom(setInstallingExtensionAtom)
const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom)
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
const onFileDownloadUpdate = useCallback(
async (state: DownloadState) => {
@ -40,6 +46,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
}
setInstallingExtension(state.extensionId!, installingExtensionState)
} else {
addDownloadingModel(state.modelId)
setDownloadState(state)
}
},
@ -54,6 +61,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
} else {
state.downloadState = 'error'
setDownloadState(state)
removeDownloadingModel(state.modelId)
}
},
[setDownloadState, removeInstallingExtension]
@ -68,6 +76,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
state.downloadState = 'error'
state.error = 'aborted'
setDownloadState(state)
removeDownloadingModel(state.modelId)
}
},
[setDownloadState, removeInstallingExtension]
@ -79,6 +88,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
if (state.downloadType !== 'extension') {
state.downloadState = 'end'
setDownloadState(state)
removeDownloadingModel(state.modelId)
}
events.emit(ModelEvent.OnModelsUpdate, {})
},

View File

@ -1,11 +1,6 @@
import { useCallback } from 'react'
import {
events,
ExtensionTypeEnum,
ModelEvent,
ModelExtension,
} from '@janhq/core'
import { ExtensionTypeEnum, ModelExtension } from '@janhq/core'
import { useSetAtom } from 'jotai'
@ -19,13 +14,13 @@ import {
} from '@/helpers/atoms/Model.atom'
export default function useDownloadModel() {
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
const downloadModel = useCallback(
async (model: string) => {
addDownloadingModel(model)
localDownloadModel(model).catch((error) => {
async (model: string, id?: string) => {
addDownloadingModel(id ?? model)
downloadLocalModel(model, id).catch((error) => {
if (error.message) {
toaster({
title: 'Download failed',
@ -37,7 +32,7 @@ export default function useDownloadModel() {
removeDownloadingModel(model)
})
},
[addDownloadingModel]
[removeDownloadingModel, addDownloadingModel]
)
const abortModelDownload = useCallback(async (model: string) => {
@ -50,10 +45,10 @@ export default function useDownloadModel() {
}
}
const localDownloadModel = async (model: string) =>
const downloadLocalModel = async (model: string, id?: string) =>
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.pullModel(model)
?.pullModel(model, id)
const cancelModelDownload = async (model: string) =>
extensionManager

View File

@ -2,6 +2,8 @@ import { useCallback, useState } from 'react'
import { HuggingFaceRepoData } from '@janhq/core'
import { fetchHuggingFaceRepoData } from '@/utils/huggingface'
export const useGetHFRepoData = () => {
const [error, setError] = useState<string | undefined>(undefined)
const [loading, setLoading] = useState(false)
@ -29,8 +31,5 @@ export const useGetHFRepoData = () => {
const extensionGetHfRepoData = async (
repoId: string
): Promise<HuggingFaceRepoData | undefined> => {
return Promise.resolve(undefined)
// return extensionManager
// .get<ModelExtension>(ExtensionTypeEnum.Model)
// ?.fetchHuggingFaceRepoData(repoId)
return fetchHuggingFaceRepoData(repoId)
}

View File

@ -216,7 +216,7 @@ export default function useSendChatMessage() {
...activeThreadRef.current,
updated: newMessage.created,
metadata: {
...(activeThreadRef.current.metadata ?? {}),
...activeThreadRef.current.metadata,
lastMessage: prompt,
},
}
@ -256,7 +256,6 @@ export default function useSendChatMessage() {
)
request.messages = normalizeMessages(request.messages ?? [])
console.log(requestBuilder.model?.engine ?? modelRequest.engine, request)
// Request for inference
EngineManager.instance()
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')

View File

@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
const assistants = useAtomValue(assistantsAtom)
const onDownloadClick = useCallback(() => {
downloadModel(model.sources[0].url)
downloadModel(model.sources[0].url, model.id)
}, [model, downloadModel])
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
@ -123,17 +123,6 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
className="cursor-pointer rounded-t-md bg-[hsla(var(--app-bg))]"
onClick={onClick}
>
{/* TODO: @faisal are we still using cover? */}
{/* {model.metadata.cover && imageLoaded && (
<div className="relative h-full w-full">
<img
onError={() => setImageLoaded(false)}
src={model.metadata.cover}
className="h-[250px] w-full object-cover"
alt={`Cover - ${model.id}`}
/>
</div>
)} */}
<div className="flex items-center justify-between px-4 py-2">
<div className="flex items-center gap-2">
<span className="line-clamp-1 text-base font-semibold">

View File

@ -20,6 +20,7 @@ import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { normalizeModelId } from '@/utils/model'
type Props = {
index: number
@ -50,13 +51,13 @@ const ModelDownloadRow: React.FC<Props> = ({
const onAbortDownloadClick = useCallback(() => {
if (downloadUrl) {
abortModelDownload(downloadUrl)
abortModelDownload(normalizeModelId(downloadUrl))
}
}, [downloadUrl, abortModelDownload])
const onDownloadClick = useCallback(async () => {
if (downloadUrl) {
downloadModel(downloadUrl)
downloadModel(downloadUrl, normalizeModelId(downloadUrl))
}
}, [downloadUrl, downloadModel])

View File

@ -168,7 +168,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
size={18}
className="cursor-pointer text-[hsla(var(--app-link))]"
onClick={() =>
downloadModel(model.sources[0].url)
downloadModel(
model.sources[0].url,
model.id
)
}
/>
) : (
@ -256,7 +259,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
theme="ghost"
className="!bg-[hsla(var(--secondary-bg))]"
onClick={() =>
downloadModel(featModel.sources[0].url)
downloadModel(
featModel.sources[0].url,
featModel.id
)
}
>
Download

View File

@ -9,7 +9,7 @@ export function openExternalUrl(url: string) {
}
// Define API routes based on different route types
export const APIRoutes = [...CoreRoutes.map((r) => ({ path: `app`, route: r }))]
export const APIRoutes = CoreRoutes.map((r) => ({ path: `app`, route: r }))
// Define the restAPI object with methods for each API route
export const restAPI = {

93
web/utils/huggingface.ts Normal file
View File

@ -0,0 +1,93 @@
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
export const fetchHuggingFaceRepoData = async (
repoId: string,
huggingFaceAccessToken?: string
): Promise<HuggingFaceRepoData> => {
const sanitizedUrl = toHuggingFaceUrl(repoId)
console.debug('sanitizedUrl', sanitizedUrl)
const headers: Record<string, string> = {
Accept: 'application/json',
}
if (huggingFaceAccessToken && huggingFaceAccessToken.length > 0) {
headers['Authorization'] = `Bearer ${huggingFaceAccessToken}`
}
const res = await fetch(sanitizedUrl, {
headers: headers,
})
const response = await res.json()
if (response['error'] != null) {
throw new Error(response['error'])
}
const data = response as HuggingFaceRepoData
if (data.tags.indexOf('gguf') === -1) {
throw new Error(
`${repoId} is not supported. Only GGUF models are supported.`
)
}
const promises: Promise<number>[] = []
// fetching file sizes
const url = new URL(sanitizedUrl)
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
for (const sibling of data.siblings) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
sibling.downloadUrl = downloadUrl
promises.push(getFileSize(downloadUrl))
}
const result = await Promise.all(promises)
for (let i = 0; i < data.siblings.length; i++) {
data.siblings[i].fileSize = result[i]
}
AllQuantizations.forEach((quantization) => {
data.siblings.forEach((sibling) => {
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
sibling.quantization = quantization
}
})
})
data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}`
return data
}
function toHuggingFaceUrl(repoId: string): string {
try {
const url = new URL(repoId)
if (url.host !== 'huggingface.co') {
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
}
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
if (paths.length < 2) {
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
}
return `${url.origin}/api/models/${paths[0]}/${paths[1]}`
} catch (err) {
if (err instanceof InvalidHostError) {
throw err
}
if (repoId.startsWith('https')) {
throw new Error(`Cannot parse url: ${repoId}`)
}
return `https://huggingface.co/api/models/${repoId}`
}
}
class InvalidHostError extends Error {
constructor(message: string) {
super(message)
this.name = 'InvalidHostError'
}
}

3
web/utils/model.ts Normal file
View File

@ -0,0 +1,3 @@
export const normalizeModelId = (downloadUrl: string): string => {
return downloadUrl.split('/').pop() ?? downloadUrl
}