chore: handle chat functions (#5009)

This commit is contained in:
Louis 2025-05-18 20:41:10 +07:00 committed by GitHub
parent c1091ce812
commit 74c2c59c90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 201 additions and 162 deletions

View File

@ -3,7 +3,7 @@
import TextareaAutosize from 'react-textarea-autosize'
import { cn } from '@/lib/utils'
import { usePrompt } from '@/hooks/usePrompt'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useEffect, useRef, useState } from 'react'
import { Button } from '@/components/ui/button'
import { ArrowRight } from 'lucide-react'
import {
@ -20,28 +20,14 @@ import {
import { useTranslation } from 'react-i18next'
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
import { useModelProvider } from '@/hooks/useModelProvider'
import {
emptyThreadContent,
extractToolCall,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
startModel,
} from '@/lib/completion'
import { useThreads } from '@/hooks/useThreads'
import { defaultModel } from '@/lib/models'
import { useMessages } from '@/hooks/useMessages'
import { useRouter } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useAppState } from '@/hooks/useAppState'
import { MovingBorder } from './MovingBorder'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { getTools } from '@/services/mcp'
import { useChat } from '@/hooks/useChat'
type ChatInputProps = {
className?: string
@ -52,24 +38,14 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null)
const [isFocused, setIsFocused] = useState(false)
const [rows, setRows] = useState(1)
const [tools, setTools] = useState<MCPTool[]>([])
const { streamingContent, updateTools } = useAppState()
const { prompt, setPrompt } = usePrompt()
const { t } = useTranslation()
const { spellCheckChatInput } = useGeneralSetting()
const maxRows = 10
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { streamingContent, updateStreamingContent, updateLoadingModel } =
useAppState()
const { addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const { selectedModel } = useModelProvider()
const { sendMessage } = useChat()
useEffect(() => {
const handleFocusIn = () => {
@ -94,20 +70,20 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
}, [])
useEffect(() => {
function updateTools() {
function setTools() {
getTools().then((data: MCPTool[]) => {
setTools(data)
updateTools(data)
})
}
updateTools()
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, updateTools).then((unsub) => {
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [])
}, [updateTools])
useEffect(() => {
if (textareaRef.current) {
@ -115,115 +91,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
}
}, [])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
])
const sendMessage = useCallback(async () => {
const activeThread = await getCurrentThread()
if (!activeThread || !provider) return
updateStreamingContent(emptyThreadContent)
addMessage(newUserThreadContent(activeThread.id, prompt))
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(provider.provider, selectedModel.id).catch(
console.error
)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder()
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(prompt)
let isCompleted = false
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
tools
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent
)
addMessage(updatedMessage ?? finalContent)
isCompleted = !toolCalls.length
}
} catch (error) {
console.error('Error sending message:', error)
}
updateStreamingContent(undefined)
}, [
getCurrentThread,
provider,
updateStreamingContent,
addMessage,
prompt,
setPrompt,
selectedModel,
tools,
updateLoadingModel,
])
return (
<div className="relative">
<div
@ -266,7 +133,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
if (e.key === 'Enter' && !e.shiftKey && prompt) {
e.preventDefault()
// Submit the message when Enter is pressed without Shift
sendMessage()
sendMessage(prompt)
// When Shift+Enter is pressed, a new line is added (default behavior)
}
}}
@ -351,7 +218,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
variant={!prompt ? null : 'default'}
size="icon"
disabled={!prompt}
onClick={sendMessage}
onClick={() => sendMessage(prompt)}
>
{streamingContent ? (
<span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" />

View File

@ -7,7 +7,7 @@ import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { abortDownload } from '@/services/models'
import { DownloadEvent, DownloadState, events } from '@janhq/core'
import { IconPlayerPauseFilled, IconX } from '@tabler/icons-react'
import { IconX } from '@tabler/icons-react'
import { useCallback, useEffect, useMemo } from 'react'
export function DownloadManagement() {

View File

@ -1,6 +1,6 @@
import { ThreadMessage } from '@janhq/core'
import { RenderMarkdown } from './RenderMarkdown'
import { Fragment, memo, useMemo, useState } from 'react'
import { Fragment, memo, useCallback, useMemo, useState } from 'react'
import {
IconCopy,
IconCopyCheck,
@ -13,6 +13,7 @@ import { cn } from '@/lib/utils'
import { useMessages } from '@/hooks/useMessages'
import ThinkingBlock from '@/containers/ThinkingBlock'
import ToolCallBlock from '@/containers/ToolCallBlock'
import { useChat } from '@/hooks/useChat'
const CopyButton = ({ text }: { text: string }) => {
const [copied, setCopied] = useState(false)
@ -25,7 +26,7 @@ const CopyButton = ({ text }: { text: string }) => {
return (
<button
className="flex items-center gap-1 hover:text-accent transition-colors group relative"
className="flex items-center gap-1 hover:text-accent transition-colors group relative cursor-pointer"
onClick={handleCopy}
>
{copied ? (
@ -80,7 +81,18 @@ export const ThreadContent = memo(
}
}, [text])
const { deleteMessage } = useMessages()
const { getMessages, deleteMessage } = useMessages()
const { sendMessage } = useChat()
const regenerate = useCallback(() => {
// Only regenerate assistant message is allowed
deleteMessage(item.thread_id, item.id)
const threadMessages = getMessages(item.thread_id)
const lastMessage = threadMessages[threadMessages.length - 1]
if (!lastMessage) return
deleteMessage(lastMessage.thread_id, lastMessage.id)
sendMessage(lastMessage.content?.[0]?.text?.value || '')
}, [deleteMessage, getMessages, item, sendMessage])
const isToolCalls =
item.metadata &&
@ -170,17 +182,17 @@ export const ThreadContent = memo(
Delete
</span>
</button>
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={() => {
console.log('Regenerate clicked')
}}
>
<IconRefresh size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Regenerate
</span>
</button>
{item.isLastMessage && (
<button
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
onClick={regenerate}
>
<IconRefresh size={16} />
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
Regenerate
</span>
</button>
)}
</div>
</div>
)}

View File

@ -1,20 +1,27 @@
import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core'
import { MCPTool } from '@/types/completion'
type AppState = {
streamingContent?: ThreadMessage
loadingModel?: boolean
tools: MCPTool[]
updateStreamingContent: (content: ThreadMessage | undefined) => void
updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
}
export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined,
loadingModel: false,
tools: [],
updateStreamingContent: (content) => {
set({ streamingContent: content })
},
updateLoadingModel: (loading) => {
set({ loadingModel: loading })
},
updateTools: (tools) => {
set({ tools })
},
}))

View File

@ -0,0 +1,149 @@
import { useCallback, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
sendCompletion,
startModel,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const { tools } = useAppState()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { updateStreamingContent, updateLoadingModel } = useAppState()
const { addMessage } = useMessages()
const router = useRouter()
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
if (!currentThread) {
currentThread = await createThread(
{
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt
)
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
}
return currentThread
}, [
createThread,
prompt,
retrieveThread,
router,
selectedModel?.id,
selectedProvider,
])
const sendMessage = useCallback(
async (message: string) => {
const activeThread = await getCurrentThread()
if (!activeThread || !provider) return
updateStreamingContent(emptyThreadContent)
addMessage(newUserThreadContent(activeThread.id, message))
setPrompt('')
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(provider.provider, selectedModel.id).catch(
console.error
)
updateLoadingModel(false)
}
const builder = new CompletionMessagesBuilder()
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(message)
let isCompleted = false
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
tools
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
finalContent
)
addMessage(updatedMessage ?? finalContent)
isCompleted = !toolCalls.length
}
} catch (error) {
console.error('Error sending message:', error)
}
updateStreamingContent(undefined)
},
[
getCurrentThread,
provider,
updateStreamingContent,
addMessage,
setPrompt,
selectedModel,
tools,
updateLoadingModel,
]
)
return { sendMessage }
}

View File

@ -9,6 +9,7 @@ import {
type MessageState = {
messages: Record<string, ThreadMessage[]>
getMessages: (threadId: string) => ThreadMessage[]
setMessages: (threadId: string, messages: ThreadMessage[]) => void
addMessage: (message: ThreadMessage) => void
deleteMessage: (threadId: string, messageId: string) => void
@ -16,8 +17,11 @@ type MessageState = {
export const useMessages = create<MessageState>()(
persist(
(set) => ({
(set, get) => ({
messages: {},
getMessages: (threadId) => {
return get().messages[threadId] || []
},
setMessages: (threadId, messages) => {
set((state) => ({
messages: {