diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index 5a2537239..0fdf0b2d4 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -55,7 +55,7 @@ export default class JSONConversationalExtension convos.sort( (a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime() ) - console.debug('getThreads', JSON.stringify(convos, null, 2)) + return convos } catch (error) { console.error(error) diff --git a/extensions/inference-nitro-extension/src/module.ts b/extensions/inference-nitro-extension/src/module.ts index bc39e8fca..bca0b6fcc 100644 --- a/extensions/inference-nitro-extension/src/module.ts +++ b/extensions/inference-nitro-extension/src/module.ts @@ -71,7 +71,7 @@ async function loadModel(nitroResourceProbe: any | undefined) { .then(() => loadLLMModel(currentSettings)) .then(validateModelStatus) .catch((err) => { - console.log("error: ", err); + console.error("error: ", err); // TODO: Broadcast error so app could display proper error message return { error: err, currentModelFile }; }); @@ -172,7 +172,7 @@ async function validateModelStatus(): Promise { async function killSubprocess(): Promise { const controller = new AbortController(); setTimeout(() => controller.abort(), 5000); - console.log("Start requesting to kill Nitro..."); + console.debug("Start requesting to kill Nitro..."); return fetch(NITRO_HTTP_KILL_URL, { method: "DELETE", signal: controller.signal, @@ -183,7 +183,7 @@ async function killSubprocess(): Promise { }) .catch(() => {}) .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) - .then(() => console.log("Nitro is killed")); + .then(() => console.debug("Nitro is killed")); } /** * Look for the Nitro binary and execute it @@ -191,7 +191,7 @@ async function killSubprocess(): Promise { * Should run exactly platform specified Nitro binary version */ function spawnNitroProcess(nitroResourceProbe: any): Promise { - console.log("Starting Nitro subprocess..."); + console.debug("Starting Nitro subprocess..."); return new Promise(async (resolve, reject) => { let binaryFolder = path.join(__dirname, "bin"); // Current directory by default let binaryName; @@ -221,7 +221,7 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise { }); subprocess.stderr.on("data", (data) => { - console.log("subprocess error:" + data.toString()); + console.error("subprocess error:" + data.toString()); console.error(`stderr: ${data}`); }); diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 954929553..b7544f74e 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -5,11 +5,14 @@ import { Thread, ThreadAssistantInfo, ThreadState, + Model, } from '@janhq/core' import { atom, useAtomValue, useSetAtom } from 'jotai' import { generateThreadId } from '@/utils/thread' +import useDeleteThread from './useDeleteThread' + import { extensionManager } from '@/extension' import { threadsAtom, @@ -46,27 +49,33 @@ export const useCreateNewThread = () => { setThreadModelRuntimeParamsAtom ) - const requestCreateNewThread = async (assistant: Assistant) => { + const { deleteThread } = useDeleteThread() + + const requestCreateNewThread = async ( + assistant: Assistant, + model?: Model | undefined + ) => { // loop through threads state and filter if there's any thread that is not finish init - let hasUnfinishedInitThread = false + let unfinishedInitThreadId: string | undefined = undefined for (const key in threadStates) { const isFinishInit = threadStates[key].isFinishInit ?? true if (!isFinishInit) { - hasUnfinishedInitThread = true + unfinishedInitThreadId = key break } } - if (hasUnfinishedInitThread) { - return + if (unfinishedInitThreadId) { + await deleteThread(unfinishedInitThreadId) } + const modelId = model ? model.id : '*' const createdAt = Date.now() const assistantInfo: ThreadAssistantInfo = { assistant_id: assistant.id, assistant_name: assistant.name, model: { - id: '*', + id: modelId, settings: {}, parameters: { stream: true, diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index 8822b6aa8..320fe045c 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -1,5 +1,9 @@ -import { ChatCompletionRole, ExtensionType } from '@janhq/core' -import { ConversationalExtension } from '@janhq/core' +import { + ChatCompletionRole, + ExtensionType, + ConversationalExtension, +} from '@janhq/core' + import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { currentPromptAtom } from '@/containers/Providers/Jotai' @@ -19,6 +23,7 @@ import { threadsAtom, setActiveThreadIdAtom, deleteThreadStateAtom, + threadStatesAtom, } from '@/helpers/atoms/Thread.atom' export default function useDeleteThread() { @@ -32,6 +37,8 @@ export default function useDeleteThread() { const cleanMessages = useSetAtom(cleanChatMessagesAtom) const deleteThreadState = useSetAtom(deleteThreadStateAtom) + const threadStates = useAtomValue(threadStatesAtom) + const cleanThread = async (threadId: string) => { if (threadId) { const thread = threads.filter((c) => c.id === threadId)[0] @@ -59,15 +66,21 @@ export default function useDeleteThread() { const availableThreads = threads.filter((c) => c.id !== threadId) setThreads(availableThreads) + const deletingThreadState = threadStates[threadId] + const isFinishInit = deletingThreadState?.isFinishInit ?? true + // delete the thread state deleteThreadState(threadId) - deleteMessages(threadId) - setCurrentPrompt('') - toaster({ - title: 'Thread successfully deleted.', - description: `Thread with ${activeModel?.name} has been successfully deleted.`, - }) + if (isFinishInit) { + deleteMessages(threadId) + setCurrentPrompt('') + toaster({ + title: 'Thread successfully deleted.', + description: `Thread with ${activeModel?.name} has been successfully deleted.`, + }) + } + if (availableThreads.length > 0) { setActiveThreadId(availableThreads[0].id) } else { diff --git a/web/hooks/useGetAllThreads.ts b/web/hooks/useGetAllThreads.ts deleted file mode 100644 index 867434617..000000000 --- a/web/hooks/useGetAllThreads.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { ExtensionType, ModelRuntimeParams, ThreadState } from '@janhq/core' -import { ConversationalExtension } from '@janhq/core' -import { useSetAtom } from 'jotai' - -import { extensionManager } from '@/extension/ExtensionManager' -import { - threadModelRuntimeParamsAtom, - threadStatesAtom, - threadsAtom, -} from '@/helpers/atoms/Thread.atom' - -const useGetAllThreads = () => { - const setThreadStates = useSetAtom(threadStatesAtom) - const setThreads = useSetAtom(threadsAtom) - const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom) - - const getAllThreads = async () => { - try { - const threads = - (await extensionManager - .get(ExtensionType.Conversational) - ?.getThreads()) ?? [] - - const threadStates: Record = {} - const threadModelParams: Record = {} - - threads.forEach((thread) => { - if (thread.id != null) { - const lastMessage = (thread.metadata?.lastMessage as string) ?? '' - - threadStates[thread.id] = { - hasMore: true, - waitingForResponse: false, - lastMessage, - isFinishInit: true, - } - - // model params - const modelParams = thread.assistants?.[0]?.model?.parameters - threadModelParams[thread.id] = modelParams - } - }) - - // updating app states - setThreadStates(threadStates) - setThreads(threads) - setThreadModelRuntimeParams(threadModelParams) - } catch (error) { - console.error(error) - } - } - - return { - getAllThreads, - } -} - -export default useGetAllThreads diff --git a/web/hooks/useGetAssistants.ts b/web/hooks/useGetAssistants.ts index c40e2861e..0fa66c9c9 100644 --- a/web/hooks/useGetAssistants.ts +++ b/web/hooks/useGetAssistants.ts @@ -4,13 +4,10 @@ import { Assistant, ExtensionType, AssistantExtension } from '@janhq/core' import { extensionManager } from '@/extension/ExtensionManager' -export const getAssistants = async (): Promise => { - return ( - extensionManager - .get(ExtensionType.Assistant) - ?.getAssistants() ?? [] - ) -} +export const getAssistants = async (): Promise => + extensionManager + .get(ExtensionType.Assistant) + ?.getAssistants() ?? [] /** * Hooks for get assistants diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts index 944faf83b..6dc5771f1 100644 --- a/web/hooks/useRecommendedModel.ts +++ b/web/hooks/useRecommendedModel.ts @@ -57,6 +57,17 @@ export default function useRecommendedModel() { } return + } else { + const modelId = activeThread.assistants[0]?.model.id + if (modelId !== '*') { + const models = await getAndSortDownloadedModels() + const model = models.find((model) => model.id === modelId) + + if (model) { + setRecommendedModel(model) + } + return + } } if (activeModel) { diff --git a/web/hooks/useThreads.ts b/web/hooks/useThreads.ts new file mode 100644 index 000000000..69145de05 --- /dev/null +++ b/web/hooks/useThreads.ts @@ -0,0 +1,95 @@ +import { + ExtensionType, + ModelRuntimeParams, + Thread, + ThreadState, +} from '@janhq/core' +import { ConversationalExtension } from '@janhq/core' +import { useAtom } from 'jotai' + +import { extensionManager } from '@/extension/ExtensionManager' +import { + threadModelRuntimeParamsAtom, + threadStatesAtom, + threadsAtom, +} from '@/helpers/atoms/Thread.atom' + +const useThreads = () => { + const [threadStates, setThreadStates] = useAtom(threadStatesAtom) + const [threads, setThreads] = useAtom(threadsAtom) + const [threadModelRuntimeParams, setThreadModelRuntimeParams] = useAtom( + threadModelRuntimeParamsAtom + ) + + const getThreads = async () => { + try { + const localThreads = await getLocalThreads() + const localThreadStates: Record = {} + const threadModelParams: Record = {} + + localThreads.forEach((thread) => { + if (thread.id != null) { + const lastMessage = (thread.metadata?.lastMessage as string) ?? '' + + localThreadStates[thread.id] = { + hasMore: false, + waitingForResponse: false, + lastMessage, + isFinishInit: true, + } + + // model params + const modelParams = thread.assistants?.[0]?.model?.parameters + threadModelParams[thread.id] = modelParams + } + }) + + // allow at max 1 unfinished init thread and it should be at the top of the list + let unfinishedThreadId: string | undefined = undefined + const unfinishedThreadState: Record = {} + + for (const key of Object.keys(threadStates)) { + const threadState = threadStates[key] + if (threadState.isFinishInit === false) { + unfinishedThreadState[key] = threadState + unfinishedThreadId = key + break + } + } + const unfinishedThread: Thread | undefined = threads.find( + (thread) => thread.id === unfinishedThreadId + ) + + let allThreads: Thread[] = [...localThreads] + if (unfinishedThread) { + allThreads = [unfinishedThread, ...localThreads] + } + + if (unfinishedThreadId) { + localThreadStates[unfinishedThreadId] = + unfinishedThreadState[unfinishedThreadId] + + threadModelParams[unfinishedThreadId] = + threadModelRuntimeParams[unfinishedThreadId] + } + + // updating app states + setThreadStates(localThreadStates) + setThreads(allThreads) + setThreadModelRuntimeParams(threadModelParams) + } catch (error) { + console.error(error) + } + } + + return { + getAllThreads: getThreads, + } +} + +const getLocalThreads = async (): Promise => + (await extensionManager + .get(ExtensionType.Conversational) + ?.getThreads()) ?? [] + +export default useThreads diff --git a/web/screens/Chat/ThreadList/index.tsx b/web/screens/Chat/ThreadList/index.tsx index c3da35c42..a32faa92f 100644 --- a/web/screens/Chat/ThreadList/index.tsx +++ b/web/screens/Chat/ThreadList/index.tsx @@ -24,12 +24,13 @@ import { twMerge } from 'tailwind-merge' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useDeleteThread from '@/hooks/useDeleteThread' -import useGetAllThreads from '@/hooks/useGetAllThreads' import useGetAssistants from '@/hooks/useGetAssistants' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import useSetActiveThread from '@/hooks/useSetActiveThread' +import useThreads from '@/hooks/useThreads' + import { displayDate } from '@/utils/datetime' import { @@ -41,7 +42,7 @@ import { export default function ThreadList() { const threads = useAtomValue(threadsAtom) const threadStates = useAtomValue(threadStatesAtom) - const { getAllThreads } = useGetAllThreads() + const { getAllThreads } = useThreads() const { assistants } = useGetAssistants() const { requestCreateNewThread } = useCreateNewThread() const activeThread = useAtomValue(activeThreadAtom) diff --git a/web/screens/Chat/index.tsx b/web/screens/Chat/index.tsx index 8f2e30e09..1de97f394 100644 --- a/web/screens/Chat/index.tsx +++ b/web/screens/Chat/index.tsx @@ -75,14 +75,12 @@ const ChatScreen = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, [waitingToSendMessage, activeThreadId]) - const resizeTextArea = () => { + useEffect(() => { if (textareaRef.current) { textareaRef.current.style.height = '40px' textareaRef.current.style.height = textareaRef.current.scrollHeight + 'px' } - } - - useEffect(resizeTextArea, [currentPrompt]) + }, [currentPrompt]) const onKeyDown = async (e: React.KeyboardEvent) => { if (e.key === 'Enter') { diff --git a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx index ba23056c6..40813225f 100644 --- a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx +++ b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx @@ -14,11 +14,10 @@ import ModalCancelDownload from '@/containers/ModalCancelDownload' import { MainViewState } from '@/constants/screens' -// import { ModelPerformance, TagType } from '@/constants/tagType' - -import { useActiveModel } from '@/hooks/useActiveModel' +import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useDownloadModel from '@/hooks/useDownloadModel' import { useDownloadState } from '@/hooks/useDownloadState' +import { getAssistants } from '@/hooks/useGetAssistants' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import { useMainViewState } from '@/hooks/useMainViewState' @@ -34,12 +33,7 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { const { downloadModel } = useDownloadModel() const { downloadedModels } = useGetDownloadedModels() const { modelDownloadStateAtom, downloadStates } = useDownloadState() - const { startModel } = useActiveModel() - // const [title, setTitle] = useState('Recommended') - - // const [performanceTag, setPerformanceTag] = useState( - // ModelPerformance.PerformancePositive - // ) + const { requestCreateNewThread } = useCreateNewThread() const downloadAtom = useMemo( () => atom((get) => get(modelDownloadStateAtom)[model.id]), @@ -59,10 +53,15 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { ) - const onUseModelClick = () => { - startModel(model.id) + const onUseModelClick = useCallback(async () => { + const assistants = await getAssistants() + if (assistants.length === 0) { + alert('No assistant available') + return + } + await requestCreateNewThread(assistants[0], model) setMainViewState(MainViewState.Thread) - } + }, []) if (isDownloaded) { downloadButton = ( @@ -80,22 +79,6 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { downloadButton = } - // const renderBadge = (performance: TagType) => { - // switch (performance) { - // case ModelPerformance.PerformancePositive: - // return {title} - - // case ModelPerformance.PerformanceNeutral: - // return {title} - - // case ModelPerformance.PerformanceNegative: - // return {title} - - // default: - // break - // } - // } - return (