feat: sync model hub and download progress from cortex.cpp
This commit is contained in:
parent
f44f291bd8
commit
03e15fb70f
@ -1 +1 @@
|
|||||||
npm run lint --fix
|
oxlint --fix || npm run lint --fix
|
||||||
@ -13,7 +13,7 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter
|
|||||||
}
|
}
|
||||||
|
|
||||||
abstract getModels(): Promise<Model[]>
|
abstract getModels(): Promise<Model[]>
|
||||||
abstract pullModel(model: string): Promise<void>
|
abstract pullModel(model: string, id?: string): Promise<void>
|
||||||
abstract cancelModelPull(modelId: string): Promise<void>
|
abstract cancelModelPull(modelId: string): Promise<void>
|
||||||
abstract importModel(model: string, modePath: string): Promise<void>
|
abstract importModel(model: string, modePath: string): Promise<void>
|
||||||
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
|
abstract updateModel(modelInfo: Partial<Model>): Promise<Model>
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import { Model } from './modelEntity'
|
import { Model } from './modelEntity'
|
||||||
import { OptionType } from './modelImport'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model extension for managing models.
|
* Model extension for managing models.
|
||||||
@ -10,14 +9,14 @@ export interface ModelInterface {
|
|||||||
* @param model - The model to download.
|
* @param model - The model to download.
|
||||||
* @returns A Promise that resolves when the model has been downloaded.
|
* @returns A Promise that resolves when the model has been downloaded.
|
||||||
*/
|
*/
|
||||||
pullModel(model: string): Promise<void>
|
pullModel(model: string, id?: string): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cancels the download of a specific model.
|
* Cancels the download of a specific model.
|
||||||
* @param {string} modelId - The ID of the model to cancel the download for.
|
* @param {string} modelId - The ID of the model to cancel the download for.
|
||||||
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
|
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
|
||||||
*/
|
*/
|
||||||
cancelModelPull(modelId: string): Promise<void>
|
cancelModelPull(model: string): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes a model.
|
* Deletes a model.
|
||||||
|
|||||||
@ -1,16 +1,13 @@
|
|||||||
|
import * as monitoringInterface from './monitoringInterface'
|
||||||
|
import * as resourceInfo from './resourceInfo'
|
||||||
|
|
||||||
import * as monitoringInterface from './monitoringInterface';
|
import * as index from './index'
|
||||||
import * as resourceInfo from './resourceInfo';
|
|
||||||
|
|
||||||
import * as index from './index';
|
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
|
||||||
import * as monitoringInterface from './monitoringInterface';
|
|
||||||
import * as resourceInfo from './resourceInfo';
|
|
||||||
|
|
||||||
it('should re-export all symbols from monitoringInterface and resourceInfo', () => {
|
|
||||||
for (const key in monitoringInterface) {
|
for (const key in monitoringInterface) {
|
||||||
expect(index[key]).toBe(monitoringInterface[key]);
|
expect(index[key]).toBe(monitoringInterface[key])
|
||||||
}
|
}
|
||||||
for (const key in resourceInfo) {
|
for (const key in resourceInfo) {
|
||||||
expect(index[key]).toBe(resourceInfo[key]);
|
expect(index[key]).toBe(resourceInfo[key])
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import path from 'path'
|
import path from 'path'
|
||||||
import { log, SystemInformation } from '@janhq/core/node'
|
import { getJanDataFolderPath, log, SystemInformation } from '@janhq/core/node'
|
||||||
import { executableCortexFile } from './execute'
|
import { executableCortexFile } from './execute'
|
||||||
import { ProcessWatchdog } from './watchdog'
|
import { ProcessWatchdog } from './watchdog'
|
||||||
|
|
||||||
@ -40,9 +40,18 @@ function run(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
executableOptions.enginePath
|
executableOptions.enginePath
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const dataFolderPath = getJanDataFolderPath()
|
||||||
watchdog = new ProcessWatchdog(
|
watchdog = new ProcessWatchdog(
|
||||||
executableOptions.executablePath,
|
executableOptions.executablePath,
|
||||||
['--start-server', '--port', LOCAL_PORT.toString()],
|
[
|
||||||
|
'--start-server',
|
||||||
|
'--port',
|
||||||
|
LOCAL_PORT.toString(),
|
||||||
|
'--config_file_path',
|
||||||
|
`${path.join(dataFolderPath, '.janrc')}`,
|
||||||
|
'--data_folder_path',
|
||||||
|
dataFolderPath,
|
||||||
|
],
|
||||||
{
|
{
|
||||||
cwd: executableOptions.enginePath,
|
cwd: executableOptions.enginePath,
|
||||||
env: {
|
env: {
|
||||||
|
|||||||
@ -20,6 +20,8 @@ export default [
|
|||||||
replace({
|
replace({
|
||||||
preventAssignment: true,
|
preventAssignment: true,
|
||||||
SETTINGS: JSON.stringify(settingJson),
|
SETTINGS: JSON.stringify(settingJson),
|
||||||
|
API_URL: 'http://127.0.0.1:39291',
|
||||||
|
SOCKET_URL: 'ws://127.0.0.1:39291',
|
||||||
}),
|
}),
|
||||||
// Allow json resolution
|
// Allow json resolution
|
||||||
json(),
|
json(),
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
export {}
|
export {}
|
||||||
declare global {
|
declare global {
|
||||||
declare const NODE: string
|
declare const NODE: string
|
||||||
|
declare const API_URL: string
|
||||||
|
declare const SOCKET_URL: string
|
||||||
|
|
||||||
interface Core {
|
interface Core {
|
||||||
api: APIFunctions
|
api: APIFunctions
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import PQueue from 'p-queue'
|
import PQueue from 'p-queue'
|
||||||
import ky from 'ky'
|
import ky from 'ky'
|
||||||
import {
|
import {
|
||||||
|
DownloadEvent,
|
||||||
events,
|
events,
|
||||||
Model,
|
Model,
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
@ -13,18 +14,12 @@ import {
|
|||||||
interface ICortexAPI {
|
interface ICortexAPI {
|
||||||
getModel(model: string): Promise<Model>
|
getModel(model: string): Promise<Model>
|
||||||
getModels(): Promise<Model[]>
|
getModels(): Promise<Model[]>
|
||||||
pullModel(model: string): Promise<void>
|
pullModel(model: string, id?: string): Promise<void>
|
||||||
importModel(path: string, modelPath: string): Promise<void>
|
importModel(path: string, modelPath: string): Promise<void>
|
||||||
deleteModel(model: string): Promise<void>
|
deleteModel(model: string): Promise<void>
|
||||||
updateModel(model: object): Promise<void>
|
updateModel(model: object): Promise<void>
|
||||||
cancelModelPull(model: string): Promise<void>
|
cancelModelPull(model: string): Promise<void>
|
||||||
}
|
}
|
||||||
/**
|
|
||||||
* Simple CortexAPI service
|
|
||||||
* It could be replaced by cortex client sdk later on
|
|
||||||
*/
|
|
||||||
const API_URL = 'http://127.0.0.1:39291'
|
|
||||||
const SOCKET_URL = 'ws://127.0.0.1:39291'
|
|
||||||
|
|
||||||
type ModelList = {
|
type ModelList = {
|
||||||
data: any[]
|
data: any[]
|
||||||
@ -71,10 +66,10 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
* @param model
|
* @param model
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
pullModel(model: string): Promise<void> {
|
pullModel(model: string, id?: string): Promise<void> {
|
||||||
return this.queue.add(() =>
|
return this.queue.add(() =>
|
||||||
ky
|
ky
|
||||||
.post(`${API_URL}/v1/models/pull`, { json: { model } })
|
.post(`${API_URL}/v1/models/pull`, { json: { model, id } })
|
||||||
.json()
|
.json()
|
||||||
.catch(async (e) => {
|
.catch(async (e) => {
|
||||||
throw (await e.response?.json()) ?? e
|
throw (await e.response?.json()) ?? e
|
||||||
@ -160,7 +155,6 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
() =>
|
() =>
|
||||||
new Promise<void>((resolve) => {
|
new Promise<void>((resolve) => {
|
||||||
this.socket = new WebSocket(`${SOCKET_URL}/events`)
|
this.socket = new WebSocket(`${SOCKET_URL}/events`)
|
||||||
console.log('Socket connected')
|
|
||||||
|
|
||||||
this.socket.addEventListener('message', (event) => {
|
this.socket.addEventListener('message', (event) => {
|
||||||
const data = JSON.parse(event.data)
|
const data = JSON.parse(event.data)
|
||||||
@ -173,7 +167,7 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
(accumulator, currentValue) => accumulator + currentValue.bytes,
|
(accumulator, currentValue) => accumulator + currentValue.bytes,
|
||||||
0
|
0
|
||||||
)
|
)
|
||||||
const percent = ((transferred ?? 1) / (total ?? 1)) * 100
|
const percent = (transferred / total || 0) * 100
|
||||||
|
|
||||||
events.emit(data.type, {
|
events.emit(data.type, {
|
||||||
modelId: data.task.id,
|
modelId: data.task.id,
|
||||||
@ -184,7 +178,13 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
// Update models list from Hub
|
// Update models list from Hub
|
||||||
|
if (data.type === DownloadEvent.onFileDownloadSuccess) {
|
||||||
|
// Delay for the state update from cortex.cpp
|
||||||
|
// Just to be sure
|
||||||
|
setTimeout(() => {
|
||||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||||
|
}, 500)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -47,11 +47,11 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @param model - The model to download.
|
* @param model - The model to download.
|
||||||
* @returns A Promise that resolves when the model is downloaded.
|
* @returns A Promise that resolves when the model is downloaded.
|
||||||
*/
|
*/
|
||||||
async pullModel(model: string): Promise<void> {
|
async pullModel(model: string, id?: string): Promise<void> {
|
||||||
/**
|
/**
|
||||||
* Sending POST to /models/pull/{id} endpoint to pull the model
|
* Sending POST to /models/pull/{id} endpoint to pull the model
|
||||||
*/
|
*/
|
||||||
return this.cortexAPI.pullModel(model)
|
return this.cortexAPI.pullModel(model, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import { Model } from '@janhq/core'
|
|||||||
|
|
||||||
import { Modal, Button, Progress, ModalClose } from '@janhq/joi'
|
import { Modal, Button, Progress, ModalClose } from '@janhq/joi'
|
||||||
|
|
||||||
import { useAtomValue } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import useDownloadModel from '@/hooks/useDownloadModel'
|
import useDownloadModel from '@/hooks/useDownloadModel'
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
|
|||||||
|
|
||||||
import { formatDownloadPercentage } from '@/utils/converter'
|
import { formatDownloadPercentage } from '@/utils/converter'
|
||||||
|
|
||||||
import { getDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
|
import { removeDownloadingModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model: Model
|
model: Model
|
||||||
@ -21,20 +21,16 @@ type Props = {
|
|||||||
|
|
||||||
const ModalCancelDownload = ({ model, isFromList }: Props) => {
|
const ModalCancelDownload = ({ model, isFromList }: Props) => {
|
||||||
const { abortModelDownload } = useDownloadModel()
|
const { abortModelDownload } = useDownloadModel()
|
||||||
const downloadingModels = useAtomValue(getDownloadingModelAtom)
|
const removeModelDownload = useSetAtom(removeDownloadingModelAtom)
|
||||||
const allDownloadStates = useAtomValue(modelDownloadStateAtom)
|
const allDownloadStates = useAtomValue(modelDownloadStateAtom)
|
||||||
const downloadState = allDownloadStates[model.id]
|
const downloadState = allDownloadStates[model.id]
|
||||||
|
|
||||||
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
|
const cancelText = `Cancel ${formatDownloadPercentage(downloadState?.percent ?? 0)}`
|
||||||
|
|
||||||
const onAbortDownloadClick = useCallback(() => {
|
const onAbortDownloadClick = useCallback(() => {
|
||||||
if (downloadState?.modelId) {
|
removeModelDownload(model.id)
|
||||||
const model = downloadingModels.find(
|
abortModelDownload(downloadState?.modelId ?? model.id)
|
||||||
(model) => model === downloadState.modelId
|
}, [downloadState, abortModelDownload, removeModelDownload, model])
|
||||||
)
|
|
||||||
if (model) abortModelDownload(model)
|
|
||||||
}
|
|
||||||
}, [downloadState, downloadingModels, abortModelDownload])
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
@ -51,13 +47,13 @@ const ModalCancelDownload = ({ model, isFromList }: Props) => {
|
|||||||
<Progress
|
<Progress
|
||||||
className="w-[80px]"
|
className="w-[80px]"
|
||||||
value={
|
value={
|
||||||
formatDownloadPercentage(downloadState?.percent, {
|
formatDownloadPercentage(downloadState?.percent ?? 0, {
|
||||||
hidePercentage: true,
|
hidePercentage: true,
|
||||||
}) as number
|
}) as number
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<span className="tabular-nums">
|
<span className="tabular-nums">
|
||||||
{formatDownloadPercentage(downloadState.percent)}
|
{formatDownloadPercentage(downloadState?.percent ?? 0)}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@ -472,7 +472,10 @@ const ModelDropdown = ({
|
|||||||
size={18}
|
size={18}
|
||||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(model.sources[0].url)
|
downloadModel(
|
||||||
|
model.sources[0].url,
|
||||||
|
model.id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
@ -559,7 +562,10 @@ const ModelDropdown = ({
|
|||||||
size={18}
|
size={18}
|
||||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(model.sources[0].url)
|
downloadModel(
|
||||||
|
model.sources[0].url,
|
||||||
|
model.id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@ -23,11 +23,17 @@ import {
|
|||||||
removeInstallingExtensionAtom,
|
removeInstallingExtensionAtom,
|
||||||
setInstallingExtensionAtom,
|
setInstallingExtensionAtom,
|
||||||
} from '@/helpers/atoms/Extension.atom'
|
} from '@/helpers/atoms/Extension.atom'
|
||||||
|
import {
|
||||||
|
addDownloadingModelAtom,
|
||||||
|
removeDownloadingModelAtom,
|
||||||
|
} from '@/helpers/atoms/Model.atom'
|
||||||
|
|
||||||
const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||||
const setDownloadState = useSetAtom(setDownloadStateAtom)
|
const setDownloadState = useSetAtom(setDownloadStateAtom)
|
||||||
const setInstallingExtension = useSetAtom(setInstallingExtensionAtom)
|
const setInstallingExtension = useSetAtom(setInstallingExtensionAtom)
|
||||||
const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom)
|
const removeInstallingExtension = useSetAtom(removeInstallingExtensionAtom)
|
||||||
|
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||||
|
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
|
||||||
|
|
||||||
const onFileDownloadUpdate = useCallback(
|
const onFileDownloadUpdate = useCallback(
|
||||||
async (state: DownloadState) => {
|
async (state: DownloadState) => {
|
||||||
@ -40,6 +46,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
|||||||
}
|
}
|
||||||
setInstallingExtension(state.extensionId!, installingExtensionState)
|
setInstallingExtension(state.extensionId!, installingExtensionState)
|
||||||
} else {
|
} else {
|
||||||
|
addDownloadingModel(state.modelId)
|
||||||
setDownloadState(state)
|
setDownloadState(state)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -54,6 +61,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
|||||||
} else {
|
} else {
|
||||||
state.downloadState = 'error'
|
state.downloadState = 'error'
|
||||||
setDownloadState(state)
|
setDownloadState(state)
|
||||||
|
removeDownloadingModel(state.modelId)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[setDownloadState, removeInstallingExtension]
|
[setDownloadState, removeInstallingExtension]
|
||||||
@ -68,6 +76,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
|||||||
state.downloadState = 'error'
|
state.downloadState = 'error'
|
||||||
state.error = 'aborted'
|
state.error = 'aborted'
|
||||||
setDownloadState(state)
|
setDownloadState(state)
|
||||||
|
removeDownloadingModel(state.modelId)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[setDownloadState, removeInstallingExtension]
|
[setDownloadState, removeInstallingExtension]
|
||||||
@ -79,6 +88,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
|||||||
if (state.downloadType !== 'extension') {
|
if (state.downloadType !== 'extension') {
|
||||||
state.downloadState = 'end'
|
state.downloadState = 'end'
|
||||||
setDownloadState(state)
|
setDownloadState(state)
|
||||||
|
removeDownloadingModel(state.modelId)
|
||||||
}
|
}
|
||||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,11 +1,6 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import {
|
import { ExtensionTypeEnum, ModelExtension } from '@janhq/core'
|
||||||
events,
|
|
||||||
ExtensionTypeEnum,
|
|
||||||
ModelEvent,
|
|
||||||
ModelExtension,
|
|
||||||
} from '@janhq/core'
|
|
||||||
|
|
||||||
import { useSetAtom } from 'jotai'
|
import { useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -19,13 +14,13 @@ import {
|
|||||||
} from '@/helpers/atoms/Model.atom'
|
} from '@/helpers/atoms/Model.atom'
|
||||||
|
|
||||||
export default function useDownloadModel() {
|
export default function useDownloadModel() {
|
||||||
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
|
||||||
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
|
const removeDownloadingModel = useSetAtom(removeDownloadingModelAtom)
|
||||||
|
const addDownloadingModel = useSetAtom(addDownloadingModelAtom)
|
||||||
|
|
||||||
const downloadModel = useCallback(
|
const downloadModel = useCallback(
|
||||||
async (model: string) => {
|
async (model: string, id?: string) => {
|
||||||
addDownloadingModel(model)
|
addDownloadingModel(id ?? model)
|
||||||
localDownloadModel(model).catch((error) => {
|
downloadLocalModel(model, id).catch((error) => {
|
||||||
if (error.message) {
|
if (error.message) {
|
||||||
toaster({
|
toaster({
|
||||||
title: 'Download failed',
|
title: 'Download failed',
|
||||||
@ -37,7 +32,7 @@ export default function useDownloadModel() {
|
|||||||
removeDownloadingModel(model)
|
removeDownloadingModel(model)
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
[addDownloadingModel]
|
[removeDownloadingModel, addDownloadingModel]
|
||||||
)
|
)
|
||||||
|
|
||||||
const abortModelDownload = useCallback(async (model: string) => {
|
const abortModelDownload = useCallback(async (model: string) => {
|
||||||
@ -50,10 +45,10 @@ export default function useDownloadModel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const localDownloadModel = async (model: string) =>
|
const downloadLocalModel = async (model: string, id?: string) =>
|
||||||
extensionManager
|
extensionManager
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.pullModel(model)
|
?.pullModel(model, id)
|
||||||
|
|
||||||
const cancelModelDownload = async (model: string) =>
|
const cancelModelDownload = async (model: string) =>
|
||||||
extensionManager
|
extensionManager
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import { useCallback, useState } from 'react'
|
|||||||
|
|
||||||
import { HuggingFaceRepoData } from '@janhq/core'
|
import { HuggingFaceRepoData } from '@janhq/core'
|
||||||
|
|
||||||
|
import { fetchHuggingFaceRepoData } from '@/utils/huggingface'
|
||||||
|
|
||||||
export const useGetHFRepoData = () => {
|
export const useGetHFRepoData = () => {
|
||||||
const [error, setError] = useState<string | undefined>(undefined)
|
const [error, setError] = useState<string | undefined>(undefined)
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
@ -29,8 +31,5 @@ export const useGetHFRepoData = () => {
|
|||||||
const extensionGetHfRepoData = async (
|
const extensionGetHfRepoData = async (
|
||||||
repoId: string
|
repoId: string
|
||||||
): Promise<HuggingFaceRepoData | undefined> => {
|
): Promise<HuggingFaceRepoData | undefined> => {
|
||||||
return Promise.resolve(undefined)
|
return fetchHuggingFaceRepoData(repoId)
|
||||||
// return extensionManager
|
|
||||||
// .get<ModelExtension>(ExtensionTypeEnum.Model)
|
|
||||||
// ?.fetchHuggingFaceRepoData(repoId)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -216,7 +216,7 @@ export default function useSendChatMessage() {
|
|||||||
...activeThreadRef.current,
|
...activeThreadRef.current,
|
||||||
updated: newMessage.created,
|
updated: newMessage.created,
|
||||||
metadata: {
|
metadata: {
|
||||||
...(activeThreadRef.current.metadata ?? {}),
|
...activeThreadRef.current.metadata,
|
||||||
lastMessage: prompt,
|
lastMessage: prompt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -256,7 +256,6 @@ export default function useSendChatMessage() {
|
|||||||
)
|
)
|
||||||
request.messages = normalizeMessages(request.messages ?? [])
|
request.messages = normalizeMessages(request.messages ?? [])
|
||||||
|
|
||||||
console.log(requestBuilder.model?.engine ?? modelRequest.engine, request)
|
|
||||||
// Request for inference
|
// Request for inference
|
||||||
EngineManager.instance()
|
EngineManager.instance()
|
||||||
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
.get(requestBuilder.model?.engine ?? modelRequest.engine ?? '')
|
||||||
|
|||||||
@ -64,7 +64,7 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
|
|||||||
const assistants = useAtomValue(assistantsAtom)
|
const assistants = useAtomValue(assistantsAtom)
|
||||||
|
|
||||||
const onDownloadClick = useCallback(() => {
|
const onDownloadClick = useCallback(() => {
|
||||||
downloadModel(model.sources[0].url)
|
downloadModel(model.sources[0].url, model.id)
|
||||||
}, [model, downloadModel])
|
}, [model, downloadModel])
|
||||||
|
|
||||||
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
||||||
@ -123,17 +123,6 @@ const ModelItemHeader = ({ model, onClick, open }: Props) => {
|
|||||||
className="cursor-pointer rounded-t-md bg-[hsla(var(--app-bg))]"
|
className="cursor-pointer rounded-t-md bg-[hsla(var(--app-bg))]"
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
>
|
>
|
||||||
{/* TODO: @faisal are we still using cover? */}
|
|
||||||
{/* {model.metadata.cover && imageLoaded && (
|
|
||||||
<div className="relative h-full w-full">
|
|
||||||
<img
|
|
||||||
onError={() => setImageLoaded(false)}
|
|
||||||
src={model.metadata.cover}
|
|
||||||
className="h-[250px] w-full object-cover"
|
|
||||||
alt={`Cover - ${model.id}`}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)} */}
|
|
||||||
<div className="flex items-center justify-between px-4 py-2">
|
<div className="flex items-center justify-between px-4 py-2">
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<span className="line-clamp-1 text-base font-semibold">
|
<span className="line-clamp-1 text-base font-semibold">
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
|
|||||||
|
|
||||||
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
|
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
|
import { normalizeModelId } from '@/utils/model'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
index: number
|
index: number
|
||||||
@ -50,13 +51,13 @@ const ModelDownloadRow: React.FC<Props> = ({
|
|||||||
|
|
||||||
const onAbortDownloadClick = useCallback(() => {
|
const onAbortDownloadClick = useCallback(() => {
|
||||||
if (downloadUrl) {
|
if (downloadUrl) {
|
||||||
abortModelDownload(downloadUrl)
|
abortModelDownload(normalizeModelId(downloadUrl))
|
||||||
}
|
}
|
||||||
}, [downloadUrl, abortModelDownload])
|
}, [downloadUrl, abortModelDownload])
|
||||||
|
|
||||||
const onDownloadClick = useCallback(async () => {
|
const onDownloadClick = useCallback(async () => {
|
||||||
if (downloadUrl) {
|
if (downloadUrl) {
|
||||||
downloadModel(downloadUrl)
|
downloadModel(downloadUrl, normalizeModelId(downloadUrl))
|
||||||
}
|
}
|
||||||
}, [downloadUrl, downloadModel])
|
}, [downloadUrl, downloadModel])
|
||||||
|
|
||||||
|
|||||||
@ -168,7 +168,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
size={18}
|
size={18}
|
||||||
className="cursor-pointer text-[hsla(var(--app-link))]"
|
className="cursor-pointer text-[hsla(var(--app-link))]"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(model.sources[0].url)
|
downloadModel(
|
||||||
|
model.sources[0].url,
|
||||||
|
model.id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
@ -256,7 +259,10 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
|||||||
theme="ghost"
|
theme="ghost"
|
||||||
className="!bg-[hsla(var(--secondary-bg))]"
|
className="!bg-[hsla(var(--secondary-bg))]"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
downloadModel(featModel.sources[0].url)
|
downloadModel(
|
||||||
|
featModel.sources[0].url,
|
||||||
|
featModel.id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
Download
|
Download
|
||||||
|
|||||||
@ -9,7 +9,7 @@ export function openExternalUrl(url: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Define API routes based on different route types
|
// Define API routes based on different route types
|
||||||
export const APIRoutes = [...CoreRoutes.map((r) => ({ path: `app`, route: r }))]
|
export const APIRoutes = CoreRoutes.map((r) => ({ path: `app`, route: r }))
|
||||||
|
|
||||||
// Define the restAPI object with methods for each API route
|
// Define the restAPI object with methods for each API route
|
||||||
export const restAPI = {
|
export const restAPI = {
|
||||||
|
|||||||
93
web/utils/huggingface.ts
Normal file
93
web/utils/huggingface.ts
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
|
||||||
|
|
||||||
|
export const fetchHuggingFaceRepoData = async (
|
||||||
|
repoId: string,
|
||||||
|
huggingFaceAccessToken?: string
|
||||||
|
): Promise<HuggingFaceRepoData> => {
|
||||||
|
const sanitizedUrl = toHuggingFaceUrl(repoId)
|
||||||
|
console.debug('sanitizedUrl', sanitizedUrl)
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
Accept: 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
if (huggingFaceAccessToken && huggingFaceAccessToken.length > 0) {
|
||||||
|
headers['Authorization'] = `Bearer ${huggingFaceAccessToken}`
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await fetch(sanitizedUrl, {
|
||||||
|
headers: headers,
|
||||||
|
})
|
||||||
|
const response = await res.json()
|
||||||
|
if (response['error'] != null) {
|
||||||
|
throw new Error(response['error'])
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = response as HuggingFaceRepoData
|
||||||
|
|
||||||
|
if (data.tags.indexOf('gguf') === -1) {
|
||||||
|
throw new Error(
|
||||||
|
`${repoId} is not supported. Only GGUF models are supported.`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const promises: Promise<number>[] = []
|
||||||
|
|
||||||
|
// fetching file sizes
|
||||||
|
const url = new URL(sanitizedUrl)
|
||||||
|
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
||||||
|
|
||||||
|
for (const sibling of data.siblings) {
|
||||||
|
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
|
||||||
|
sibling.downloadUrl = downloadUrl
|
||||||
|
promises.push(getFileSize(downloadUrl))
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await Promise.all(promises)
|
||||||
|
for (let i = 0; i < data.siblings.length; i++) {
|
||||||
|
data.siblings[i].fileSize = result[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
AllQuantizations.forEach((quantization) => {
|
||||||
|
data.siblings.forEach((sibling) => {
|
||||||
|
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
|
||||||
|
sibling.quantization = quantization
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}`
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
function toHuggingFaceUrl(repoId: string): string {
|
||||||
|
try {
|
||||||
|
const url = new URL(repoId)
|
||||||
|
if (url.host !== 'huggingface.co') {
|
||||||
|
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
||||||
|
if (paths.length < 2) {
|
||||||
|
throw new InvalidHostError(`Invalid Hugging Face repo URL: ${repoId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${url.origin}/api/models/${paths[0]}/${paths[1]}`
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof InvalidHostError) {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (repoId.startsWith('https')) {
|
||||||
|
throw new Error(`Cannot parse url: ${repoId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return `https://huggingface.co/api/models/${repoId}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
class InvalidHostError extends Error {
|
||||||
|
constructor(message: string) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'InvalidHostError'
|
||||||
|
}
|
||||||
|
}
|
||||||
3
web/utils/model.ts
Normal file
3
web/utils/model.ts
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export const normalizeModelId = (downloadUrl: string): string => {
|
||||||
|
return downloadUrl.split('/').pop() ?? downloadUrl
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user