feat: Deprecate model.json ready state in favor of .download ext (#1238)

* feat: Deprecate model.json ready state in favor of .download ext

* refactor: resolve ts ignore

* chore: fix warning

* fix: path polyfill on Windows
This commit is contained in:
Louis 2023-12-28 14:06:13 +07:00 committed by GitHub
parent cbc63da831
commit 7feaf9694d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 245 additions and 180 deletions

View File

@ -67,13 +67,6 @@ export type Model = {
*/ */
description: string description: string
/**
* The model state.
* Default: "to_download"
* Enum: "to_download" "downloading" "ready" "running"
*/
state?: ModelState
/** /**
* The model settings. * The model settings.
*/ */
@ -101,15 +94,6 @@ export type ModelMetadata = {
cover?: string cover?: string
} }
/**
* The Model transition states.
*/
export enum ModelState {
Downloading = 'downloading',
Ready = 'ready',
Running = 'running',
}
/** /**
* The available model settings. * The available model settings.
*/ */

View File

@ -3,7 +3,7 @@ import { DownloadManager } from './../managers/download'
import { resolve, join } from 'path' import { resolve, join } from 'path'
import { WindowManager } from './../managers/window' import { WindowManager } from './../managers/window'
import request from 'request' import request from 'request'
import { createWriteStream } from 'fs' import { createWriteStream, renameSync } from 'fs'
import { DownloadEvent, DownloadRoute } from '@janhq/core' import { DownloadEvent, DownloadRoute } from '@janhq/core'
const progress = require('request-progress') const progress = require('request-progress')
@ -48,6 +48,8 @@ export function handleDownloaderIPCs() {
const userDataPath = join(app.getPath('home'), 'jan') const userDataPath = join(app.getPath('home'), 'jan')
const destination = resolve(userDataPath, fileName) const destination = resolve(userDataPath, fileName)
const rq = request(url) const rq = request(url)
// downloading file to a temp file first
const downloadingTempFile = `${destination}.download`
progress(rq, {}) progress(rq, {})
.on('progress', function (state: any) { .on('progress', function (state: any) {
@ -70,6 +72,9 @@ export function handleDownloaderIPCs() {
}) })
.on('end', function () { .on('end', function () {
if (DownloadManager.instance.networkRequests[fileName]) { if (DownloadManager.instance.networkRequests[fileName]) {
// Finished downloading, rename temp file to actual file
renameSync(downloadingTempFile, destination)
WindowManager?.instance.currentWindow?.webContents.send( WindowManager?.instance.currentWindow?.webContents.send(
DownloadEvent.onFileDownloadSuccess, DownloadEvent.onFileDownloadSuccess,
{ {
@ -87,7 +92,7 @@ export function handleDownloaderIPCs() {
) )
} }
}) })
.pipe(createWriteStream(destination)) .pipe(createWriteStream(downloadingTempFile))
DownloadManager.instance.setRequest(fileName, rq) DownloadManager.instance.setRequest(fileName, rq)
}) })

View File

@ -1,7 +1,6 @@
import { ExtensionType, fs } from '@janhq/core' import { ExtensionType, fs, joinPath } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core' import { ConversationalExtension } from '@janhq/core'
import { Thread, ThreadMessage } from '@janhq/core' import { Thread, ThreadMessage } from '@janhq/core'
import { join } from 'path'
/** /**
* JSONConversationalExtension is a ConversationalExtension implementation that provides * JSONConversationalExtension is a ConversationalExtension implementation that provides
@ -69,14 +68,14 @@ export default class JSONConversationalExtension
*/ */
async saveThread(thread: Thread): Promise<void> { async saveThread(thread: Thread): Promise<void> {
try { try {
const threadDirPath = join( const threadDirPath = await joinPath([
JSONConversationalExtension._homeDir, JSONConversationalExtension._homeDir,
thread.id thread.id,
) ])
const threadJsonPath = join( const threadJsonPath = await joinPath([
threadDirPath, threadDirPath,
JSONConversationalExtension._threadInfoFileName JSONConversationalExtension._threadInfoFileName,
) ])
await fs.mkdir(threadDirPath) await fs.mkdir(threadDirPath)
await fs.writeFile(threadJsonPath, JSON.stringify(thread, null, 2)) await fs.writeFile(threadJsonPath, JSON.stringify(thread, null, 2))
Promise.resolve() Promise.resolve()
@ -89,20 +88,22 @@ export default class JSONConversationalExtension
* Delete a thread with the specified ID. * Delete a thread with the specified ID.
* @param threadId The ID of the thread to delete. * @param threadId The ID of the thread to delete.
*/ */
deleteThread(threadId: string): Promise<void> { async deleteThread(threadId: string): Promise<void> {
return fs.rmdir(join(JSONConversationalExtension._homeDir, `${threadId}`)) return fs.rmdir(
await joinPath([JSONConversationalExtension._homeDir, `${threadId}`])
)
} }
async addNewMessage(message: ThreadMessage): Promise<void> { async addNewMessage(message: ThreadMessage): Promise<void> {
try { try {
const threadDirPath = join( const threadDirPath = await joinPath([
JSONConversationalExtension._homeDir, JSONConversationalExtension._homeDir,
message.thread_id message.thread_id,
) ])
const threadMessagePath = join( const threadMessagePath = await joinPath([
threadDirPath, threadDirPath,
JSONConversationalExtension._threadMessagesFileName JSONConversationalExtension._threadMessagesFileName,
) ])
await fs.mkdir(threadDirPath) await fs.mkdir(threadDirPath)
await fs.appendFile(threadMessagePath, JSON.stringify(message) + '\n') await fs.appendFile(threadMessagePath, JSON.stringify(message) + '\n')
Promise.resolve() Promise.resolve()
@ -116,11 +117,14 @@ export default class JSONConversationalExtension
messages: ThreadMessage[] messages: ThreadMessage[]
): Promise<void> { ): Promise<void> {
try { try {
const threadDirPath = join(JSONConversationalExtension._homeDir, threadId) const threadDirPath = await joinPath([
const threadMessagePath = join( JSONConversationalExtension._homeDir,
threadId,
])
const threadMessagePath = await joinPath([
threadDirPath, threadDirPath,
JSONConversationalExtension._threadMessagesFileName JSONConversationalExtension._threadMessagesFileName,
) ])
await fs.mkdir(threadDirPath) await fs.mkdir(threadDirPath)
await fs.writeFile( await fs.writeFile(
threadMessagePath, threadMessagePath,
@ -140,11 +144,11 @@ export default class JSONConversationalExtension
*/ */
private async readThread(threadDirName: string): Promise<any> { private async readThread(threadDirName: string): Promise<any> {
return fs.readFile( return fs.readFile(
join( await joinPath([
JSONConversationalExtension._homeDir, JSONConversationalExtension._homeDir,
threadDirName, threadDirName,
JSONConversationalExtension._threadInfoFileName JSONConversationalExtension._threadInfoFileName,
) ])
) )
} }
@ -159,10 +163,10 @@ export default class JSONConversationalExtension
const threadDirs: string[] = [] const threadDirs: string[] = []
for (let i = 0; i < fileInsideThread.length; i++) { for (let i = 0; i < fileInsideThread.length; i++) {
const path = join( const path = await joinPath([
JSONConversationalExtension._homeDir, JSONConversationalExtension._homeDir,
fileInsideThread[i] fileInsideThread[i],
) ])
const isDirectory = await fs.isDirectory(path) const isDirectory = await fs.isDirectory(path)
if (!isDirectory) { if (!isDirectory) {
console.debug(`Ignore ${path} because it is not a directory`) console.debug(`Ignore ${path} because it is not a directory`)
@ -184,7 +188,10 @@ export default class JSONConversationalExtension
async getAllMessages(threadId: string): Promise<ThreadMessage[]> { async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
try { try {
const threadDirPath = join(JSONConversationalExtension._homeDir, threadId) const threadDirPath = await joinPath([
JSONConversationalExtension._homeDir,
threadId,
])
const isDir = await fs.isDirectory(threadDirPath) const isDir = await fs.isDirectory(threadDirPath)
if (!isDir) { if (!isDir) {
throw Error(`${threadDirPath} is not directory`) throw Error(`${threadDirPath} is not directory`)
@ -197,10 +204,10 @@ export default class JSONConversationalExtension
throw Error(`${threadDirPath} not contains message file`) throw Error(`${threadDirPath} not contains message file`)
} }
const messageFilePath = join( const messageFilePath = await joinPath([
threadDirPath, threadDirPath,
JSONConversationalExtension._threadMessagesFileName JSONConversationalExtension._threadMessagesFileName,
) ])
const result = await fs.readLineByLine(messageFilePath) const result = await fs.readLineByLine(messageFilePath)

View File

@ -111,7 +111,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
return; return;
} }
const userSpacePath = await getUserSpace(); const userSpacePath = await getUserSpace();
const modelFullPath = join(userSpacePath, "models", model.id, model.id); const modelFullPath = join(userSpacePath, "models", model.id);
const nitroInitResult = await executeOnMain(MODULE, "initModel", { const nitroInitResult = await executeOnMain(MODULE, "initModel", {
modelFullPath: modelFullPath, modelFullPath: modelFullPath,

View File

@ -13,10 +13,11 @@ const NITRO_HTTP_LOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/
const NITRO_HTTP_UNLOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/unloadModel`; const NITRO_HTTP_UNLOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/unloadModel`;
const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/modelstatus`; const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/modelstatus`;
const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy`; const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy`;
const SUPPORTED_MODEL_FORMAT = ".gguf";
// The subprocess instance for Nitro // The subprocess instance for Nitro
let subprocess = undefined; let subprocess = undefined;
let currentModelFile = undefined; let currentModelFile: string = undefined;
let currentSettings = undefined; let currentSettings = undefined;
/** /**
@ -37,6 +38,17 @@ function stopModel(): Promise<void> {
*/ */
async function initModel(wrapper: any): Promise<ModelOperationResponse> { async function initModel(wrapper: any): Promise<ModelOperationResponse> {
currentModelFile = wrapper.modelFullPath; currentModelFile = wrapper.modelFullPath;
const files: string[] = fs.readdirSync(currentModelFile);
// Look for GGUF model file
const ggufBinFile = files.find(
(file) =>
file === path.basename(currentModelFile) ||
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT)
);
currentModelFile = path.join(currentModelFile, ggufBinFile);
if (wrapper.model.engine !== "nitro") { if (wrapper.model.engine !== "nitro") {
return Promise.resolve({ error: "Not a nitro model" }); return Promise.resolve({ error: "Not a nitro model" });
} else { } else {
@ -66,14 +78,14 @@ async function initModel(wrapper: any): Promise<ModelOperationResponse> {
async function loadModel(nitroResourceProbe: any | undefined) { async function loadModel(nitroResourceProbe: any | undefined) {
// Gather system information for CPU physical cores and memory // Gather system information for CPU physical cores and memory
if (!nitroResourceProbe) nitroResourceProbe = await getResourcesInfo(); if (!nitroResourceProbe) nitroResourceProbe = await getResourcesInfo();
return killSubprocess() return (
killSubprocess()
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
// wait for 500ms to make sure the port is free for windows platform // wait for 500ms to make sure the port is free for windows platform
.then(() => { .then(() => {
if (process.platform === "win32") { if (process.platform === "win32") {
return sleep(500); return sleep(500);
} } else {
else {
return sleep(0); return sleep(0);
} }
}) })
@ -84,7 +96,8 @@ async function loadModel(nitroResourceProbe: any | undefined) {
console.error("error: ", err); console.error("error: ", err);
// TODO: Broadcast error so app could display proper error message // TODO: Broadcast error so app could display proper error message
return { error: err, currentModelFile }; return { error: err, currentModelFile };
}); })
);
} }
// Add function sleep // Add function sleep

View File

@ -5,9 +5,11 @@ import {
abortDownload, abortDownload,
getResourcePath, getResourcePath,
getUserSpace, getUserSpace,
InferenceEngine,
joinPath,
} from '@janhq/core' } from '@janhq/core'
import { ModelExtension, Model, ModelState } from '@janhq/core' import { basename } from 'path'
import { join } from 'path' import { ModelExtension, Model } from '@janhq/core'
/** /**
* A extension for models * A extension for models
@ -15,6 +17,9 @@ import { join } from 'path'
export default class JanModelExtension implements ModelExtension { export default class JanModelExtension implements ModelExtension {
private static readonly _homeDir = 'models' private static readonly _homeDir = 'models'
private static readonly _modelMetadataFileName = 'model.json' private static readonly _modelMetadataFileName = 'model.json'
private static readonly _supportedModelFormat = '.gguf'
private static readonly _incompletedModelFileName = '.download'
private static readonly _offlineInferenceEngine = InferenceEngine.nitro
/** /**
* Implements type from JanExtension. * Implements type from JanExtension.
@ -54,10 +59,10 @@ export default class JanModelExtension implements ModelExtension {
// copy models folder from resources to home directory // copy models folder from resources to home directory
const resourePath = await getResourcePath() const resourePath = await getResourcePath()
const srcPath = join(resourePath, 'models') const srcPath = await joinPath([resourePath, 'models'])
const userSpace = await getUserSpace() const userSpace = await getUserSpace()
const destPath = join(userSpace, JanModelExtension._homeDir) const destPath = await joinPath([userSpace, JanModelExtension._homeDir])
await fs.syncFile(srcPath, destPath) await fs.syncFile(srcPath, destPath)
@ -88,11 +93,18 @@ export default class JanModelExtension implements ModelExtension {
*/ */
async downloadModel(model: Model): Promise<void> { async downloadModel(model: Model): Promise<void> {
// create corresponding directory // create corresponding directory
const directoryPath = join(JanModelExtension._homeDir, model.id) const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
await fs.mkdir(directoryPath) await fs.mkdir(modelDirPath)
// path to model binary // try to retrieve the download file name from the source url
const path = join(directoryPath, model.id) // if it fails, use the model ID as the file name
const extractedFileName = basename(model.source_url)
const fileName = extractedFileName
.toLowerCase()
.endsWith(JanModelExtension._supportedModelFormat)
? extractedFileName
: model.id
const path = await joinPath([modelDirPath, fileName])
downloadFile(model.source_url, path) downloadFile(model.source_url, path)
} }
@ -103,10 +115,12 @@ export default class JanModelExtension implements ModelExtension {
*/ */
async cancelModelDownload(modelId: string): Promise<void> { async cancelModelDownload(modelId: string): Promise<void> {
return abortDownload( return abortDownload(
join(JanModelExtension._homeDir, modelId, modelId) await joinPath([JanModelExtension._homeDir, modelId, modelId])
).then(() => { ).then(async () =>
fs.deleteFile(join(JanModelExtension._homeDir, modelId, modelId)) fs.deleteFile(
}) await joinPath([JanModelExtension._homeDir, modelId, modelId])
)
)
} }
/** /**
@ -116,27 +130,16 @@ export default class JanModelExtension implements ModelExtension {
*/ */
async deleteModel(modelId: string): Promise<void> { async deleteModel(modelId: string): Promise<void> {
try { try {
const dirPath = join(JanModelExtension._homeDir, modelId) const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
// remove all files under dirPath except model.json // remove all files under dirPath except model.json
const files = await fs.listFiles(dirPath) const files = await fs.listFiles(dirPath)
const deletePromises = files.map((fileName: string) => { const deletePromises = files.map(async (fileName: string) => {
if (fileName !== JanModelExtension._modelMetadataFileName) { if (fileName !== JanModelExtension._modelMetadataFileName) {
return fs.deleteFile(join(dirPath, fileName)) return fs.deleteFile(await joinPath([dirPath, fileName]))
} }
}) })
await Promise.allSettled(deletePromises) await Promise.allSettled(deletePromises)
// update the state as default
const jsonFilePath = join(
dirPath,
JanModelExtension._modelMetadataFileName
)
const json = await fs.readFile(jsonFilePath)
const model = JSON.parse(json) as Model
delete model.state
await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2))
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
@ -148,24 +151,14 @@ export default class JanModelExtension implements ModelExtension {
* @returns A Promise that resolves when the model is saved. * @returns A Promise that resolves when the model is saved.
*/ */
async saveModel(model: Model): Promise<void> { async saveModel(model: Model): Promise<void> {
const jsonFilePath = join( const jsonFilePath = await joinPath([
JanModelExtension._homeDir, JanModelExtension._homeDir,
model.id, model.id,
JanModelExtension._modelMetadataFileName JanModelExtension._modelMetadataFileName,
) ])
try { try {
await fs.writeFile( await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2))
jsonFilePath,
JSON.stringify(
{
...model,
state: ModelState.Ready,
},
null,
2
)
)
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
@ -176,11 +169,34 @@ export default class JanModelExtension implements ModelExtension {
* @returns A Promise that resolves with an array of all models. * @returns A Promise that resolves with an array of all models.
*/ */
async getDownloadedModels(): Promise<Model[]> { async getDownloadedModels(): Promise<Model[]> {
const models = await this.getModelsMetadata() return await this.getModelsMetadata(
return models.filter((model) => model.state === ModelState.Ready) async (modelDir: string, model: Model) => {
if (model.engine !== JanModelExtension._offlineInferenceEngine) {
return true
}
return await fs
.listFiles(await joinPath([JanModelExtension._homeDir, modelDir]))
.then((files: string[]) => {
// or model binary exists in the directory
// model binary name can match model ID or be a .gguf file and not be an incompleted model file
return (
files.includes(modelDir) ||
files.some(
(file) =>
file
.toLowerCase()
.includes(JanModelExtension._supportedModelFormat) &&
!file.endsWith(JanModelExtension._incompletedModelFileName)
)
)
})
}
)
} }
private async getModelsMetadata(): Promise<Model[]> { private async getModelsMetadata(
selector?: (path: string, model: Model) => Promise<boolean>
): Promise<Model[]> {
try { try {
const filesUnderJanRoot = await fs.listFiles('') const filesUnderJanRoot = await fs.listFiles('')
if (!filesUnderJanRoot.includes(JanModelExtension._homeDir)) { if (!filesUnderJanRoot.includes(JanModelExtension._homeDir)) {
@ -193,26 +209,35 @@ export default class JanModelExtension implements ModelExtension {
const allDirectories: string[] = [] const allDirectories: string[] = []
for (const file of files) { for (const file of files) {
const isDirectory = await fs.isDirectory( const isDirectory = await fs.isDirectory(
join(JanModelExtension._homeDir, file) await joinPath([JanModelExtension._homeDir, file])
) )
if (isDirectory) { if (isDirectory) {
allDirectories.push(file) allDirectories.push(file)
} }
} }
const readJsonPromises = allDirectories.map((dirName) => { const readJsonPromises = allDirectories.map(async (dirName) => {
const jsonPath = join( // filter out directories that don't match the selector
// read model.json
const jsonPath = await joinPath([
JanModelExtension._homeDir, JanModelExtension._homeDir,
dirName, dirName,
JanModelExtension._modelMetadataFileName JanModelExtension._modelMetadataFileName,
) ])
return this.readModelMetadata(jsonPath) let model = await this.readModelMetadata(jsonPath)
model = typeof model === 'object' ? model : JSON.parse(model)
if (selector && !(await selector?.(dirName, model))) {
return
}
return model
}) })
const results = await Promise.allSettled(readJsonPromises) const results = await Promise.allSettled(readJsonPromises)
const modelData = results.map((result) => { const modelData = results.map((result) => {
if (result.status === 'fulfilled') { if (result.status === 'fulfilled') {
try { try {
return JSON.parse(result.value) as Model return result.value as Model
} catch { } catch {
console.debug(`Unable to parse model metadata: ${result.value}`) console.debug(`Unable to parse model metadata: ${result.value}`)
return undefined return undefined
@ -230,7 +255,7 @@ export default class JanModelExtension implements ModelExtension {
} }
private readModelMetadata(path: string) { private readModelMetadata(path: string) {
return fs.readFile(join(path)) return fs.readFile(path)
} }
/** /**

View File

@ -1,7 +1,5 @@
import { Fragment } from 'react' import { Fragment } from 'react'
import { ExtensionType } from '@janhq/core'
import { ModelExtension } from '@janhq/core'
import { import {
Progress, Progress,
Modal, Modal,
@ -12,14 +10,19 @@ import {
ModalTrigger, ModalTrigger,
} from '@janhq/uikit' } from '@janhq/uikit'
import { useAtomValue } from 'jotai'
import useDownloadModel from '@/hooks/useDownloadModel'
import { useDownloadState } from '@/hooks/useDownloadState' import { useDownloadState } from '@/hooks/useDownloadState'
import { formatDownloadPercentage } from '@/utils/converter' import { formatDownloadPercentage } from '@/utils/converter'
import { extensionManager } from '@/extension' import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
export default function DownloadingState() { export default function DownloadingState() {
const { downloadStates } = useDownloadState() const { downloadStates } = useDownloadState()
const downloadingModels = useAtomValue(downloadingModelsAtom)
const { abortModelDownload } = useDownloadModel()
const totalCurrentProgress = downloadStates const totalCurrentProgress = downloadStates
.map((a) => a.size.transferred + a.size.transferred) .map((a) => a.size.transferred + a.size.transferred)
@ -73,9 +76,10 @@ export default function DownloadingState() {
size="sm" size="sm"
onClick={() => { onClick={() => {
if (item?.modelId) { if (item?.modelId) {
extensionManager const model = downloadingModels.find(
.get<ModelExtension>(ExtensionType.Model) (model) => model.id === item.modelId
?.cancelModelDownload(item.modelId) )
if (model) abortModelDownload(model)
} }
}} }}
> >

View File

@ -1,6 +1,5 @@
import { useMemo } from 'react' import { useMemo } from 'react'
import { ModelExtension, ExtensionType } from '@janhq/core'
import { Model } from '@janhq/core' import { Model } from '@janhq/core'
import { import {
@ -17,11 +16,12 @@ import {
import { atom, useAtomValue } from 'jotai' import { atom, useAtomValue } from 'jotai'
import useDownloadModel from '@/hooks/useDownloadModel'
import { useDownloadState } from '@/hooks/useDownloadState' import { useDownloadState } from '@/hooks/useDownloadState'
import { formatDownloadPercentage } from '@/utils/converter' import { formatDownloadPercentage } from '@/utils/converter'
import { extensionManager } from '@/extension' import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
type Props = { type Props = {
model: Model model: Model
@ -30,6 +30,7 @@ type Props = {
export default function ModalCancelDownload({ model, isFromList }: Props) { export default function ModalCancelDownload({ model, isFromList }: Props) {
const { modelDownloadStateAtom } = useDownloadState() const { modelDownloadStateAtom } = useDownloadState()
const downloadingModels = useAtomValue(downloadingModelsAtom)
const downloadAtom = useMemo( const downloadAtom = useMemo(
() => atom((get) => get(modelDownloadStateAtom)[model.id]), () => atom((get) => get(modelDownloadStateAtom)[model.id]),
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
@ -37,6 +38,7 @@ export default function ModalCancelDownload({ model, isFromList }: Props) {
) )
const downloadState = useAtomValue(downloadAtom) const downloadState = useAtomValue(downloadAtom)
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}` const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
const { abortModelDownload } = useDownloadModel()
return ( return (
<Modal> <Modal>
@ -80,9 +82,10 @@ export default function ModalCancelDownload({ model, isFromList }: Props) {
themes="danger" themes="danger"
onClick={() => { onClick={() => {
if (downloadState?.modelId) { if (downloadState?.modelId) {
extensionManager const model = downloadingModels.find(
.get<ModelExtension>(ExtensionType.Model) (model) => model.id === downloadState.modelId
?.cancelModelDownload(downloadState.modelId) )
if (model) abortModelDownload(model)
} }
}} }}
> >

View File

@ -1,34 +1,35 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { PropsWithChildren, useEffect, useRef } from 'react' import { basename } from 'path'
import { ExtensionType } from '@janhq/core' import { PropsWithChildren, useEffect, useRef } from 'react'
import { ModelExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { useDownloadState } from '@/hooks/useDownloadState' import { useDownloadState } from '@/hooks/useDownloadState'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { modelBinFileName } from '@/utils/model'
import EventHandler from './EventHandler' import EventHandler from './EventHandler'
import { appDownloadProgress } from './Jotai' import { appDownloadProgress } from './Jotai'
import { extensionManager } from '@/extension/ExtensionManager'
import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
export default function EventListenerWrapper({ children }: PropsWithChildren) { export default function EventListenerWrapper({ children }: PropsWithChildren) {
const setProgress = useSetAtom(appDownloadProgress) const setProgress = useSetAtom(appDownloadProgress)
const models = useAtomValue(downloadingModelsAtom) const models = useAtomValue(downloadingModelsAtom)
const modelsRef = useRef(models) const modelsRef = useRef(models)
useEffect(() => {
modelsRef.current = models
}, [models])
const { setDownloadedModels, downloadedModels } = useGetDownloadedModels() const { setDownloadedModels, downloadedModels } = useGetDownloadedModels()
const { setDownloadState, setDownloadStateSuccess, setDownloadStateFailed } = const { setDownloadState, setDownloadStateSuccess, setDownloadStateFailed } =
useDownloadState() useDownloadState()
const downloadedModelRef = useRef(downloadedModels) const downloadedModelRef = useRef(downloadedModels)
useEffect(() => {
modelsRef.current = models
}, [models])
useEffect(() => { useEffect(() => {
downloadedModelRef.current = downloadedModels downloadedModelRef.current = downloadedModels
}, [downloadedModels]) }, [downloadedModels])
@ -38,40 +39,36 @@ export default function EventListenerWrapper({ children }: PropsWithChildren) {
window.electronAPI.onFileDownloadUpdate( window.electronAPI.onFileDownloadUpdate(
(_event: string, state: any | undefined) => { (_event: string, state: any | undefined) => {
if (!state) return if (!state) return
const model = modelsRef.current.find(
(model) => modelBinFileName(model) === basename(state.fileName)
)
if (model)
setDownloadState({ setDownloadState({
...state, ...state,
modelId: state.fileName.split('/').pop() ?? '', modelId: model.id,
}) })
} }
) )
window.electronAPI.onFileDownloadError( window.electronAPI.onFileDownloadError((_event: string, state: any) => {
(_event: string, callback: any) => { console.error('Download error', state)
console.error('Download error', callback) const model = modelsRef.current.find(
const modelId = callback.fileName.split('/').pop() ?? '' (model) => modelBinFileName(model) === basename(state.fileName)
setDownloadStateFailed(modelId)
}
) )
if (model) setDownloadStateFailed(model.id)
})
window.electronAPI.onFileDownloadSuccess( window.electronAPI.onFileDownloadSuccess((_event: string, state: any) => {
(_event: string, callback: any) => { if (state && state.fileName) {
if (callback && callback.fileName) { const model = modelsRef.current.find(
const modelId = callback.fileName.split('/').pop() ?? '' (model) => modelBinFileName(model) === basename(state.fileName)
)
const model = modelsRef.current.find((e) => e.id === modelId) if (model) {
setDownloadStateSuccess(model.id)
setDownloadStateSuccess(modelId)
if (model)
extensionManager
.get<ModelExtension>(ExtensionType.Model)
?.saveModel(model)
.then(() => {
setDownloadedModels([...downloadedModelRef.current, model]) setDownloadedModels([...downloadedModelRef.current, model])
}
}
}) })
}
}
)
window.electronAPI.onAppUpdateDownloadUpdate( window.electronAPI.onAppUpdateDownloadUpdate(
(_event: string, progress: any) => { (_event: string, progress: any) => {

View File

@ -1,7 +1,15 @@
import { Model, ExtensionType, ModelExtension } from '@janhq/core' import {
Model,
ExtensionType,
ModelExtension,
abortDownload,
joinPath,
} from '@janhq/core'
import { useSetAtom } from 'jotai' import { useSetAtom } from 'jotai'
import { modelBinFileName } from '@/utils/model'
import { useDownloadState } from './useDownloadState' import { useDownloadState } from './useDownloadState'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
@ -33,8 +41,14 @@ export default function useDownloadModel() {
.get<ModelExtension>(ExtensionType.Model) .get<ModelExtension>(ExtensionType.Model)
?.downloadModel(model) ?.downloadModel(model)
} }
const abortModelDownload = async (model: Model) => {
await abortDownload(
await joinPath(['models', model.id, modelBinFileName(model)])
)
}
return { return {
downloadModel, downloadModel,
abortModelDownload,
} }
} }

View File

@ -1,10 +1,10 @@
import { join } from 'path' import { fs, joinPath } from '@janhq/core'
import { fs } from '@janhq/core'
export const useEngineSettings = () => { export const useEngineSettings = () => {
const readOpenAISettings = async () => { const readOpenAISettings = async () => {
const settings = await fs.readFile(join('engines', 'openai.json')) const settings = await fs.readFile(
await joinPath(['engines', 'openai.json'])
)
if (settings) { if (settings) {
return JSON.parse(settings) return JSON.parse(settings)
} }
@ -17,7 +17,10 @@ export const useEngineSettings = () => {
}) => { }) => {
const settings = await readOpenAISettings() const settings = await readOpenAISettings()
settings.api_key = apiKey settings.api_key = apiKey
await fs.writeFile(join('engines', 'openai.json'), JSON.stringify(settings)) await fs.writeFile(
await joinPath(['engines', 'openai.json']),
JSON.stringify(settings)
)
} }
return { readOpenAISettings, saveOpenAISettings } return { readOpenAISettings, saveOpenAISettings }
} }

View File

@ -116,7 +116,7 @@ const ChatBody: React.FC = () => {
) : ( ) : (
<ScrollToBottom className="flex h-full w-full flex-col"> <ScrollToBottom className="flex h-full w-full flex-col">
{messages.map((message, index) => ( {messages.map((message, index) => (
<> <div key={message.id}>
<ChatItem {...message} key={message.id} /> <ChatItem {...message} key={message.id} />
{message.status === MessageStatus.Error && {message.status === MessageStatus.Error &&
@ -126,8 +126,8 @@ const ChatBody: React.FC = () => {
className="mt-10 flex flex-col items-center" className="mt-10 flex flex-col items-center"
> >
<span className="mb-3 text-center text-sm font-medium text-gray-500"> <span className="mb-3 text-center text-sm font-medium text-gray-500">
Oops! The generation was interrupted. Let&apos;s Oops! The generation was interrupted. Let&apos;s give it
give it another go! another go!
</span> </span>
<Button <Button
className="w-min" className="w-min"
@ -140,7 +140,7 @@ const ChatBody: React.FC = () => {
</Button> </Button>
</div> </div>
)} )}
</> </div>
))} ))}
</ScrollToBottom> </ScrollToBottom>
)} )}

12
web/utils/model.ts Normal file
View File

@ -0,0 +1,12 @@
import { basename } from 'path'
import { Model } from '@janhq/core'
export const modelBinFileName = (model: Model) => {
const modelFormatExt = '.gguf'
const extractedFileName = basename(model.source_url) ?? model.id
const fileName = extractedFileName.toLowerCase().endsWith(modelFormatExt)
? extractedFileName
: model.id
return fileName
}

View File

@ -22,8 +22,7 @@ export const toRuntimeParams = (
for (const [key, value] of Object.entries(modelParams)) { for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) { if (key in defaultModelParams) {
// @ts-ignore runtimeParams[key as keyof ModelRuntimeParams] = value
runtimeParams[key] = value
} }
} }
@ -46,8 +45,7 @@ export const toSettingParams = (
for (const [key, value] of Object.entries(modelParams)) { for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) { if (key in defaultSettingParams) {
// @ts-ignore settingParams[key as keyof ModelSettingParams] = value
settingParams[key] = value
} }
} }