fix(Thread): #1042 allow create new thread by clicking Use in Jan Hub (#1103)

Signed-off-by: James <james@jan.ai>
Co-authored-by: James <james@jan.ai>
This commit is contained in:
NamH 2023-12-19 17:06:57 +07:00 committed by GitHub
parent 4653030bc1
commit 84fb5ef346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 168 additions and 119 deletions

View File

@ -55,7 +55,7 @@ export default class JSONConversationalExtension
convos.sort( convos.sort(
(a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime() (a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime()
) )
console.debug('getThreads', JSON.stringify(convos, null, 2))
return convos return convos
} catch (error) { } catch (error) {
console.error(error) console.error(error)

View File

@ -71,7 +71,7 @@ async function loadModel(nitroResourceProbe: any | undefined) {
.then(() => loadLLMModel(currentSettings)) .then(() => loadLLMModel(currentSettings))
.then(validateModelStatus) .then(validateModelStatus)
.catch((err) => { .catch((err) => {
console.log("error: ", err); console.error("error: ", err);
// TODO: Broadcast error so app could display proper error message // TODO: Broadcast error so app could display proper error message
return { error: err, currentModelFile }; return { error: err, currentModelFile };
}); });
@ -172,7 +172,7 @@ async function validateModelStatus(): Promise<ModelOperationResponse> {
async function killSubprocess(): Promise<void> { async function killSubprocess(): Promise<void> {
const controller = new AbortController(); const controller = new AbortController();
setTimeout(() => controller.abort(), 5000); setTimeout(() => controller.abort(), 5000);
console.log("Start requesting to kill Nitro..."); console.debug("Start requesting to kill Nitro...");
return fetch(NITRO_HTTP_KILL_URL, { return fetch(NITRO_HTTP_KILL_URL, {
method: "DELETE", method: "DELETE",
signal: controller.signal, signal: controller.signal,
@ -183,7 +183,7 @@ async function killSubprocess(): Promise<void> {
}) })
.catch(() => {}) .catch(() => {})
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
.then(() => console.log("Nitro is killed")); .then(() => console.debug("Nitro is killed"));
} }
/** /**
* Look for the Nitro binary and execute it * Look for the Nitro binary and execute it
@ -191,7 +191,7 @@ async function killSubprocess(): Promise<void> {
* Should run exactly platform specified Nitro binary version * Should run exactly platform specified Nitro binary version
*/ */
function spawnNitroProcess(nitroResourceProbe: any): Promise<any> { function spawnNitroProcess(nitroResourceProbe: any): Promise<any> {
console.log("Starting Nitro subprocess..."); console.debug("Starting Nitro subprocess...");
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
let binaryFolder = path.join(__dirname, "bin"); // Current directory by default let binaryFolder = path.join(__dirname, "bin"); // Current directory by default
let binaryName; let binaryName;
@ -221,7 +221,7 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise<any> {
}); });
subprocess.stderr.on("data", (data) => { subprocess.stderr.on("data", (data) => {
console.log("subprocess error:" + data.toString()); console.error("subprocess error:" + data.toString());
console.error(`stderr: ${data}`); console.error(`stderr: ${data}`);
}); });

View File

@ -5,11 +5,14 @@ import {
Thread, Thread,
ThreadAssistantInfo, ThreadAssistantInfo,
ThreadState, ThreadState,
Model,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtomValue, useSetAtom } from 'jotai'
import { generateThreadId } from '@/utils/thread' import { generateThreadId } from '@/utils/thread'
import useDeleteThread from './useDeleteThread'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
threadsAtom, threadsAtom,
@ -46,27 +49,33 @@ export const useCreateNewThread = () => {
setThreadModelRuntimeParamsAtom setThreadModelRuntimeParamsAtom
) )
const requestCreateNewThread = async (assistant: Assistant) => { const { deleteThread } = useDeleteThread()
const requestCreateNewThread = async (
assistant: Assistant,
model?: Model | undefined
) => {
// loop through threads state and filter if there's any thread that is not finish init // loop through threads state and filter if there's any thread that is not finish init
let hasUnfinishedInitThread = false let unfinishedInitThreadId: string | undefined = undefined
for (const key in threadStates) { for (const key in threadStates) {
const isFinishInit = threadStates[key].isFinishInit ?? true const isFinishInit = threadStates[key].isFinishInit ?? true
if (!isFinishInit) { if (!isFinishInit) {
hasUnfinishedInitThread = true unfinishedInitThreadId = key
break break
} }
} }
if (hasUnfinishedInitThread) { if (unfinishedInitThreadId) {
return await deleteThread(unfinishedInitThreadId)
} }
const modelId = model ? model.id : '*'
const createdAt = Date.now() const createdAt = Date.now()
const assistantInfo: ThreadAssistantInfo = { const assistantInfo: ThreadAssistantInfo = {
assistant_id: assistant.id, assistant_id: assistant.id,
assistant_name: assistant.name, assistant_name: assistant.name,
model: { model: {
id: '*', id: modelId,
settings: {}, settings: {},
parameters: { parameters: {
stream: true, stream: true,

View File

@ -1,5 +1,9 @@
import { ChatCompletionRole, ExtensionType } from '@janhq/core' import {
import { ConversationalExtension } from '@janhq/core' ChatCompletionRole,
ExtensionType,
ConversationalExtension,
} from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { currentPromptAtom } from '@/containers/Providers/Jotai' import { currentPromptAtom } from '@/containers/Providers/Jotai'
@ -19,6 +23,7 @@ import {
threadsAtom, threadsAtom,
setActiveThreadIdAtom, setActiveThreadIdAtom,
deleteThreadStateAtom, deleteThreadStateAtom,
threadStatesAtom,
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
export default function useDeleteThread() { export default function useDeleteThread() {
@ -32,6 +37,8 @@ export default function useDeleteThread() {
const cleanMessages = useSetAtom(cleanChatMessagesAtom) const cleanMessages = useSetAtom(cleanChatMessagesAtom)
const deleteThreadState = useSetAtom(deleteThreadStateAtom) const deleteThreadState = useSetAtom(deleteThreadStateAtom)
const threadStates = useAtomValue(threadStatesAtom)
const cleanThread = async (threadId: string) => { const cleanThread = async (threadId: string) => {
if (threadId) { if (threadId) {
const thread = threads.filter((c) => c.id === threadId)[0] const thread = threads.filter((c) => c.id === threadId)[0]
@ -59,15 +66,21 @@ export default function useDeleteThread() {
const availableThreads = threads.filter((c) => c.id !== threadId) const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads) setThreads(availableThreads)
const deletingThreadState = threadStates[threadId]
const isFinishInit = deletingThreadState?.isFinishInit ?? true
// delete the thread state // delete the thread state
deleteThreadState(threadId) deleteThreadState(threadId)
deleteMessages(threadId) if (isFinishInit) {
setCurrentPrompt('') deleteMessages(threadId)
toaster({ setCurrentPrompt('')
title: 'Thread successfully deleted.', toaster({
description: `Thread with ${activeModel?.name} has been successfully deleted.`, title: 'Thread successfully deleted.',
}) description: `Thread with ${activeModel?.name} has been successfully deleted.`,
})
}
if (availableThreads.length > 0) { if (availableThreads.length > 0) {
setActiveThreadId(availableThreads[0].id) setActiveThreadId(availableThreads[0].id)
} else { } else {

View File

@ -1,58 +0,0 @@
import { ExtensionType, ModelRuntimeParams, ThreadState } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useSetAtom } from 'jotai'
import { extensionManager } from '@/extension/ExtensionManager'
import {
threadModelRuntimeParamsAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Thread.atom'
const useGetAllThreads = () => {
const setThreadStates = useSetAtom(threadStatesAtom)
const setThreads = useSetAtom(threadsAtom)
const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom)
const getAllThreads = async () => {
try {
const threads =
(await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()) ?? []
const threadStates: Record<string, ThreadState> = {}
const threadModelParams: Record<string, ModelRuntimeParams> = {}
threads.forEach((thread) => {
if (thread.id != null) {
const lastMessage = (thread.metadata?.lastMessage as string) ?? ''
threadStates[thread.id] = {
hasMore: true,
waitingForResponse: false,
lastMessage,
isFinishInit: true,
}
// model params
const modelParams = thread.assistants?.[0]?.model?.parameters
threadModelParams[thread.id] = modelParams
}
})
// updating app states
setThreadStates(threadStates)
setThreads(threads)
setThreadModelRuntimeParams(threadModelParams)
} catch (error) {
console.error(error)
}
}
return {
getAllThreads,
}
}
export default useGetAllThreads

View File

@ -4,13 +4,10 @@ import { Assistant, ExtensionType, AssistantExtension } from '@janhq/core'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
export const getAssistants = async (): Promise<Assistant[]> => { export const getAssistants = async (): Promise<Assistant[]> =>
return ( extensionManager
extensionManager .get<AssistantExtension>(ExtensionType.Assistant)
.get<AssistantExtension>(ExtensionType.Assistant) ?.getAssistants() ?? []
?.getAssistants() ?? []
)
}
/** /**
* Hooks for get assistants * Hooks for get assistants

View File

@ -57,6 +57,17 @@ export default function useRecommendedModel() {
} }
return return
} else {
const modelId = activeThread.assistants[0]?.model.id
if (modelId !== '*') {
const models = await getAndSortDownloadedModels()
const model = models.find((model) => model.id === modelId)
if (model) {
setRecommendedModel(model)
}
return
}
} }
if (activeModel) { if (activeModel) {

95
web/hooks/useThreads.ts Normal file
View File

@ -0,0 +1,95 @@
import {
ExtensionType,
ModelRuntimeParams,
Thread,
ThreadState,
} from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtom } from 'jotai'
import { extensionManager } from '@/extension/ExtensionManager'
import {
threadModelRuntimeParamsAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Thread.atom'
const useThreads = () => {
const [threadStates, setThreadStates] = useAtom(threadStatesAtom)
const [threads, setThreads] = useAtom(threadsAtom)
const [threadModelRuntimeParams, setThreadModelRuntimeParams] = useAtom(
threadModelRuntimeParamsAtom
)
const getThreads = async () => {
try {
const localThreads = await getLocalThreads()
const localThreadStates: Record<string, ThreadState> = {}
const threadModelParams: Record<string, ModelRuntimeParams> = {}
localThreads.forEach((thread) => {
if (thread.id != null) {
const lastMessage = (thread.metadata?.lastMessage as string) ?? ''
localThreadStates[thread.id] = {
hasMore: false,
waitingForResponse: false,
lastMessage,
isFinishInit: true,
}
// model params
const modelParams = thread.assistants?.[0]?.model?.parameters
threadModelParams[thread.id] = modelParams
}
})
// allow at max 1 unfinished init thread and it should be at the top of the list
let unfinishedThreadId: string | undefined = undefined
const unfinishedThreadState: Record<string, ThreadState> = {}
for (const key of Object.keys(threadStates)) {
const threadState = threadStates[key]
if (threadState.isFinishInit === false) {
unfinishedThreadState[key] = threadState
unfinishedThreadId = key
break
}
}
const unfinishedThread: Thread | undefined = threads.find(
(thread) => thread.id === unfinishedThreadId
)
let allThreads: Thread[] = [...localThreads]
if (unfinishedThread) {
allThreads = [unfinishedThread, ...localThreads]
}
if (unfinishedThreadId) {
localThreadStates[unfinishedThreadId] =
unfinishedThreadState[unfinishedThreadId]
threadModelParams[unfinishedThreadId] =
threadModelRuntimeParams[unfinishedThreadId]
}
// updating app states
setThreadStates(localThreadStates)
setThreads(allThreads)
setThreadModelRuntimeParams(threadModelParams)
} catch (error) {
console.error(error)
}
}
return {
getAllThreads: getThreads,
}
}
const getLocalThreads = async (): Promise<Thread[]> =>
(await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()) ?? []
export default useThreads

View File

@ -24,12 +24,13 @@ import { twMerge } from 'tailwind-merge'
import { useCreateNewThread } from '@/hooks/useCreateNewThread' import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useDeleteThread from '@/hooks/useDeleteThread' import useDeleteThread from '@/hooks/useDeleteThread'
import useGetAllThreads from '@/hooks/useGetAllThreads'
import useGetAssistants from '@/hooks/useGetAssistants' import useGetAssistants from '@/hooks/useGetAssistants'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import useSetActiveThread from '@/hooks/useSetActiveThread' import useSetActiveThread from '@/hooks/useSetActiveThread'
import useThreads from '@/hooks/useThreads'
import { displayDate } from '@/utils/datetime' import { displayDate } from '@/utils/datetime'
import { import {
@ -41,7 +42,7 @@ import {
export default function ThreadList() { export default function ThreadList() {
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
const threadStates = useAtomValue(threadStatesAtom) const threadStates = useAtomValue(threadStatesAtom)
const { getAllThreads } = useGetAllThreads() const { getAllThreads } = useThreads()
const { assistants } = useGetAssistants() const { assistants } = useGetAssistants()
const { requestCreateNewThread } = useCreateNewThread() const { requestCreateNewThread } = useCreateNewThread()
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)

View File

@ -75,14 +75,12 @@ const ChatScreen = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [waitingToSendMessage, activeThreadId]) }, [waitingToSendMessage, activeThreadId])
const resizeTextArea = () => { useEffect(() => {
if (textareaRef.current) { if (textareaRef.current) {
textareaRef.current.style.height = '40px' textareaRef.current.style.height = '40px'
textareaRef.current.style.height = textareaRef.current.scrollHeight + 'px' textareaRef.current.style.height = textareaRef.current.scrollHeight + 'px'
} }
} }, [currentPrompt])
useEffect(resizeTextArea, [currentPrompt])
const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => { const onKeyDown = async (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter') { if (e.key === 'Enter') {

View File

@ -14,11 +14,10 @@ import ModalCancelDownload from '@/containers/ModalCancelDownload'
import { MainViewState } from '@/constants/screens' import { MainViewState } from '@/constants/screens'
// import { ModelPerformance, TagType } from '@/constants/tagType' import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import { useActiveModel } from '@/hooks/useActiveModel'
import useDownloadModel from '@/hooks/useDownloadModel' import useDownloadModel from '@/hooks/useDownloadModel'
import { useDownloadState } from '@/hooks/useDownloadState' import { useDownloadState } from '@/hooks/useDownloadState'
import { getAssistants } from '@/hooks/useGetAssistants'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { useMainViewState } from '@/hooks/useMainViewState' import { useMainViewState } from '@/hooks/useMainViewState'
@ -34,12 +33,7 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
const { downloadModel } = useDownloadModel() const { downloadModel } = useDownloadModel()
const { downloadedModels } = useGetDownloadedModels() const { downloadedModels } = useGetDownloadedModels()
const { modelDownloadStateAtom, downloadStates } = useDownloadState() const { modelDownloadStateAtom, downloadStates } = useDownloadState()
const { startModel } = useActiveModel() const { requestCreateNewThread } = useCreateNewThread()
// const [title, setTitle] = useState<string>('Recommended')
// const [performanceTag, setPerformanceTag] = useState<TagType>(
// ModelPerformance.PerformancePositive
// )
const downloadAtom = useMemo( const downloadAtom = useMemo(
() => atom((get) => get(modelDownloadStateAtom)[model.id]), () => atom((get) => get(modelDownloadStateAtom)[model.id]),
@ -59,10 +53,15 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
<Button onClick={() => onDownloadClick()}>Download</Button> <Button onClick={() => onDownloadClick()}>Download</Button>
) )
const onUseModelClick = () => { const onUseModelClick = useCallback(async () => {
startModel(model.id) const assistants = await getAssistants()
if (assistants.length === 0) {
alert('No assistant available')
return
}
await requestCreateNewThread(assistants[0], model)
setMainViewState(MainViewState.Thread) setMainViewState(MainViewState.Thread)
} }, [])
if (isDownloaded) { if (isDownloaded) {
downloadButton = ( downloadButton = (
@ -80,22 +79,6 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
downloadButton = <ModalCancelDownload model={model} /> downloadButton = <ModalCancelDownload model={model} />
} }
// const renderBadge = (performance: TagType) => {
// switch (performance) {
// case ModelPerformance.PerformancePositive:
// return <Badge themes="success">{title}</Badge>
// case ModelPerformance.PerformanceNeutral:
// return <Badge themes="secondary">{title}</Badge>
// case ModelPerformance.PerformanceNegative:
// return <Badge themes="danger">{title}</Badge>
// default:
// break
// }
// }
return ( return (
<div <div
className="cursor-pointer rounded-t-md bg-background/50" className="cursor-pointer rounded-t-md bg-background/50"