chore: handle chat functions (#5009)

This commit is contained in:
Louis 2025-05-18 20:41:10 +07:00 committed by GitHub
parent c1091ce812
commit 74c2c59c90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 201 additions and 162 deletions

View File

@ -3,7 +3,7 @@
import TextareaAutosize from 'react-textarea-autosize' import TextareaAutosize from 'react-textarea-autosize'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { usePrompt } from '@/hooks/usePrompt' import { usePrompt } from '@/hooks/usePrompt'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useEffect, useRef, useState } from 'react'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { ArrowRight } from 'lucide-react' import { ArrowRight } from 'lucide-react'
import { import {
@ -20,28 +20,14 @@ import {
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useGeneralSetting } from '@/hooks/useGeneralSetting' import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { useModelProvider } from '@/hooks/useModelProvider' import { useModelProvider } from '@/hooks/useModelProvider'
import {
emptyThreadContent,
extractToolCall,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
startModel,
} from '@/lib/completion'
import { useThreads } from '@/hooks/useThreads'
import { defaultModel } from '@/lib/models'
import { useMessages } from '@/hooks/useMessages'
import { useRouter } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useAppState } from '@/hooks/useAppState' import { useAppState } from '@/hooks/useAppState'
import { MovingBorder } from './MovingBorder' import { MovingBorder } from './MovingBorder'
import { MCPTool } from '@/types/completion' import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event' import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events' import { SystemEvent } from '@/types/events'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { getTools } from '@/services/mcp' import { getTools } from '@/services/mcp'
import { useChat } from '@/hooks/useChat'
type ChatInputProps = { type ChatInputProps = {
className?: string className?: string
@ -52,24 +38,14 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null) const textareaRef = useRef<HTMLTextAreaElement>(null)
const [isFocused, setIsFocused] = useState(false) const [isFocused, setIsFocused] = useState(false)
const [rows, setRows] = useState(1) const [rows, setRows] = useState(1)
const [tools, setTools] = useState<MCPTool[]>([]) const { streamingContent, updateTools } = useAppState()
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
const { t } = useTranslation() const { t } = useTranslation()
const { spellCheckChatInput } = useGeneralSetting() const { spellCheckChatInput } = useGeneralSetting()
const maxRows = 10 const maxRows = 10
const { getProviderByName, selectedModel, selectedProvider } = const { selectedModel } = useModelProvider()
useModelProvider() const { sendMessage } = useChat()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { streamingContent, updateStreamingContent, updateLoadingModel } =
useAppState()
const { addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
useEffect(() => { useEffect(() => {
const handleFocusIn = () => { const handleFocusIn = () => {
@ -94,20 +70,20 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
}, []) }, [])
useEffect(() => { useEffect(() => {
function updateTools() { function setTools() {
getTools().then((data: MCPTool[]) => { getTools().then((data: MCPTool[]) => {
setTools(data) updateTools(data)
}) })
} }
updateTools() setTools()
let unsubscribe = () => {} let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, updateTools).then((unsub) => { listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts // Unsubscribe from the event when the component unmounts
unsubscribe = unsub unsubscribe = unsub
}) })
return unsubscribe return unsubscribe
}, []) }, [updateTools])
useEffect(() => { useEffect(() => {
if (textareaRef.current) { if (textareaRef.current) {
@ -115,115 +91,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
} }
}, []) }, [])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
])
const sendMessage = useCallback(async () => {
const activeThread = await getCurrentThread()
if (!activeThread || !provider) return
updateStreamingContent(emptyThreadContent)
addMessage(newUserThreadContent(activeThread.id, prompt))
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(provider.provider, selectedModel.id).catch(
console.error
)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder()
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(prompt)
let isCompleted = false
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
tools
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent
)
addMessage(updatedMessage ?? finalContent)
isCompleted = !toolCalls.length
}
} catch (error) {
console.error('Error sending message:', error)
}
updateStreamingContent(undefined)
}, [
getCurrentThread,
provider,
updateStreamingContent,
addMessage,
prompt,
setPrompt,
selectedModel,
tools,
updateLoadingModel,
])
return ( return (
<div className="relative"> <div className="relative">
<div <div
@ -266,7 +133,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
if (e.key === 'Enter' && !e.shiftKey && prompt) { if (e.key === 'Enter' && !e.shiftKey && prompt) {
e.preventDefault() e.preventDefault()
// Submit the message when Enter is pressed without Shift // Submit the message when Enter is pressed without Shift
sendMessage() sendMessage(prompt)
// When Shift+Enter is pressed, a new line is added (default behavior) // When Shift+Enter is pressed, a new line is added (default behavior)
} }
}} }}
@ -351,7 +218,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
variant={!prompt ? null : 'default'} variant={!prompt ? null : 'default'}
size="icon" size="icon"
disabled={!prompt} disabled={!prompt}
onClick={sendMessage} onClick={() => sendMessage(prompt)}
> >
{streamingContent ? ( {streamingContent ? (
<span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" /> <span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" />

View File

@ -7,7 +7,7 @@ import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore' import { useDownloadStore } from '@/hooks/useDownloadStore'
import { abortDownload } from '@/services/models' import { abortDownload } from '@/services/models'
import { DownloadEvent, DownloadState, events } from '@janhq/core' import { DownloadEvent, DownloadState, events } from '@janhq/core'
import { IconPlayerPauseFilled, IconX } from '@tabler/icons-react' import { IconX } from '@tabler/icons-react'
import { useCallback, useEffect, useMemo } from 'react' import { useCallback, useEffect, useMemo } from 'react'
export function DownloadManagement() { export function DownloadManagement() {

View File

@ -1,6 +1,6 @@
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
import { RenderMarkdown } from './RenderMarkdown' import { RenderMarkdown } from './RenderMarkdown'
import { Fragment, memo, useMemo, useState } from 'react' import { Fragment, memo, useCallback, useMemo, useState } from 'react'
import { import {
IconCopy, IconCopy,
IconCopyCheck, IconCopyCheck,
@ -13,6 +13,7 @@ import { cn } from '@/lib/utils'
import { useMessages } from '@/hooks/useMessages' import { useMessages } from '@/hooks/useMessages'
import ThinkingBlock from '@/containers/ThinkingBlock' import ThinkingBlock from '@/containers/ThinkingBlock'
import ToolCallBlock from '@/containers/ToolCallBlock' import ToolCallBlock from '@/containers/ToolCallBlock'
import { useChat } from '@/hooks/useChat'
const CopyButton = ({ text }: { text: string }) => { const CopyButton = ({ text }: { text: string }) => {
const [copied, setCopied] = useState(false) const [copied, setCopied] = useState(false)
@ -25,7 +26,7 @@ const CopyButton = ({ text }: { text: string }) => {
return ( return (
<button <button
className="flex items-center gap-1 hover:text-accent transition-colors group relative" className="flex items-center gap-1 hover:text-accent transition-colors group relative cursor-pointer"
onClick={handleCopy} onClick={handleCopy}
> >
{copied ? ( {copied ? (
@ -80,7 +81,18 @@ export const ThreadContent = memo(
} }
}, [text]) }, [text])
const { deleteMessage } = useMessages() const { getMessages, deleteMessage } = useMessages()
const { sendMessage } = useChat()
const regenerate = useCallback(() => {
// Only regenerate assistant message is allowed
deleteMessage(item.thread_id, item.id)
const threadMessages = getMessages(item.thread_id)
const lastMessage = threadMessages[threadMessages.length - 1]
if (!lastMessage) return
deleteMessage(lastMessage.thread_id, lastMessage.id)
sendMessage(lastMessage.content?.[0]?.text?.value || '')
}, [deleteMessage, getMessages, item, sendMessage])
const isToolCalls = const isToolCalls =
item.metadata && item.metadata &&
@ -170,17 +182,17 @@ export const ThreadContent = memo(
Delete Delete
</span> </span>
</button> </button>
{item.isLastMessage && (
<button <button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative" className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => { onClick={regenerate}
console.log('Regenerate clicked')
}}
> >
<IconRefresh size={16} /> <IconRefresh size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out"> <span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Regenerate Regenerate
</span> </span>
</button> </button>
)}
</div> </div>
</div> </div>
)} )}

View File

@ -1,20 +1,27 @@
import { create } from 'zustand' import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
import { MCPTool } from '@/types/completion'
type AppState = { type AppState = {
streamingContent?: ThreadMessage streamingContent?: ThreadMessage
loadingModel?: boolean loadingModel?: boolean
tools: MCPTool[]
updateStreamingContent: (content: ThreadMessage | undefined) => void updateStreamingContent: (content: ThreadMessage | undefined) => void
updateLoadingModel: (loading: boolean) => void updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
} }
export const useAppState = create<AppState>()((set) => ({ export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined, streamingContent: undefined,
loadingModel: false, loadingModel: false,
tools: [],
updateStreamingContent: (content) => { updateStreamingContent: (content) => {
set({ streamingContent: content }) set({ streamingContent: content })
}, },
updateLoadingModel: (loading) => { updateLoadingModel: (loading) => {
set({ loadingModel: loading }) set({ loadingModel: loading })
}, },
updateTools: (tools) => {
set({ tools })
},
})) }))

View File

@ -0,0 +1,149 @@
import { useCallback, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
startModel,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const { tools } = useAppState()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { updateStreamingContent, updateLoadingModel } = useAppState()
const { addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
])
const sendMessage = useCallback(
async (message: string) => {
const activeThread = await getCurrentThread()
if (!activeThread || !provider) return
updateStreamingContent(emptyThreadContent)
addMessage(newUserThreadContent(activeThread.id, message))
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(provider.provider, selectedModel.id).catch(
console.error
)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder()
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(message)
let isCompleted = false
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
tools
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent
)
addMessage(updatedMessage ?? finalContent)
isCompleted = !toolCalls.length
}
} catch (error) {
console.error('Error sending message:', error)
}
updateStreamingContent(undefined)
},
[
getCurrentThread,
provider,
updateStreamingContent,
addMessage,
setPrompt,
selectedModel,
tools,
updateLoadingModel,
]
)
return { sendMessage }
}

View File

@ -9,6 +9,7 @@ import {
type MessageState = { type MessageState = {
messages: Record<string, ThreadMessage[]> messages: Record<string, ThreadMessage[]>
getMessages: (threadId: string) => ThreadMessage[]
setMessages: (threadId: string, messages: ThreadMessage[]) => void setMessages: (threadId: string, messages: ThreadMessage[]) => void
addMessage: (message: ThreadMessage) => void addMessage: (message: ThreadMessage) => void
deleteMessage: (threadId: string, messageId: string) => void deleteMessage: (threadId: string, messageId: string) => void
@ -16,8 +17,11 @@ type MessageState = {
export const useMessages = create<MessageState>()( export const useMessages = create<MessageState>()(
persist( persist(
(set) => ({ (set, get) => ({
messages: {}, messages: {},
getMessages: (threadId) => {
return get().messages[threadId] || []
},
setMessages: (threadId, messages) => { setMessages: (threadId, messages) => {
set((state) => ({ set((state) => ({
messages: { messages: {