feat: app supports cortex.cpp model downloader and legacy downloader - maintain legacy JSON models

This commit is contained in:
Louis 2024-10-24 14:49:18 +07:00
parent 5f075c8554
commit 8f778ee90f
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 197 additions and 26 deletions

View File

@ -69,11 +69,11 @@ export enum DownloadRoute {
}
export enum DownloadEvent {
onFileDownloadUpdate = 'DownloadUpdated',
onFileDownloadError = 'DownloadError',
onFileDownloadSuccess = 'DownloadSuccess',
onFileDownloadStopped = 'DownloadStopped',
onFileDownloadStarted = 'DownloadStarted',
onFileDownloadUpdate = 'onFileDownloadUpdate',
onFileDownloadError = 'onFileDownloadError',
onFileDownloadSuccess = 'onFileDownloadSuccess',
onFileDownloadStopped = 'onFileDownloadStopped',
onFileDownloadStarted = 'onFileDownloadStarted',
onFileUnzipSuccess = 'onFileUnzipSuccess',
}

View File

@ -25,6 +25,14 @@ type ModelList = {
data: any[]
}
enum DownloadTypes {
DownloadUpdated = 'onFileDownloadUpdate',
DownloadError = 'onFileDownloadError',
DownloadSuccess = 'onFileDownloadSuccess',
DownloadStopped = 'onFileDownloadStopped',
DownloadStarted = 'onFileDownloadStarted',
}
export class CortexAPI implements ICortexAPI {
queue = new PQueue({ concurrency: 1 })
socket?: WebSocket = undefined
@ -159,17 +167,16 @@ export class CortexAPI implements ICortexAPI {
this.socket.addEventListener('message', (event) => {
const data = JSON.parse(event.data)
const transferred = data.task.items.reduce(
(accumulator, currentValue) =>
accumulator + currentValue.downloadedBytes,
(acc, cur) => acc + cur.downloadedBytes,
0
)
const total = data.task.items.reduce(
(accumulator, currentValue) => accumulator + currentValue.bytes,
(acc, cur) => acc + cur.bytes,
0
)
const percent = (transferred / total || 0) * 100
events.emit(data.type, {
events.emit(DownloadTypes[data.type], {
modelId: data.task.id,
percent: percent,
size: {
@ -178,7 +185,7 @@ export class CortexAPI implements ICortexAPI {
},
})
// Update models list from Hub
if (data.type === DownloadEvent.onFileDownloadSuccess) {
if (data.type === DownloadTypes.DownloadSuccess) {
// Delay for the state update from cortex.cpp
// Just to be sure
setTimeout(() => {

View File

@ -4,9 +4,15 @@ import {
InferenceEngine,
joinPath,
dirName,
ModelManager,
abortDownload,
DownloadState,
events,
DownloadEvent,
} from '@janhq/core'
import { CortexAPI } from './cortex'
import { scanModelsFolder } from './model-json'
import { scanModelsFolder } from './legacy/model-json'
import { downloadModel } from './legacy/download'
declare const SETTINGS: Array<any>
@ -34,6 +40,9 @@ export default class JanModelExtension extends ModelExtension {
this.getModels().then((models) => {
this.registerModels(models)
})
// Listen to app download events
this.handleDesktopEvents()
}
/**
@ -48,6 +57,17 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is downloaded.
*/
async pullModel(model: string, id?: string): Promise<void> {
if (id) {
const model: Model = ModelManager.instance().get(id)
// Clip vision model - should not be handled by cortex.cpp
// TensorRT model - should not be handled by cortex.cpp
if (
model.engine === InferenceEngine.nitro_tensorrt_llm ||
model.settings.vision_model
) {
return downloadModel(model)
}
}
/**
* Sending POST to /models/pull/{id} endpoint to pull the model
*/
@ -61,10 +81,24 @@ export default class JanModelExtension extends ModelExtension {
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
*/
async cancelModelPull(model: string): Promise<void> {
if (model) {
const modelDto: Model = ModelManager.instance().get(model)
// Clip vision model - should not be handled by cortex.cpp
// TensorRT model - should not be handled by cortex.cpp
if (
modelDto.engine === InferenceEngine.nitro_tensorrt_llm ||
modelDto.settings.vision_model
) {
for (const source of modelDto.sources) {
const path = await joinPath(['models', modelDto.id, source.filename])
return abortDownload(path)
}
}
}
/**
* Sending DELETE to /models/pull/{id} endpoint to cancel a model pull
*/
this.cortexAPI.cancelModelPull(model)
return this.cortexAPI.cancelModelPull(model)
}
/**
@ -87,14 +121,18 @@ export default class JanModelExtension extends ModelExtension {
* should compare and try import
*/
let currentModels: Model[] = []
/**
* Legacy models should be supported
*/
let legacyModels = await scanModelsFolder()
try {
if (!localStorage.getItem(ExtensionEnum.downloadedModels)) {
// Updated from an older version than 0.5.5
// Scan through the models folder and import them (Legacy flow)
// Return models immediately
currentModels = await scanModelsFolder().then((models) => {
return models ?? []
})
currentModels = legacyModels
} else {
currentModels = JSON.parse(
localStorage.getItem(ExtensionEnum.downloadedModels)
@ -116,7 +154,7 @@ export default class JanModelExtension extends ModelExtension {
await this.cortexAPI.getModels().then((models) => {
const existingIds = models.map((e) => e.id)
toImportModels = toImportModels.filter(
(e: Model) => !existingIds.includes(e.id)
(e: Model) => !existingIds.includes(e.id) && !e.settings?.vision_model
)
})
@ -147,13 +185,15 @@ export default class JanModelExtension extends ModelExtension {
}
/**
* All models are imported successfully before
* just return models from cortex.cpp
* Models are imported successfully before
* Now return models from cortex.cpp and merge with legacy models which are not imported
*/
return (
this.cortexAPI.getModels().then((models) => {
return models
}) ?? Promise.resolve([])
return models.concat(
legacyModels.filter((e) => !models.some((x) => x.id === e.id))
)
}) ?? Promise.resolve(legacyModels)
)
}
@ -175,4 +215,31 @@ export default class JanModelExtension extends ModelExtension {
async importModel(model: string, modelPath: string): Promise<void> {
return this.cortexAPI.importModel(model, modelPath)
}
/**
* Handle download state from main app
*/
handleDesktopEvents() {
if (window && window.electronAPI) {
window.electronAPI.onFileDownloadUpdate(
async (_event: string, state: DownloadState | undefined) => {
if (!state) return
state.downloadState = 'downloading'
events.emit(DownloadEvent.onFileDownloadUpdate, state)
}
)
window.electronAPI.onFileDownloadError(
async (_event: string, state: DownloadState) => {
state.downloadState = 'error'
events.emit(DownloadEvent.onFileDownloadError, state)
}
)
window.electronAPI.onFileDownloadSuccess(
async (_event: string, state: DownloadState) => {
state.downloadState = 'end'
events.emit(DownloadEvent.onFileDownloadSuccess, state)
}
)
}
}
}

View File

@ -0,0 +1,97 @@
import {
downloadFile,
DownloadRequest,
fs,
GpuSetting,
InferenceEngine,
joinPath,
Model,
} from '@janhq/core'
export const downloadModel = async (
model: Model,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void> => {
const homedir = 'file://models'
const supportedGpuArch = ['ampere', 'ada']
// Create corresponding directory
const modelDirPath = await joinPath([homedir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
if (model.engine === InferenceEngine.nitro_tensorrt_llm) {
if (!gpuSettings || gpuSettings.gpus.length === 0) {
console.error('No GPU found. Please check your GPU setting.')
return
}
const firstGpu = gpuSettings.gpus[0]
if (!firstGpu.name.toLowerCase().includes('nvidia')) {
console.error('No Nvidia GPU found. Please check your GPU setting.')
return
}
const gpuArch = firstGpu.arch
if (gpuArch === undefined) {
console.error('No GPU architecture found. Please check your GPU setting.')
return
}
if (!supportedGpuArch.includes(gpuArch)) {
console.debug(
`Your GPU: ${JSON.stringify(firstGpu)} is not supported. Only 30xx, 40xx series are supported.`
)
return
}
const os = 'windows' // TODO: remove this hard coded value
const newSources = model.sources.map((source) => {
const newSource = { ...source }
newSource.url = newSource.url
.replace(/<os>/g, os)
.replace(/<gpuarch>/g, gpuArch)
return newSource
})
model.sources = newSources
}
console.debug(`Download sources: ${JSON.stringify(model.sources)}`)
if (model.sources.length > 1) {
// path to model binaries
for (const source of model.sources) {
let path = extractFileName(source.url, '.gguf')
if (source.filename) {
path = await joinPath([modelDirPath, source.filename])
}
const downloadRequest: DownloadRequest = {
url: source.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)
}
} else {
const fileName = extractFileName(model.sources[0]?.url, '.gguf')
const path = await joinPath([modelDirPath, fileName])
const downloadRequest: DownloadRequest = {
url: model.sources[0]?.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)
}
}
/**
* try to retrieve the download file name from the source url
*/
function extractFileName(url: string, fileExtension: string): string {
if (!url) return fileExtension
const extractedFileName = url.split('/').pop()
const fileName = extractedFileName.toLowerCase().endsWith(fileExtension)
? extractedFileName
: extractedFileName + fileExtension
return fileName
}

View File

@ -71,7 +71,7 @@ export const scanModelsFolder = async (): Promise<Model[]> => {
file.toLowerCase().endsWith('.gguf') || // GGUF
file.toLowerCase().endsWith('.engine') // Tensort-LLM
)
})?.length > 0 // TODO: find better way (can use basename to check the file name with source url)
})?.length >= (model.sources?.length ?? 1)
)
})

View File

@ -50,7 +50,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
setDownloadState(state)
}
},
[setDownloadState, setInstallingExtension]
[addDownloadingModel, setDownloadState, setInstallingExtension]
)
const onFileDownloadError = useCallback(
@ -64,7 +64,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
removeDownloadingModel(state.modelId)
}
},
[setDownloadState, removeInstallingExtension]
[removeInstallingExtension, setDownloadState, removeDownloadingModel]
)
const onFileDownloadStopped = useCallback(
@ -79,7 +79,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
removeDownloadingModel(state.modelId)
}
},
[setDownloadState, removeInstallingExtension]
[removeInstallingExtension, setDownloadState, removeDownloadingModel]
)
const onFileDownloadSuccess = useCallback(
@ -92,7 +92,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
}
events.emit(ModelEvent.OnModelsUpdate, {})
},
[setDownloadState]
[removeDownloadingModel, setDownloadState]
)
const onFileUnzipSuccess = useCallback(
@ -121,7 +121,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
events.off(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
events.off(DownloadEvent.onFileUnzipSuccess, onFileUnzipSuccess)
}
}, [