feat: app supports cortex.cpp model downloader and legacy downloader - maintain legacy JSON models
This commit is contained in:
parent
5f075c8554
commit
8f778ee90f
@ -69,11 +69,11 @@ export enum DownloadRoute {
|
||||
}
|
||||
|
||||
export enum DownloadEvent {
|
||||
onFileDownloadUpdate = 'DownloadUpdated',
|
||||
onFileDownloadError = 'DownloadError',
|
||||
onFileDownloadSuccess = 'DownloadSuccess',
|
||||
onFileDownloadStopped = 'DownloadStopped',
|
||||
onFileDownloadStarted = 'DownloadStarted',
|
||||
onFileDownloadUpdate = 'onFileDownloadUpdate',
|
||||
onFileDownloadError = 'onFileDownloadError',
|
||||
onFileDownloadSuccess = 'onFileDownloadSuccess',
|
||||
onFileDownloadStopped = 'onFileDownloadStopped',
|
||||
onFileDownloadStarted = 'onFileDownloadStarted',
|
||||
onFileUnzipSuccess = 'onFileUnzipSuccess',
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,14 @@ type ModelList = {
|
||||
data: any[]
|
||||
}
|
||||
|
||||
enum DownloadTypes {
|
||||
DownloadUpdated = 'onFileDownloadUpdate',
|
||||
DownloadError = 'onFileDownloadError',
|
||||
DownloadSuccess = 'onFileDownloadSuccess',
|
||||
DownloadStopped = 'onFileDownloadStopped',
|
||||
DownloadStarted = 'onFileDownloadStarted',
|
||||
}
|
||||
|
||||
export class CortexAPI implements ICortexAPI {
|
||||
queue = new PQueue({ concurrency: 1 })
|
||||
socket?: WebSocket = undefined
|
||||
@ -159,17 +167,16 @@ export class CortexAPI implements ICortexAPI {
|
||||
this.socket.addEventListener('message', (event) => {
|
||||
const data = JSON.parse(event.data)
|
||||
const transferred = data.task.items.reduce(
|
||||
(accumulator, currentValue) =>
|
||||
accumulator + currentValue.downloadedBytes,
|
||||
(acc, cur) => acc + cur.downloadedBytes,
|
||||
0
|
||||
)
|
||||
const total = data.task.items.reduce(
|
||||
(accumulator, currentValue) => accumulator + currentValue.bytes,
|
||||
(acc, cur) => acc + cur.bytes,
|
||||
0
|
||||
)
|
||||
const percent = (transferred / total || 0) * 100
|
||||
|
||||
events.emit(data.type, {
|
||||
events.emit(DownloadTypes[data.type], {
|
||||
modelId: data.task.id,
|
||||
percent: percent,
|
||||
size: {
|
||||
@ -178,7 +185,7 @@ export class CortexAPI implements ICortexAPI {
|
||||
},
|
||||
})
|
||||
// Update models list from Hub
|
||||
if (data.type === DownloadEvent.onFileDownloadSuccess) {
|
||||
if (data.type === DownloadTypes.DownloadSuccess) {
|
||||
// Delay for the state update from cortex.cpp
|
||||
// Just to be sure
|
||||
setTimeout(() => {
|
||||
|
||||
@ -4,9 +4,15 @@ import {
|
||||
InferenceEngine,
|
||||
joinPath,
|
||||
dirName,
|
||||
ModelManager,
|
||||
abortDownload,
|
||||
DownloadState,
|
||||
events,
|
||||
DownloadEvent,
|
||||
} from '@janhq/core'
|
||||
import { CortexAPI } from './cortex'
|
||||
import { scanModelsFolder } from './model-json'
|
||||
import { scanModelsFolder } from './legacy/model-json'
|
||||
import { downloadModel } from './legacy/download'
|
||||
|
||||
declare const SETTINGS: Array<any>
|
||||
|
||||
@ -34,6 +40,9 @@ export default class JanModelExtension extends ModelExtension {
|
||||
this.getModels().then((models) => {
|
||||
this.registerModels(models)
|
||||
})
|
||||
|
||||
// Listen to app download events
|
||||
this.handleDesktopEvents()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -48,6 +57,17 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* @returns A Promise that resolves when the model is downloaded.
|
||||
*/
|
||||
async pullModel(model: string, id?: string): Promise<void> {
|
||||
if (id) {
|
||||
const model: Model = ModelManager.instance().get(id)
|
||||
// Clip vision model - should not be handled by cortex.cpp
|
||||
// TensorRT model - should not be handled by cortex.cpp
|
||||
if (
|
||||
model.engine === InferenceEngine.nitro_tensorrt_llm ||
|
||||
model.settings.vision_model
|
||||
) {
|
||||
return downloadModel(model)
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Sending POST to /models/pull/{id} endpoint to pull the model
|
||||
*/
|
||||
@ -61,10 +81,24 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
|
||||
*/
|
||||
async cancelModelPull(model: string): Promise<void> {
|
||||
if (model) {
|
||||
const modelDto: Model = ModelManager.instance().get(model)
|
||||
// Clip vision model - should not be handled by cortex.cpp
|
||||
// TensorRT model - should not be handled by cortex.cpp
|
||||
if (
|
||||
modelDto.engine === InferenceEngine.nitro_tensorrt_llm ||
|
||||
modelDto.settings.vision_model
|
||||
) {
|
||||
for (const source of modelDto.sources) {
|
||||
const path = await joinPath(['models', modelDto.id, source.filename])
|
||||
return abortDownload(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Sending DELETE to /models/pull/{id} endpoint to cancel a model pull
|
||||
*/
|
||||
this.cortexAPI.cancelModelPull(model)
|
||||
return this.cortexAPI.cancelModelPull(model)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -87,14 +121,18 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* should compare and try import
|
||||
*/
|
||||
let currentModels: Model[] = []
|
||||
|
||||
/**
|
||||
* Legacy models should be supported
|
||||
*/
|
||||
let legacyModels = await scanModelsFolder()
|
||||
|
||||
try {
|
||||
if (!localStorage.getItem(ExtensionEnum.downloadedModels)) {
|
||||
// Updated from an older version than 0.5.5
|
||||
// Scan through the models folder and import them (Legacy flow)
|
||||
// Return models immediately
|
||||
currentModels = await scanModelsFolder().then((models) => {
|
||||
return models ?? []
|
||||
})
|
||||
currentModels = legacyModels
|
||||
} else {
|
||||
currentModels = JSON.parse(
|
||||
localStorage.getItem(ExtensionEnum.downloadedModels)
|
||||
@ -116,7 +154,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
await this.cortexAPI.getModels().then((models) => {
|
||||
const existingIds = models.map((e) => e.id)
|
||||
toImportModels = toImportModels.filter(
|
||||
(e: Model) => !existingIds.includes(e.id)
|
||||
(e: Model) => !existingIds.includes(e.id) && !e.settings?.vision_model
|
||||
)
|
||||
})
|
||||
|
||||
@ -147,13 +185,15 @@ export default class JanModelExtension extends ModelExtension {
|
||||
}
|
||||
|
||||
/**
|
||||
* All models are imported successfully before
|
||||
* just return models from cortex.cpp
|
||||
* Models are imported successfully before
|
||||
* Now return models from cortex.cpp and merge with legacy models which are not imported
|
||||
*/
|
||||
return (
|
||||
this.cortexAPI.getModels().then((models) => {
|
||||
return models
|
||||
}) ?? Promise.resolve([])
|
||||
return models.concat(
|
||||
legacyModels.filter((e) => !models.some((x) => x.id === e.id))
|
||||
)
|
||||
}) ?? Promise.resolve(legacyModels)
|
||||
)
|
||||
}
|
||||
|
||||
@ -175,4 +215,31 @@ export default class JanModelExtension extends ModelExtension {
|
||||
async importModel(model: string, modelPath: string): Promise<void> {
|
||||
return this.cortexAPI.importModel(model, modelPath)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle download state from main app
|
||||
*/
|
||||
handleDesktopEvents() {
|
||||
if (window && window.electronAPI) {
|
||||
window.electronAPI.onFileDownloadUpdate(
|
||||
async (_event: string, state: DownloadState | undefined) => {
|
||||
if (!state) return
|
||||
state.downloadState = 'downloading'
|
||||
events.emit(DownloadEvent.onFileDownloadUpdate, state)
|
||||
}
|
||||
)
|
||||
window.electronAPI.onFileDownloadError(
|
||||
async (_event: string, state: DownloadState) => {
|
||||
state.downloadState = 'error'
|
||||
events.emit(DownloadEvent.onFileDownloadError, state)
|
||||
}
|
||||
)
|
||||
window.electronAPI.onFileDownloadSuccess(
|
||||
async (_event: string, state: DownloadState) => {
|
||||
state.downloadState = 'end'
|
||||
events.emit(DownloadEvent.onFileDownloadSuccess, state)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
97
extensions/model-extension/src/legacy/download.ts
Normal file
97
extensions/model-extension/src/legacy/download.ts
Normal file
@ -0,0 +1,97 @@
|
||||
import {
|
||||
downloadFile,
|
||||
DownloadRequest,
|
||||
fs,
|
||||
GpuSetting,
|
||||
InferenceEngine,
|
||||
joinPath,
|
||||
Model,
|
||||
} from '@janhq/core'
|
||||
|
||||
export const downloadModel = async (
|
||||
model: Model,
|
||||
gpuSettings?: GpuSetting,
|
||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||
): Promise<void> => {
|
||||
const homedir = 'file://models'
|
||||
const supportedGpuArch = ['ampere', 'ada']
|
||||
// Create corresponding directory
|
||||
const modelDirPath = await joinPath([homedir, model.id])
|
||||
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
||||
|
||||
if (model.engine === InferenceEngine.nitro_tensorrt_llm) {
|
||||
if (!gpuSettings || gpuSettings.gpus.length === 0) {
|
||||
console.error('No GPU found. Please check your GPU setting.')
|
||||
return
|
||||
}
|
||||
const firstGpu = gpuSettings.gpus[0]
|
||||
if (!firstGpu.name.toLowerCase().includes('nvidia')) {
|
||||
console.error('No Nvidia GPU found. Please check your GPU setting.')
|
||||
return
|
||||
}
|
||||
const gpuArch = firstGpu.arch
|
||||
if (gpuArch === undefined) {
|
||||
console.error('No GPU architecture found. Please check your GPU setting.')
|
||||
return
|
||||
}
|
||||
|
||||
if (!supportedGpuArch.includes(gpuArch)) {
|
||||
console.debug(
|
||||
`Your GPU: ${JSON.stringify(firstGpu)} is not supported. Only 30xx, 40xx series are supported.`
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const os = 'windows' // TODO: remove this hard coded value
|
||||
|
||||
const newSources = model.sources.map((source) => {
|
||||
const newSource = { ...source }
|
||||
newSource.url = newSource.url
|
||||
.replace(/<os>/g, os)
|
||||
.replace(/<gpuarch>/g, gpuArch)
|
||||
return newSource
|
||||
})
|
||||
model.sources = newSources
|
||||
}
|
||||
|
||||
console.debug(`Download sources: ${JSON.stringify(model.sources)}`)
|
||||
|
||||
if (model.sources.length > 1) {
|
||||
// path to model binaries
|
||||
for (const source of model.sources) {
|
||||
let path = extractFileName(source.url, '.gguf')
|
||||
if (source.filename) {
|
||||
path = await joinPath([modelDirPath, source.filename])
|
||||
}
|
||||
|
||||
const downloadRequest: DownloadRequest = {
|
||||
url: source.url,
|
||||
localPath: path,
|
||||
modelId: model.id,
|
||||
}
|
||||
downloadFile(downloadRequest, network)
|
||||
}
|
||||
} else {
|
||||
const fileName = extractFileName(model.sources[0]?.url, '.gguf')
|
||||
const path = await joinPath([modelDirPath, fileName])
|
||||
const downloadRequest: DownloadRequest = {
|
||||
url: model.sources[0]?.url,
|
||||
localPath: path,
|
||||
modelId: model.id,
|
||||
}
|
||||
downloadFile(downloadRequest, network)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* try to retrieve the download file name from the source url
|
||||
*/
|
||||
function extractFileName(url: string, fileExtension: string): string {
|
||||
if (!url) return fileExtension
|
||||
|
||||
const extractedFileName = url.split('/').pop()
|
||||
const fileName = extractedFileName.toLowerCase().endsWith(fileExtension)
|
||||
? extractedFileName
|
||||
: extractedFileName + fileExtension
|
||||
return fileName
|
||||
}
|
||||
@ -71,7 +71,7 @@ export const scanModelsFolder = async (): Promise<Model[]> => {
|
||||
file.toLowerCase().endsWith('.gguf') || // GGUF
|
||||
file.toLowerCase().endsWith('.engine') // Tensort-LLM
|
||||
)
|
||||
})?.length > 0 // TODO: find better way (can use basename to check the file name with source url)
|
||||
})?.length >= (model.sources?.length ?? 1)
|
||||
)
|
||||
})
|
||||
|
||||
@ -50,7 +50,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
setDownloadState(state)
|
||||
}
|
||||
},
|
||||
[setDownloadState, setInstallingExtension]
|
||||
[addDownloadingModel, setDownloadState, setInstallingExtension]
|
||||
)
|
||||
|
||||
const onFileDownloadError = useCallback(
|
||||
@ -64,7 +64,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
removeDownloadingModel(state.modelId)
|
||||
}
|
||||
},
|
||||
[setDownloadState, removeInstallingExtension]
|
||||
[removeInstallingExtension, setDownloadState, removeDownloadingModel]
|
||||
)
|
||||
|
||||
const onFileDownloadStopped = useCallback(
|
||||
@ -79,7 +79,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
removeDownloadingModel(state.modelId)
|
||||
}
|
||||
},
|
||||
[setDownloadState, removeInstallingExtension]
|
||||
[removeInstallingExtension, setDownloadState, removeDownloadingModel]
|
||||
)
|
||||
|
||||
const onFileDownloadSuccess = useCallback(
|
||||
@ -92,7 +92,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
}
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
},
|
||||
[setDownloadState]
|
||||
[removeDownloadingModel, setDownloadState]
|
||||
)
|
||||
|
||||
const onFileUnzipSuccess = useCallback(
|
||||
@ -121,7 +121,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
|
||||
events.off(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
|
||||
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
||||
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
||||
events.off(DownloadEvent.onFileUnzipSuccess, onFileUnzipSuccess)
|
||||
}
|
||||
}, [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user