feat: Deprecate model.json ready state in favor of .download ext (#1238)
* feat: Deprecate model.json ready state in favor of .download ext * refactor: resolve ts ignore * chore: fix warning * fix: path polyfill on Windows
This commit is contained in:
parent
cbc63da831
commit
7feaf9694d
@ -67,13 +67,6 @@ export type Model = {
|
||||
*/
|
||||
description: string
|
||||
|
||||
/**
|
||||
* The model state.
|
||||
* Default: "to_download"
|
||||
* Enum: "to_download" "downloading" "ready" "running"
|
||||
*/
|
||||
state?: ModelState
|
||||
|
||||
/**
|
||||
* The model settings.
|
||||
*/
|
||||
@ -101,15 +94,6 @@ export type ModelMetadata = {
|
||||
cover?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* The Model transition states.
|
||||
*/
|
||||
export enum ModelState {
|
||||
Downloading = 'downloading',
|
||||
Ready = 'ready',
|
||||
Running = 'running',
|
||||
}
|
||||
|
||||
/**
|
||||
* The available model settings.
|
||||
*/
|
||||
|
||||
@ -3,7 +3,7 @@ import { DownloadManager } from './../managers/download'
|
||||
import { resolve, join } from 'path'
|
||||
import { WindowManager } from './../managers/window'
|
||||
import request from 'request'
|
||||
import { createWriteStream } from 'fs'
|
||||
import { createWriteStream, renameSync } from 'fs'
|
||||
import { DownloadEvent, DownloadRoute } from '@janhq/core'
|
||||
const progress = require('request-progress')
|
||||
|
||||
@ -48,6 +48,8 @@ export function handleDownloaderIPCs() {
|
||||
const userDataPath = join(app.getPath('home'), 'jan')
|
||||
const destination = resolve(userDataPath, fileName)
|
||||
const rq = request(url)
|
||||
// downloading file to a temp file first
|
||||
const downloadingTempFile = `${destination}.download`
|
||||
|
||||
progress(rq, {})
|
||||
.on('progress', function (state: any) {
|
||||
@ -70,6 +72,9 @@ export function handleDownloaderIPCs() {
|
||||
})
|
||||
.on('end', function () {
|
||||
if (DownloadManager.instance.networkRequests[fileName]) {
|
||||
// Finished downloading, rename temp file to actual file
|
||||
renameSync(downloadingTempFile, destination)
|
||||
|
||||
WindowManager?.instance.currentWindow?.webContents.send(
|
||||
DownloadEvent.onFileDownloadSuccess,
|
||||
{
|
||||
@ -87,7 +92,7 @@ export function handleDownloaderIPCs() {
|
||||
)
|
||||
}
|
||||
})
|
||||
.pipe(createWriteStream(destination))
|
||||
.pipe(createWriteStream(downloadingTempFile))
|
||||
|
||||
DownloadManager.instance.setRequest(fileName, rq)
|
||||
})
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import { ExtensionType, fs } from '@janhq/core'
|
||||
import { ExtensionType, fs, joinPath } from '@janhq/core'
|
||||
import { ConversationalExtension } from '@janhq/core'
|
||||
import { Thread, ThreadMessage } from '@janhq/core'
|
||||
import { join } from 'path'
|
||||
|
||||
/**
|
||||
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
||||
@ -69,14 +68,14 @@ export default class JSONConversationalExtension
|
||||
*/
|
||||
async saveThread(thread: Thread): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = join(
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
thread.id
|
||||
)
|
||||
const threadJsonPath = join(
|
||||
thread.id,
|
||||
])
|
||||
const threadJsonPath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadInfoFileName
|
||||
)
|
||||
JSONConversationalExtension._threadInfoFileName,
|
||||
])
|
||||
await fs.mkdir(threadDirPath)
|
||||
await fs.writeFile(threadJsonPath, JSON.stringify(thread, null, 2))
|
||||
Promise.resolve()
|
||||
@ -89,20 +88,22 @@ export default class JSONConversationalExtension
|
||||
* Delete a thread with the specified ID.
|
||||
* @param threadId The ID of the thread to delete.
|
||||
*/
|
||||
deleteThread(threadId: string): Promise<void> {
|
||||
return fs.rmdir(join(JSONConversationalExtension._homeDir, `${threadId}`))
|
||||
async deleteThread(threadId: string): Promise<void> {
|
||||
return fs.rmdir(
|
||||
await joinPath([JSONConversationalExtension._homeDir, `${threadId}`])
|
||||
)
|
||||
}
|
||||
|
||||
async addNewMessage(message: ThreadMessage): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = join(
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
message.thread_id
|
||||
)
|
||||
const threadMessagePath = join(
|
||||
message.thread_id,
|
||||
])
|
||||
const threadMessagePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName
|
||||
)
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
await fs.mkdir(threadDirPath)
|
||||
await fs.appendFile(threadMessagePath, JSON.stringify(message) + '\n')
|
||||
Promise.resolve()
|
||||
@ -116,11 +117,14 @@ export default class JSONConversationalExtension
|
||||
messages: ThreadMessage[]
|
||||
): Promise<void> {
|
||||
try {
|
||||
const threadDirPath = join(JSONConversationalExtension._homeDir, threadId)
|
||||
const threadMessagePath = join(
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
threadId,
|
||||
])
|
||||
const threadMessagePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName
|
||||
)
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
await fs.mkdir(threadDirPath)
|
||||
await fs.writeFile(
|
||||
threadMessagePath,
|
||||
@ -140,11 +144,11 @@ export default class JSONConversationalExtension
|
||||
*/
|
||||
private async readThread(threadDirName: string): Promise<any> {
|
||||
return fs.readFile(
|
||||
join(
|
||||
await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
threadDirName,
|
||||
JSONConversationalExtension._threadInfoFileName
|
||||
)
|
||||
JSONConversationalExtension._threadInfoFileName,
|
||||
])
|
||||
)
|
||||
}
|
||||
|
||||
@ -159,10 +163,10 @@ export default class JSONConversationalExtension
|
||||
|
||||
const threadDirs: string[] = []
|
||||
for (let i = 0; i < fileInsideThread.length; i++) {
|
||||
const path = join(
|
||||
const path = await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
fileInsideThread[i]
|
||||
)
|
||||
fileInsideThread[i],
|
||||
])
|
||||
const isDirectory = await fs.isDirectory(path)
|
||||
if (!isDirectory) {
|
||||
console.debug(`Ignore ${path} because it is not a directory`)
|
||||
@ -184,7 +188,10 @@ export default class JSONConversationalExtension
|
||||
|
||||
async getAllMessages(threadId: string): Promise<ThreadMessage[]> {
|
||||
try {
|
||||
const threadDirPath = join(JSONConversationalExtension._homeDir, threadId)
|
||||
const threadDirPath = await joinPath([
|
||||
JSONConversationalExtension._homeDir,
|
||||
threadId,
|
||||
])
|
||||
const isDir = await fs.isDirectory(threadDirPath)
|
||||
if (!isDir) {
|
||||
throw Error(`${threadDirPath} is not directory`)
|
||||
@ -197,10 +204,10 @@ export default class JSONConversationalExtension
|
||||
throw Error(`${threadDirPath} not contains message file`)
|
||||
}
|
||||
|
||||
const messageFilePath = join(
|
||||
const messageFilePath = await joinPath([
|
||||
threadDirPath,
|
||||
JSONConversationalExtension._threadMessagesFileName
|
||||
)
|
||||
JSONConversationalExtension._threadMessagesFileName,
|
||||
])
|
||||
|
||||
const result = await fs.readLineByLine(messageFilePath)
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
|
||||
return;
|
||||
}
|
||||
const userSpacePath = await getUserSpace();
|
||||
const modelFullPath = join(userSpacePath, "models", model.id, model.id);
|
||||
const modelFullPath = join(userSpacePath, "models", model.id);
|
||||
|
||||
const nitroInitResult = await executeOnMain(MODULE, "initModel", {
|
||||
modelFullPath: modelFullPath,
|
||||
|
||||
@ -13,10 +13,11 @@ const NITRO_HTTP_LOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/
|
||||
const NITRO_HTTP_UNLOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/unloadModel`;
|
||||
const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/modelstatus`;
|
||||
const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy`;
|
||||
const SUPPORTED_MODEL_FORMAT = ".gguf";
|
||||
|
||||
// The subprocess instance for Nitro
|
||||
let subprocess = undefined;
|
||||
let currentModelFile = undefined;
|
||||
let currentModelFile: string = undefined;
|
||||
let currentSettings = undefined;
|
||||
|
||||
/**
|
||||
@ -37,6 +38,17 @@ function stopModel(): Promise<void> {
|
||||
*/
|
||||
async function initModel(wrapper: any): Promise<ModelOperationResponse> {
|
||||
currentModelFile = wrapper.modelFullPath;
|
||||
const files: string[] = fs.readdirSync(currentModelFile);
|
||||
|
||||
// Look for GGUF model file
|
||||
const ggufBinFile = files.find(
|
||||
(file) =>
|
||||
file === path.basename(currentModelFile) ||
|
||||
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT)
|
||||
);
|
||||
|
||||
currentModelFile = path.join(currentModelFile, ggufBinFile);
|
||||
|
||||
if (wrapper.model.engine !== "nitro") {
|
||||
return Promise.resolve({ error: "Not a nitro model" });
|
||||
} else {
|
||||
@ -66,25 +78,26 @@ async function initModel(wrapper: any): Promise<ModelOperationResponse> {
|
||||
async function loadModel(nitroResourceProbe: any | undefined) {
|
||||
// Gather system information for CPU physical cores and memory
|
||||
if (!nitroResourceProbe) nitroResourceProbe = await getResourcesInfo();
|
||||
return killSubprocess()
|
||||
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
|
||||
// wait for 500ms to make sure the port is free for windows platform
|
||||
.then(() => {
|
||||
if (process.platform === "win32") {
|
||||
return sleep(500);
|
||||
}
|
||||
else {
|
||||
return sleep(0);
|
||||
}
|
||||
})
|
||||
.then(() => spawnNitroProcess(nitroResourceProbe))
|
||||
.then(() => loadLLMModel(currentSettings))
|
||||
.then(validateModelStatus)
|
||||
.catch((err) => {
|
||||
console.error("error: ", err);
|
||||
// TODO: Broadcast error so app could display proper error message
|
||||
return { error: err, currentModelFile };
|
||||
});
|
||||
return (
|
||||
killSubprocess()
|
||||
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
|
||||
// wait for 500ms to make sure the port is free for windows platform
|
||||
.then(() => {
|
||||
if (process.platform === "win32") {
|
||||
return sleep(500);
|
||||
} else {
|
||||
return sleep(0);
|
||||
}
|
||||
})
|
||||
.then(() => spawnNitroProcess(nitroResourceProbe))
|
||||
.then(() => loadLLMModel(currentSettings))
|
||||
.then(validateModelStatus)
|
||||
.catch((err) => {
|
||||
console.error("error: ", err);
|
||||
// TODO: Broadcast error so app could display proper error message
|
||||
return { error: err, currentModelFile };
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Add function sleep
|
||||
|
||||
@ -5,9 +5,11 @@ import {
|
||||
abortDownload,
|
||||
getResourcePath,
|
||||
getUserSpace,
|
||||
InferenceEngine,
|
||||
joinPath,
|
||||
} from '@janhq/core'
|
||||
import { ModelExtension, Model, ModelState } from '@janhq/core'
|
||||
import { join } from 'path'
|
||||
import { basename } from 'path'
|
||||
import { ModelExtension, Model } from '@janhq/core'
|
||||
|
||||
/**
|
||||
* A extension for models
|
||||
@ -15,6 +17,9 @@ import { join } from 'path'
|
||||
export default class JanModelExtension implements ModelExtension {
|
||||
private static readonly _homeDir = 'models'
|
||||
private static readonly _modelMetadataFileName = 'model.json'
|
||||
private static readonly _supportedModelFormat = '.gguf'
|
||||
private static readonly _incompletedModelFileName = '.download'
|
||||
private static readonly _offlineInferenceEngine = InferenceEngine.nitro
|
||||
|
||||
/**
|
||||
* Implements type from JanExtension.
|
||||
@ -54,10 +59,10 @@ export default class JanModelExtension implements ModelExtension {
|
||||
|
||||
// copy models folder from resources to home directory
|
||||
const resourePath = await getResourcePath()
|
||||
const srcPath = join(resourePath, 'models')
|
||||
const srcPath = await joinPath([resourePath, 'models'])
|
||||
|
||||
const userSpace = await getUserSpace()
|
||||
const destPath = join(userSpace, JanModelExtension._homeDir)
|
||||
const destPath = await joinPath([userSpace, JanModelExtension._homeDir])
|
||||
|
||||
await fs.syncFile(srcPath, destPath)
|
||||
|
||||
@ -88,11 +93,18 @@ export default class JanModelExtension implements ModelExtension {
|
||||
*/
|
||||
async downloadModel(model: Model): Promise<void> {
|
||||
// create corresponding directory
|
||||
const directoryPath = join(JanModelExtension._homeDir, model.id)
|
||||
await fs.mkdir(directoryPath)
|
||||
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
||||
await fs.mkdir(modelDirPath)
|
||||
|
||||
// path to model binary
|
||||
const path = join(directoryPath, model.id)
|
||||
// try to retrieve the download file name from the source url
|
||||
// if it fails, use the model ID as the file name
|
||||
const extractedFileName = basename(model.source_url)
|
||||
const fileName = extractedFileName
|
||||
.toLowerCase()
|
||||
.endsWith(JanModelExtension._supportedModelFormat)
|
||||
? extractedFileName
|
||||
: model.id
|
||||
const path = await joinPath([modelDirPath, fileName])
|
||||
downloadFile(model.source_url, path)
|
||||
}
|
||||
|
||||
@ -103,10 +115,12 @@ export default class JanModelExtension implements ModelExtension {
|
||||
*/
|
||||
async cancelModelDownload(modelId: string): Promise<void> {
|
||||
return abortDownload(
|
||||
join(JanModelExtension._homeDir, modelId, modelId)
|
||||
).then(() => {
|
||||
fs.deleteFile(join(JanModelExtension._homeDir, modelId, modelId))
|
||||
})
|
||||
await joinPath([JanModelExtension._homeDir, modelId, modelId])
|
||||
).then(async () =>
|
||||
fs.deleteFile(
|
||||
await joinPath([JanModelExtension._homeDir, modelId, modelId])
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -116,27 +130,16 @@ export default class JanModelExtension implements ModelExtension {
|
||||
*/
|
||||
async deleteModel(modelId: string): Promise<void> {
|
||||
try {
|
||||
const dirPath = join(JanModelExtension._homeDir, modelId)
|
||||
const dirPath = await joinPath([JanModelExtension._homeDir, modelId])
|
||||
|
||||
// remove all files under dirPath except model.json
|
||||
const files = await fs.listFiles(dirPath)
|
||||
const deletePromises = files.map((fileName: string) => {
|
||||
const deletePromises = files.map(async (fileName: string) => {
|
||||
if (fileName !== JanModelExtension._modelMetadataFileName) {
|
||||
return fs.deleteFile(join(dirPath, fileName))
|
||||
return fs.deleteFile(await joinPath([dirPath, fileName]))
|
||||
}
|
||||
})
|
||||
await Promise.allSettled(deletePromises)
|
||||
|
||||
// update the state as default
|
||||
const jsonFilePath = join(
|
||||
dirPath,
|
||||
JanModelExtension._modelMetadataFileName
|
||||
)
|
||||
const json = await fs.readFile(jsonFilePath)
|
||||
const model = JSON.parse(json) as Model
|
||||
delete model.state
|
||||
|
||||
await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2))
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
@ -148,24 +151,14 @@ export default class JanModelExtension implements ModelExtension {
|
||||
* @returns A Promise that resolves when the model is saved.
|
||||
*/
|
||||
async saveModel(model: Model): Promise<void> {
|
||||
const jsonFilePath = join(
|
||||
const jsonFilePath = await joinPath([
|
||||
JanModelExtension._homeDir,
|
||||
model.id,
|
||||
JanModelExtension._modelMetadataFileName
|
||||
)
|
||||
JanModelExtension._modelMetadataFileName,
|
||||
])
|
||||
|
||||
try {
|
||||
await fs.writeFile(
|
||||
jsonFilePath,
|
||||
JSON.stringify(
|
||||
{
|
||||
...model,
|
||||
state: ModelState.Ready,
|
||||
},
|
||||
null,
|
||||
2
|
||||
)
|
||||
)
|
||||
await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2))
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
@ -176,11 +169,34 @@ export default class JanModelExtension implements ModelExtension {
|
||||
* @returns A Promise that resolves with an array of all models.
|
||||
*/
|
||||
async getDownloadedModels(): Promise<Model[]> {
|
||||
const models = await this.getModelsMetadata()
|
||||
return models.filter((model) => model.state === ModelState.Ready)
|
||||
return await this.getModelsMetadata(
|
||||
async (modelDir: string, model: Model) => {
|
||||
if (model.engine !== JanModelExtension._offlineInferenceEngine) {
|
||||
return true
|
||||
}
|
||||
return await fs
|
||||
.listFiles(await joinPath([JanModelExtension._homeDir, modelDir]))
|
||||
.then((files: string[]) => {
|
||||
// or model binary exists in the directory
|
||||
// model binary name can match model ID or be a .gguf file and not be an incompleted model file
|
||||
return (
|
||||
files.includes(modelDir) ||
|
||||
files.some(
|
||||
(file) =>
|
||||
file
|
||||
.toLowerCase()
|
||||
.includes(JanModelExtension._supportedModelFormat) &&
|
||||
!file.endsWith(JanModelExtension._incompletedModelFileName)
|
||||
)
|
||||
)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private async getModelsMetadata(): Promise<Model[]> {
|
||||
private async getModelsMetadata(
|
||||
selector?: (path: string, model: Model) => Promise<boolean>
|
||||
): Promise<Model[]> {
|
||||
try {
|
||||
const filesUnderJanRoot = await fs.listFiles('')
|
||||
if (!filesUnderJanRoot.includes(JanModelExtension._homeDir)) {
|
||||
@ -193,26 +209,35 @@ export default class JanModelExtension implements ModelExtension {
|
||||
const allDirectories: string[] = []
|
||||
for (const file of files) {
|
||||
const isDirectory = await fs.isDirectory(
|
||||
join(JanModelExtension._homeDir, file)
|
||||
await joinPath([JanModelExtension._homeDir, file])
|
||||
)
|
||||
if (isDirectory) {
|
||||
allDirectories.push(file)
|
||||
}
|
||||
}
|
||||
|
||||
const readJsonPromises = allDirectories.map((dirName) => {
|
||||
const jsonPath = join(
|
||||
const readJsonPromises = allDirectories.map(async (dirName) => {
|
||||
// filter out directories that don't match the selector
|
||||
|
||||
// read model.json
|
||||
const jsonPath = await joinPath([
|
||||
JanModelExtension._homeDir,
|
||||
dirName,
|
||||
JanModelExtension._modelMetadataFileName
|
||||
)
|
||||
return this.readModelMetadata(jsonPath)
|
||||
JanModelExtension._modelMetadataFileName,
|
||||
])
|
||||
let model = await this.readModelMetadata(jsonPath)
|
||||
model = typeof model === 'object' ? model : JSON.parse(model)
|
||||
|
||||
if (selector && !(await selector?.(dirName, model))) {
|
||||
return
|
||||
}
|
||||
return model
|
||||
})
|
||||
const results = await Promise.allSettled(readJsonPromises)
|
||||
const modelData = results.map((result) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
try {
|
||||
return JSON.parse(result.value) as Model
|
||||
return result.value as Model
|
||||
} catch {
|
||||
console.debug(`Unable to parse model metadata: ${result.value}`)
|
||||
return undefined
|
||||
@ -230,7 +255,7 @@ export default class JanModelExtension implements ModelExtension {
|
||||
}
|
||||
|
||||
private readModelMetadata(path: string) {
|
||||
return fs.readFile(join(path))
|
||||
return fs.readFile(path)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import { Fragment } from 'react'
|
||||
|
||||
import { ExtensionType } from '@janhq/core'
|
||||
import { ModelExtension } from '@janhq/core'
|
||||
import {
|
||||
Progress,
|
||||
Modal,
|
||||
@ -12,14 +10,19 @@ import {
|
||||
ModalTrigger,
|
||||
} from '@janhq/uikit'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useDownloadModel from '@/hooks/useDownloadModel'
|
||||
import { useDownloadState } from '@/hooks/useDownloadState'
|
||||
|
||||
import { formatDownloadPercentage } from '@/utils/converter'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
export default function DownloadingState() {
|
||||
const { downloadStates } = useDownloadState()
|
||||
const downloadingModels = useAtomValue(downloadingModelsAtom)
|
||||
const { abortModelDownload } = useDownloadModel()
|
||||
|
||||
const totalCurrentProgress = downloadStates
|
||||
.map((a) => a.size.transferred + a.size.transferred)
|
||||
@ -73,9 +76,10 @@ export default function DownloadingState() {
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (item?.modelId) {
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionType.Model)
|
||||
?.cancelModelDownload(item.modelId)
|
||||
const model = downloadingModels.find(
|
||||
(model) => model.id === item.modelId
|
||||
)
|
||||
if (model) abortModelDownload(model)
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import { useMemo } from 'react'
|
||||
|
||||
import { ModelExtension, ExtensionType } from '@janhq/core'
|
||||
import { Model } from '@janhq/core'
|
||||
|
||||
import {
|
||||
@ -17,11 +16,12 @@ import {
|
||||
|
||||
import { atom, useAtomValue } from 'jotai'
|
||||
|
||||
import useDownloadModel from '@/hooks/useDownloadModel'
|
||||
import { useDownloadState } from '@/hooks/useDownloadState'
|
||||
|
||||
import { formatDownloadPercentage } from '@/utils/converter'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
@ -30,6 +30,7 @@ type Props = {
|
||||
|
||||
export default function ModalCancelDownload({ model, isFromList }: Props) {
|
||||
const { modelDownloadStateAtom } = useDownloadState()
|
||||
const downloadingModels = useAtomValue(downloadingModelsAtom)
|
||||
const downloadAtom = useMemo(
|
||||
() => atom((get) => get(modelDownloadStateAtom)[model.id]),
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
@ -37,6 +38,7 @@ export default function ModalCancelDownload({ model, isFromList }: Props) {
|
||||
)
|
||||
const downloadState = useAtomValue(downloadAtom)
|
||||
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
|
||||
const { abortModelDownload } = useDownloadModel()
|
||||
|
||||
return (
|
||||
<Modal>
|
||||
@ -80,9 +82,10 @@ export default function ModalCancelDownload({ model, isFromList }: Props) {
|
||||
themes="danger"
|
||||
onClick={() => {
|
||||
if (downloadState?.modelId) {
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionType.Model)
|
||||
?.cancelModelDownload(downloadState.modelId)
|
||||
const model = downloadingModels.find(
|
||||
(model) => model.id === downloadState.modelId
|
||||
)
|
||||
if (model) abortModelDownload(model)
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@ -1,34 +1,35 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import { PropsWithChildren, useEffect, useRef } from 'react'
|
||||
import { basename } from 'path'
|
||||
|
||||
import { ExtensionType } from '@janhq/core'
|
||||
import { ModelExtension } from '@janhq/core'
|
||||
import { PropsWithChildren, useEffect, useRef } from 'react'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { useDownloadState } from '@/hooks/useDownloadState'
|
||||
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
|
||||
|
||||
import { modelBinFileName } from '@/utils/model'
|
||||
|
||||
import EventHandler from './EventHandler'
|
||||
|
||||
import { appDownloadProgress } from './Jotai'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
export default function EventListenerWrapper({ children }: PropsWithChildren) {
|
||||
const setProgress = useSetAtom(appDownloadProgress)
|
||||
const models = useAtomValue(downloadingModelsAtom)
|
||||
const modelsRef = useRef(models)
|
||||
useEffect(() => {
|
||||
modelsRef.current = models
|
||||
}, [models])
|
||||
|
||||
const { setDownloadedModels, downloadedModels } = useGetDownloadedModels()
|
||||
const { setDownloadState, setDownloadStateSuccess, setDownloadStateFailed } =
|
||||
useDownloadState()
|
||||
const downloadedModelRef = useRef(downloadedModels)
|
||||
|
||||
useEffect(() => {
|
||||
modelsRef.current = models
|
||||
}, [models])
|
||||
useEffect(() => {
|
||||
downloadedModelRef.current = downloadedModels
|
||||
}, [downloadedModels])
|
||||
@ -38,40 +39,36 @@ export default function EventListenerWrapper({ children }: PropsWithChildren) {
|
||||
window.electronAPI.onFileDownloadUpdate(
|
||||
(_event: string, state: any | undefined) => {
|
||||
if (!state) return
|
||||
setDownloadState({
|
||||
...state,
|
||||
modelId: state.fileName.split('/').pop() ?? '',
|
||||
})
|
||||
const model = modelsRef.current.find(
|
||||
(model) => modelBinFileName(model) === basename(state.fileName)
|
||||
)
|
||||
if (model)
|
||||
setDownloadState({
|
||||
...state,
|
||||
modelId: model.id,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
window.electronAPI.onFileDownloadError(
|
||||
(_event: string, callback: any) => {
|
||||
console.error('Download error', callback)
|
||||
const modelId = callback.fileName.split('/').pop() ?? ''
|
||||
setDownloadStateFailed(modelId)
|
||||
}
|
||||
)
|
||||
window.electronAPI.onFileDownloadError((_event: string, state: any) => {
|
||||
console.error('Download error', state)
|
||||
const model = modelsRef.current.find(
|
||||
(model) => modelBinFileName(model) === basename(state.fileName)
|
||||
)
|
||||
if (model) setDownloadStateFailed(model.id)
|
||||
})
|
||||
|
||||
window.electronAPI.onFileDownloadSuccess(
|
||||
(_event: string, callback: any) => {
|
||||
if (callback && callback.fileName) {
|
||||
const modelId = callback.fileName.split('/').pop() ?? ''
|
||||
|
||||
const model = modelsRef.current.find((e) => e.id === modelId)
|
||||
|
||||
setDownloadStateSuccess(modelId)
|
||||
|
||||
if (model)
|
||||
extensionManager
|
||||
.get<ModelExtension>(ExtensionType.Model)
|
||||
?.saveModel(model)
|
||||
.then(() => {
|
||||
setDownloadedModels([...downloadedModelRef.current, model])
|
||||
})
|
||||
window.electronAPI.onFileDownloadSuccess((_event: string, state: any) => {
|
||||
if (state && state.fileName) {
|
||||
const model = modelsRef.current.find(
|
||||
(model) => modelBinFileName(model) === basename(state.fileName)
|
||||
)
|
||||
if (model) {
|
||||
setDownloadStateSuccess(model.id)
|
||||
setDownloadedModels([...downloadedModelRef.current, model])
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
window.electronAPI.onAppUpdateDownloadUpdate(
|
||||
(_event: string, progress: any) => {
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
import { Model, ExtensionType, ModelExtension } from '@janhq/core'
|
||||
import {
|
||||
Model,
|
||||
ExtensionType,
|
||||
ModelExtension,
|
||||
abortDownload,
|
||||
joinPath,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useSetAtom } from 'jotai'
|
||||
|
||||
import { modelBinFileName } from '@/utils/model'
|
||||
|
||||
import { useDownloadState } from './useDownloadState'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
@ -33,8 +41,14 @@ export default function useDownloadModel() {
|
||||
.get<ModelExtension>(ExtensionType.Model)
|
||||
?.downloadModel(model)
|
||||
}
|
||||
const abortModelDownload = async (model: Model) => {
|
||||
await abortDownload(
|
||||
await joinPath(['models', model.id, modelBinFileName(model)])
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
downloadModel,
|
||||
abortModelDownload,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import { join } from 'path'
|
||||
|
||||
import { fs } from '@janhq/core'
|
||||
import { fs, joinPath } from '@janhq/core'
|
||||
|
||||
export const useEngineSettings = () => {
|
||||
const readOpenAISettings = async () => {
|
||||
const settings = await fs.readFile(join('engines', 'openai.json'))
|
||||
const settings = await fs.readFile(
|
||||
await joinPath(['engines', 'openai.json'])
|
||||
)
|
||||
if (settings) {
|
||||
return JSON.parse(settings)
|
||||
}
|
||||
@ -17,7 +17,10 @@ export const useEngineSettings = () => {
|
||||
}) => {
|
||||
const settings = await readOpenAISettings()
|
||||
settings.api_key = apiKey
|
||||
await fs.writeFile(join('engines', 'openai.json'), JSON.stringify(settings))
|
||||
await fs.writeFile(
|
||||
await joinPath(['engines', 'openai.json']),
|
||||
JSON.stringify(settings)
|
||||
)
|
||||
}
|
||||
return { readOpenAISettings, saveOpenAISettings }
|
||||
}
|
||||
|
||||
@ -116,7 +116,7 @@ const ChatBody: React.FC = () => {
|
||||
) : (
|
||||
<ScrollToBottom className="flex h-full w-full flex-col">
|
||||
{messages.map((message, index) => (
|
||||
<>
|
||||
<div key={message.id}>
|
||||
<ChatItem {...message} key={message.id} />
|
||||
|
||||
{message.status === MessageStatus.Error &&
|
||||
@ -126,8 +126,8 @@ const ChatBody: React.FC = () => {
|
||||
className="mt-10 flex flex-col items-center"
|
||||
>
|
||||
<span className="mb-3 text-center text-sm font-medium text-gray-500">
|
||||
Oops! The generation was interrupted. Let's
|
||||
give it another go!
|
||||
Oops! The generation was interrupted. Let's give it
|
||||
another go!
|
||||
</span>
|
||||
<Button
|
||||
className="w-min"
|
||||
@ -140,7 +140,7 @@ const ChatBody: React.FC = () => {
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
))}
|
||||
</ScrollToBottom>
|
||||
)}
|
||||
|
||||
12
web/utils/model.ts
Normal file
12
web/utils/model.ts
Normal file
@ -0,0 +1,12 @@
|
||||
import { basename } from 'path'
|
||||
|
||||
import { Model } from '@janhq/core'
|
||||
|
||||
export const modelBinFileName = (model: Model) => {
|
||||
const modelFormatExt = '.gguf'
|
||||
const extractedFileName = basename(model.source_url) ?? model.id
|
||||
const fileName = extractedFileName.toLowerCase().endsWith(modelFormatExt)
|
||||
? extractedFileName
|
||||
: model.id
|
||||
return fileName
|
||||
}
|
||||
@ -22,8 +22,7 @@ export const toRuntimeParams = (
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultModelParams) {
|
||||
// @ts-ignore
|
||||
runtimeParams[key] = value
|
||||
runtimeParams[key as keyof ModelRuntimeParams] = value
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,8 +45,7 @@ export const toSettingParams = (
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultSettingParams) {
|
||||
// @ts-ignore
|
||||
settingParams[key] = value
|
||||
settingParams[key as keyof ModelSettingParams] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user