diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index 1dd12c89b..23d27935e 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -67,13 +67,6 @@ export type Model = { */ description: string - /** - * The model state. - * Default: "to_download" - * Enum: "to_download" "downloading" "ready" "running" - */ - state?: ModelState - /** * The model settings. */ @@ -101,15 +94,6 @@ export type ModelMetadata = { cover?: string } -/** - * The Model transition states. - */ -export enum ModelState { - Downloading = 'downloading', - Ready = 'ready', - Running = 'running', -} - /** * The available model settings. */ diff --git a/electron/handlers/download.ts b/electron/handlers/download.ts index 6e64d23e2..145174ac2 100644 --- a/electron/handlers/download.ts +++ b/electron/handlers/download.ts @@ -3,7 +3,7 @@ import { DownloadManager } from './../managers/download' import { resolve, join } from 'path' import { WindowManager } from './../managers/window' import request from 'request' -import { createWriteStream } from 'fs' +import { createWriteStream, renameSync } from 'fs' import { DownloadEvent, DownloadRoute } from '@janhq/core' const progress = require('request-progress') @@ -48,6 +48,8 @@ export function handleDownloaderIPCs() { const userDataPath = join(app.getPath('home'), 'jan') const destination = resolve(userDataPath, fileName) const rq = request(url) + // downloading file to a temp file first + const downloadingTempFile = `${destination}.download` progress(rq, {}) .on('progress', function (state: any) { @@ -70,6 +72,9 @@ export function handleDownloaderIPCs() { }) .on('end', function () { if (DownloadManager.instance.networkRequests[fileName]) { + // Finished downloading, rename temp file to actual file + renameSync(downloadingTempFile, destination) + WindowManager?.instance.currentWindow?.webContents.send( DownloadEvent.onFileDownloadSuccess, { @@ -87,7 +92,7 @@ export function handleDownloaderIPCs() { ) } }) - .pipe(createWriteStream(destination)) + .pipe(createWriteStream(downloadingTempFile)) DownloadManager.instance.setRequest(fileName, rq) }) diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index 0fdf0b2d4..8aae791e8 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -1,7 +1,6 @@ -import { ExtensionType, fs } from '@janhq/core' +import { ExtensionType, fs, joinPath } from '@janhq/core' import { ConversationalExtension } from '@janhq/core' import { Thread, ThreadMessage } from '@janhq/core' -import { join } from 'path' /** * JSONConversationalExtension is a ConversationalExtension implementation that provides @@ -69,14 +68,14 @@ export default class JSONConversationalExtension */ async saveThread(thread: Thread): Promise { try { - const threadDirPath = join( + const threadDirPath = await joinPath([ JSONConversationalExtension._homeDir, - thread.id - ) - const threadJsonPath = join( + thread.id, + ]) + const threadJsonPath = await joinPath([ threadDirPath, - JSONConversationalExtension._threadInfoFileName - ) + JSONConversationalExtension._threadInfoFileName, + ]) await fs.mkdir(threadDirPath) await fs.writeFile(threadJsonPath, JSON.stringify(thread, null, 2)) Promise.resolve() @@ -89,20 +88,22 @@ export default class JSONConversationalExtension * Delete a thread with the specified ID. * @param threadId The ID of the thread to delete. */ - deleteThread(threadId: string): Promise { - return fs.rmdir(join(JSONConversationalExtension._homeDir, `${threadId}`)) + async deleteThread(threadId: string): Promise { + return fs.rmdir( + await joinPath([JSONConversationalExtension._homeDir, `${threadId}`]) + ) } async addNewMessage(message: ThreadMessage): Promise { try { - const threadDirPath = join( + const threadDirPath = await joinPath([ JSONConversationalExtension._homeDir, - message.thread_id - ) - const threadMessagePath = join( + message.thread_id, + ]) + const threadMessagePath = await joinPath([ threadDirPath, - JSONConversationalExtension._threadMessagesFileName - ) + JSONConversationalExtension._threadMessagesFileName, + ]) await fs.mkdir(threadDirPath) await fs.appendFile(threadMessagePath, JSON.stringify(message) + '\n') Promise.resolve() @@ -116,11 +117,14 @@ export default class JSONConversationalExtension messages: ThreadMessage[] ): Promise { try { - const threadDirPath = join(JSONConversationalExtension._homeDir, threadId) - const threadMessagePath = join( + const threadDirPath = await joinPath([ + JSONConversationalExtension._homeDir, + threadId, + ]) + const threadMessagePath = await joinPath([ threadDirPath, - JSONConversationalExtension._threadMessagesFileName - ) + JSONConversationalExtension._threadMessagesFileName, + ]) await fs.mkdir(threadDirPath) await fs.writeFile( threadMessagePath, @@ -140,11 +144,11 @@ export default class JSONConversationalExtension */ private async readThread(threadDirName: string): Promise { return fs.readFile( - join( + await joinPath([ JSONConversationalExtension._homeDir, threadDirName, - JSONConversationalExtension._threadInfoFileName - ) + JSONConversationalExtension._threadInfoFileName, + ]) ) } @@ -159,10 +163,10 @@ export default class JSONConversationalExtension const threadDirs: string[] = [] for (let i = 0; i < fileInsideThread.length; i++) { - const path = join( + const path = await joinPath([ JSONConversationalExtension._homeDir, - fileInsideThread[i] - ) + fileInsideThread[i], + ]) const isDirectory = await fs.isDirectory(path) if (!isDirectory) { console.debug(`Ignore ${path} because it is not a directory`) @@ -184,7 +188,10 @@ export default class JSONConversationalExtension async getAllMessages(threadId: string): Promise { try { - const threadDirPath = join(JSONConversationalExtension._homeDir, threadId) + const threadDirPath = await joinPath([ + JSONConversationalExtension._homeDir, + threadId, + ]) const isDir = await fs.isDirectory(threadDirPath) if (!isDir) { throw Error(`${threadDirPath} is not directory`) @@ -197,10 +204,10 @@ export default class JSONConversationalExtension throw Error(`${threadDirPath} not contains message file`) } - const messageFilePath = join( + const messageFilePath = await joinPath([ threadDirPath, - JSONConversationalExtension._threadMessagesFileName - ) + JSONConversationalExtension._threadMessagesFileName, + ]) const result = await fs.readLineByLine(messageFilePath) diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 946d526dd..d19f3853c 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -111,7 +111,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension { return; } const userSpacePath = await getUserSpace(); - const modelFullPath = join(userSpacePath, "models", model.id, model.id); + const modelFullPath = join(userSpacePath, "models", model.id); const nitroInitResult = await executeOnMain(MODULE, "initModel", { modelFullPath: modelFullPath, diff --git a/extensions/inference-nitro-extension/src/module.ts b/extensions/inference-nitro-extension/src/module.ts index 25836a875..37b9e5b3b 100644 --- a/extensions/inference-nitro-extension/src/module.ts +++ b/extensions/inference-nitro-extension/src/module.ts @@ -13,10 +13,11 @@ const NITRO_HTTP_LOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/ const NITRO_HTTP_UNLOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/unloadModel`; const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/modelstatus`; const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy`; +const SUPPORTED_MODEL_FORMAT = ".gguf"; // The subprocess instance for Nitro let subprocess = undefined; -let currentModelFile = undefined; +let currentModelFile: string = undefined; let currentSettings = undefined; /** @@ -37,6 +38,17 @@ function stopModel(): Promise { */ async function initModel(wrapper: any): Promise { currentModelFile = wrapper.modelFullPath; + const files: string[] = fs.readdirSync(currentModelFile); + + // Look for GGUF model file + const ggufBinFile = files.find( + (file) => + file === path.basename(currentModelFile) || + file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) + ); + + currentModelFile = path.join(currentModelFile, ggufBinFile); + if (wrapper.model.engine !== "nitro") { return Promise.resolve({ error: "Not a nitro model" }); } else { @@ -66,25 +78,26 @@ async function initModel(wrapper: any): Promise { async function loadModel(nitroResourceProbe: any | undefined) { // Gather system information for CPU physical cores and memory if (!nitroResourceProbe) nitroResourceProbe = await getResourcesInfo(); - return killSubprocess() - .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) - // wait for 500ms to make sure the port is free for windows platform - .then(() => { - if (process.platform === "win32") { - return sleep(500); - } - else { - return sleep(0); - } - }) - .then(() => spawnNitroProcess(nitroResourceProbe)) - .then(() => loadLLMModel(currentSettings)) - .then(validateModelStatus) - .catch((err) => { - console.error("error: ", err); - // TODO: Broadcast error so app could display proper error message - return { error: err, currentModelFile }; - }); + return ( + killSubprocess() + .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) + // wait for 500ms to make sure the port is free for windows platform + .then(() => { + if (process.platform === "win32") { + return sleep(500); + } else { + return sleep(0); + } + }) + .then(() => spawnNitroProcess(nitroResourceProbe)) + .then(() => loadLLMModel(currentSettings)) + .then(validateModelStatus) + .catch((err) => { + console.error("error: ", err); + // TODO: Broadcast error so app could display proper error message + return { error: err, currentModelFile }; + }) + ); } // Add function sleep diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 16adced5d..9580afd9b 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -5,9 +5,11 @@ import { abortDownload, getResourcePath, getUserSpace, + InferenceEngine, + joinPath, } from '@janhq/core' -import { ModelExtension, Model, ModelState } from '@janhq/core' -import { join } from 'path' +import { basename } from 'path' +import { ModelExtension, Model } from '@janhq/core' /** * A extension for models @@ -15,6 +17,9 @@ import { join } from 'path' export default class JanModelExtension implements ModelExtension { private static readonly _homeDir = 'models' private static readonly _modelMetadataFileName = 'model.json' + private static readonly _supportedModelFormat = '.gguf' + private static readonly _incompletedModelFileName = '.download' + private static readonly _offlineInferenceEngine = InferenceEngine.nitro /** * Implements type from JanExtension. @@ -54,10 +59,10 @@ export default class JanModelExtension implements ModelExtension { // copy models folder from resources to home directory const resourePath = await getResourcePath() - const srcPath = join(resourePath, 'models') + const srcPath = await joinPath([resourePath, 'models']) const userSpace = await getUserSpace() - const destPath = join(userSpace, JanModelExtension._homeDir) + const destPath = await joinPath([userSpace, JanModelExtension._homeDir]) await fs.syncFile(srcPath, destPath) @@ -88,11 +93,18 @@ export default class JanModelExtension implements ModelExtension { */ async downloadModel(model: Model): Promise { // create corresponding directory - const directoryPath = join(JanModelExtension._homeDir, model.id) - await fs.mkdir(directoryPath) + const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id]) + await fs.mkdir(modelDirPath) - // path to model binary - const path = join(directoryPath, model.id) + // try to retrieve the download file name from the source url + // if it fails, use the model ID as the file name + const extractedFileName = basename(model.source_url) + const fileName = extractedFileName + .toLowerCase() + .endsWith(JanModelExtension._supportedModelFormat) + ? extractedFileName + : model.id + const path = await joinPath([modelDirPath, fileName]) downloadFile(model.source_url, path) } @@ -103,10 +115,12 @@ export default class JanModelExtension implements ModelExtension { */ async cancelModelDownload(modelId: string): Promise { return abortDownload( - join(JanModelExtension._homeDir, modelId, modelId) - ).then(() => { - fs.deleteFile(join(JanModelExtension._homeDir, modelId, modelId)) - }) + await joinPath([JanModelExtension._homeDir, modelId, modelId]) + ).then(async () => + fs.deleteFile( + await joinPath([JanModelExtension._homeDir, modelId, modelId]) + ) + ) } /** @@ -116,27 +130,16 @@ export default class JanModelExtension implements ModelExtension { */ async deleteModel(modelId: string): Promise { try { - const dirPath = join(JanModelExtension._homeDir, modelId) + const dirPath = await joinPath([JanModelExtension._homeDir, modelId]) // remove all files under dirPath except model.json const files = await fs.listFiles(dirPath) - const deletePromises = files.map((fileName: string) => { + const deletePromises = files.map(async (fileName: string) => { if (fileName !== JanModelExtension._modelMetadataFileName) { - return fs.deleteFile(join(dirPath, fileName)) + return fs.deleteFile(await joinPath([dirPath, fileName])) } }) await Promise.allSettled(deletePromises) - - // update the state as default - const jsonFilePath = join( - dirPath, - JanModelExtension._modelMetadataFileName - ) - const json = await fs.readFile(jsonFilePath) - const model = JSON.parse(json) as Model - delete model.state - - await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2)) } catch (err) { console.error(err) } @@ -148,24 +151,14 @@ export default class JanModelExtension implements ModelExtension { * @returns A Promise that resolves when the model is saved. */ async saveModel(model: Model): Promise { - const jsonFilePath = join( + const jsonFilePath = await joinPath([ JanModelExtension._homeDir, model.id, - JanModelExtension._modelMetadataFileName - ) + JanModelExtension._modelMetadataFileName, + ]) try { - await fs.writeFile( - jsonFilePath, - JSON.stringify( - { - ...model, - state: ModelState.Ready, - }, - null, - 2 - ) - ) + await fs.writeFile(jsonFilePath, JSON.stringify(model, null, 2)) } catch (err) { console.error(err) } @@ -176,11 +169,34 @@ export default class JanModelExtension implements ModelExtension { * @returns A Promise that resolves with an array of all models. */ async getDownloadedModels(): Promise { - const models = await this.getModelsMetadata() - return models.filter((model) => model.state === ModelState.Ready) + return await this.getModelsMetadata( + async (modelDir: string, model: Model) => { + if (model.engine !== JanModelExtension._offlineInferenceEngine) { + return true + } + return await fs + .listFiles(await joinPath([JanModelExtension._homeDir, modelDir])) + .then((files: string[]) => { + // or model binary exists in the directory + // model binary name can match model ID or be a .gguf file and not be an incompleted model file + return ( + files.includes(modelDir) || + files.some( + (file) => + file + .toLowerCase() + .includes(JanModelExtension._supportedModelFormat) && + !file.endsWith(JanModelExtension._incompletedModelFileName) + ) + ) + }) + } + ) } - private async getModelsMetadata(): Promise { + private async getModelsMetadata( + selector?: (path: string, model: Model) => Promise + ): Promise { try { const filesUnderJanRoot = await fs.listFiles('') if (!filesUnderJanRoot.includes(JanModelExtension._homeDir)) { @@ -193,26 +209,35 @@ export default class JanModelExtension implements ModelExtension { const allDirectories: string[] = [] for (const file of files) { const isDirectory = await fs.isDirectory( - join(JanModelExtension._homeDir, file) + await joinPath([JanModelExtension._homeDir, file]) ) if (isDirectory) { allDirectories.push(file) } } - const readJsonPromises = allDirectories.map((dirName) => { - const jsonPath = join( + const readJsonPromises = allDirectories.map(async (dirName) => { + // filter out directories that don't match the selector + + // read model.json + const jsonPath = await joinPath([ JanModelExtension._homeDir, dirName, - JanModelExtension._modelMetadataFileName - ) - return this.readModelMetadata(jsonPath) + JanModelExtension._modelMetadataFileName, + ]) + let model = await this.readModelMetadata(jsonPath) + model = typeof model === 'object' ? model : JSON.parse(model) + + if (selector && !(await selector?.(dirName, model))) { + return + } + return model }) const results = await Promise.allSettled(readJsonPromises) const modelData = results.map((result) => { if (result.status === 'fulfilled') { try { - return JSON.parse(result.value) as Model + return result.value as Model } catch { console.debug(`Unable to parse model metadata: ${result.value}`) return undefined @@ -230,7 +255,7 @@ export default class JanModelExtension implements ModelExtension { } private readModelMetadata(path: string) { - return fs.readFile(join(path)) + return fs.readFile(path) } /** diff --git a/web/containers/Layout/BottomBar/DownloadingState/index.tsx b/web/containers/Layout/BottomBar/DownloadingState/index.tsx index 0648508d0..7aef36caf 100644 --- a/web/containers/Layout/BottomBar/DownloadingState/index.tsx +++ b/web/containers/Layout/BottomBar/DownloadingState/index.tsx @@ -1,7 +1,5 @@ import { Fragment } from 'react' -import { ExtensionType } from '@janhq/core' -import { ModelExtension } from '@janhq/core' import { Progress, Modal, @@ -12,14 +10,19 @@ import { ModalTrigger, } from '@janhq/uikit' +import { useAtomValue } from 'jotai' + +import useDownloadModel from '@/hooks/useDownloadModel' import { useDownloadState } from '@/hooks/useDownloadState' import { formatDownloadPercentage } from '@/utils/converter' -import { extensionManager } from '@/extension' +import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom' export default function DownloadingState() { const { downloadStates } = useDownloadState() + const downloadingModels = useAtomValue(downloadingModelsAtom) + const { abortModelDownload } = useDownloadModel() const totalCurrentProgress = downloadStates .map((a) => a.size.transferred + a.size.transferred) @@ -73,9 +76,10 @@ export default function DownloadingState() { size="sm" onClick={() => { if (item?.modelId) { - extensionManager - .get(ExtensionType.Model) - ?.cancelModelDownload(item.modelId) + const model = downloadingModels.find( + (model) => model.id === item.modelId + ) + if (model) abortModelDownload(model) } }} > diff --git a/web/containers/ModalCancelDownload/index.tsx b/web/containers/ModalCancelDownload/index.tsx index d1a6f1a44..2a5626183 100644 --- a/web/containers/ModalCancelDownload/index.tsx +++ b/web/containers/ModalCancelDownload/index.tsx @@ -1,6 +1,5 @@ import { useMemo } from 'react' -import { ModelExtension, ExtensionType } from '@janhq/core' import { Model } from '@janhq/core' import { @@ -17,11 +16,12 @@ import { import { atom, useAtomValue } from 'jotai' +import useDownloadModel from '@/hooks/useDownloadModel' import { useDownloadState } from '@/hooks/useDownloadState' import { formatDownloadPercentage } from '@/utils/converter' -import { extensionManager } from '@/extension' +import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom' type Props = { model: Model @@ -30,6 +30,7 @@ type Props = { export default function ModalCancelDownload({ model, isFromList }: Props) { const { modelDownloadStateAtom } = useDownloadState() + const downloadingModels = useAtomValue(downloadingModelsAtom) const downloadAtom = useMemo( () => atom((get) => get(modelDownloadStateAtom)[model.id]), // eslint-disable-next-line react-hooks/exhaustive-deps @@ -37,6 +38,7 @@ export default function ModalCancelDownload({ model, isFromList }: Props) { ) const downloadState = useAtomValue(downloadAtom) const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}` + const { abortModelDownload } = useDownloadModel() return ( @@ -80,9 +82,10 @@ export default function ModalCancelDownload({ model, isFromList }: Props) { themes="danger" onClick={() => { if (downloadState?.modelId) { - extensionManager - .get(ExtensionType.Model) - ?.cancelModelDownload(downloadState.modelId) + const model = downloadingModels.find( + (model) => model.id === downloadState.modelId + ) + if (model) abortModelDownload(model) } }} > diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index d73e5732d..046f2ecd2 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -1,34 +1,35 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { PropsWithChildren, useEffect, useRef } from 'react' +import { basename } from 'path' -import { ExtensionType } from '@janhq/core' -import { ModelExtension } from '@janhq/core' +import { PropsWithChildren, useEffect, useRef } from 'react' import { useAtomValue, useSetAtom } from 'jotai' import { useDownloadState } from '@/hooks/useDownloadState' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' +import { modelBinFileName } from '@/utils/model' + import EventHandler from './EventHandler' import { appDownloadProgress } from './Jotai' -import { extensionManager } from '@/extension/ExtensionManager' import { downloadingModelsAtom } from '@/helpers/atoms/Model.atom' export default function EventListenerWrapper({ children }: PropsWithChildren) { const setProgress = useSetAtom(appDownloadProgress) const models = useAtomValue(downloadingModelsAtom) const modelsRef = useRef(models) - useEffect(() => { - modelsRef.current = models - }, [models]) + const { setDownloadedModels, downloadedModels } = useGetDownloadedModels() const { setDownloadState, setDownloadStateSuccess, setDownloadStateFailed } = useDownloadState() const downloadedModelRef = useRef(downloadedModels) + useEffect(() => { + modelsRef.current = models + }, [models]) useEffect(() => { downloadedModelRef.current = downloadedModels }, [downloadedModels]) @@ -38,40 +39,36 @@ export default function EventListenerWrapper({ children }: PropsWithChildren) { window.electronAPI.onFileDownloadUpdate( (_event: string, state: any | undefined) => { if (!state) return - setDownloadState({ - ...state, - modelId: state.fileName.split('/').pop() ?? '', - }) + const model = modelsRef.current.find( + (model) => modelBinFileName(model) === basename(state.fileName) + ) + if (model) + setDownloadState({ + ...state, + modelId: model.id, + }) } ) - window.electronAPI.onFileDownloadError( - (_event: string, callback: any) => { - console.error('Download error', callback) - const modelId = callback.fileName.split('/').pop() ?? '' - setDownloadStateFailed(modelId) - } - ) + window.electronAPI.onFileDownloadError((_event: string, state: any) => { + console.error('Download error', state) + const model = modelsRef.current.find( + (model) => modelBinFileName(model) === basename(state.fileName) + ) + if (model) setDownloadStateFailed(model.id) + }) - window.electronAPI.onFileDownloadSuccess( - (_event: string, callback: any) => { - if (callback && callback.fileName) { - const modelId = callback.fileName.split('/').pop() ?? '' - - const model = modelsRef.current.find((e) => e.id === modelId) - - setDownloadStateSuccess(modelId) - - if (model) - extensionManager - .get(ExtensionType.Model) - ?.saveModel(model) - .then(() => { - setDownloadedModels([...downloadedModelRef.current, model]) - }) + window.electronAPI.onFileDownloadSuccess((_event: string, state: any) => { + if (state && state.fileName) { + const model = modelsRef.current.find( + (model) => modelBinFileName(model) === basename(state.fileName) + ) + if (model) { + setDownloadStateSuccess(model.id) + setDownloadedModels([...downloadedModelRef.current, model]) } } - ) + }) window.electronAPI.onAppUpdateDownloadUpdate( (_event: string, progress: any) => { diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts index 7d5b2d1bd..bd587981c 100644 --- a/web/hooks/useDownloadModel.ts +++ b/web/hooks/useDownloadModel.ts @@ -1,7 +1,15 @@ -import { Model, ExtensionType, ModelExtension } from '@janhq/core' +import { + Model, + ExtensionType, + ModelExtension, + abortDownload, + joinPath, +} from '@janhq/core' import { useSetAtom } from 'jotai' +import { modelBinFileName } from '@/utils/model' + import { useDownloadState } from './useDownloadState' import { extensionManager } from '@/extension/ExtensionManager' @@ -33,8 +41,14 @@ export default function useDownloadModel() { .get(ExtensionType.Model) ?.downloadModel(model) } + const abortModelDownload = async (model: Model) => { + await abortDownload( + await joinPath(['models', model.id, modelBinFileName(model)]) + ) + } return { downloadModel, + abortModelDownload, } } diff --git a/web/hooks/useEngineSettings.ts b/web/hooks/useEngineSettings.ts index 50dcd0518..14f32d4b4 100644 --- a/web/hooks/useEngineSettings.ts +++ b/web/hooks/useEngineSettings.ts @@ -1,10 +1,10 @@ -import { join } from 'path' - -import { fs } from '@janhq/core' +import { fs, joinPath } from '@janhq/core' export const useEngineSettings = () => { const readOpenAISettings = async () => { - const settings = await fs.readFile(join('engines', 'openai.json')) + const settings = await fs.readFile( + await joinPath(['engines', 'openai.json']) + ) if (settings) { return JSON.parse(settings) } @@ -17,7 +17,10 @@ export const useEngineSettings = () => { }) => { const settings = await readOpenAISettings() settings.api_key = apiKey - await fs.writeFile(join('engines', 'openai.json'), JSON.stringify(settings)) + await fs.writeFile( + await joinPath(['engines', 'openai.json']), + JSON.stringify(settings) + ) } return { readOpenAISettings, saveOpenAISettings } } diff --git a/web/screens/Chat/ChatBody/index.tsx b/web/screens/Chat/ChatBody/index.tsx index 4d86e9e44..e1ae98b6e 100644 --- a/web/screens/Chat/ChatBody/index.tsx +++ b/web/screens/Chat/ChatBody/index.tsx @@ -116,7 +116,7 @@ const ChatBody: React.FC = () => { ) : ( {messages.map((message, index) => ( - <> +
{message.status === MessageStatus.Error && @@ -126,8 +126,8 @@ const ChatBody: React.FC = () => { className="mt-10 flex flex-col items-center" > - Oops! The generation was interrupted. Let's - give it another go! + Oops! The generation was interrupted. Let's give it + another go!
)} - + ))}
)} diff --git a/web/utils/model.ts b/web/utils/model.ts new file mode 100644 index 000000000..5c5ef1264 --- /dev/null +++ b/web/utils/model.ts @@ -0,0 +1,12 @@ +import { basename } from 'path' + +import { Model } from '@janhq/core' + +export const modelBinFileName = (model: Model) => { + const modelFormatExt = '.gguf' + const extractedFileName = basename(model.source_url) ?? model.id + const fileName = extractedFileName.toLowerCase().endsWith(modelFormatExt) + ? extractedFileName + : model.id + return fileName +} diff --git a/web/utils/model_param.ts b/web/utils/model_param.ts index 3288fb40b..7d559c313 100644 --- a/web/utils/model_param.ts +++ b/web/utils/model_param.ts @@ -22,8 +22,7 @@ export const toRuntimeParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultModelParams) { - // @ts-ignore - runtimeParams[key] = value + runtimeParams[key as keyof ModelRuntimeParams] = value } } @@ -46,8 +45,7 @@ export const toSettingParams = ( for (const [key, value] of Object.entries(modelParams)) { if (key in defaultSettingParams) { - // @ts-ignore - settingParams[key] = value + settingParams[key as keyof ModelSettingParams] = value } }