chore: handle chat functions (#5009)
This commit is contained in:
parent
c1091ce812
commit
74c2c59c90
@ -3,7 +3,7 @@
|
||||
import TextareaAutosize from 'react-textarea-autosize'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { usePrompt } from '@/hooks/usePrompt'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { ArrowRight } from 'lucide-react'
|
||||
import {
|
||||
@ -20,28 +20,14 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useGeneralSetting } from '@/hooks/useGeneralSetting'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import {
|
||||
emptyThreadContent,
|
||||
extractToolCall,
|
||||
newAssistantThreadContent,
|
||||
newUserThreadContent,
|
||||
postMessageProcessing,
|
||||
sendCompletion,
|
||||
startModel,
|
||||
} from '@/lib/completion'
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
import { defaultModel } from '@/lib/models'
|
||||
import { useMessages } from '@/hooks/useMessages'
|
||||
import { useRouter } from '@tanstack/react-router'
|
||||
import { route } from '@/constants/routes'
|
||||
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
import { MovingBorder } from './MovingBorder'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { listen } from '@tauri-apps/api/event'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { getTools } from '@/services/mcp'
|
||||
import { useChat } from '@/hooks/useChat'
|
||||
|
||||
type ChatInputProps = {
|
||||
className?: string
|
||||
@ -52,24 +38,14 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||
const [isFocused, setIsFocused] = useState(false)
|
||||
const [rows, setRows] = useState(1)
|
||||
const [tools, setTools] = useState<MCPTool[]>([])
|
||||
const { streamingContent, updateTools } = useAppState()
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { t } = useTranslation()
|
||||
const { spellCheckChatInput } = useGeneralSetting()
|
||||
const maxRows = 10
|
||||
|
||||
const { getProviderByName, selectedModel, selectedProvider } =
|
||||
useModelProvider()
|
||||
|
||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||
const { streamingContent, updateStreamingContent, updateLoadingModel } =
|
||||
useAppState()
|
||||
const { addMessage } = useMessages()
|
||||
const router = useRouter()
|
||||
|
||||
const provider = useMemo(() => {
|
||||
return getProviderByName(selectedProvider)
|
||||
}, [selectedProvider, getProviderByName])
|
||||
const { selectedModel } = useModelProvider()
|
||||
const { sendMessage } = useChat()
|
||||
|
||||
useEffect(() => {
|
||||
const handleFocusIn = () => {
|
||||
@ -94,20 +70,20 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
function updateTools() {
|
||||
function setTools() {
|
||||
getTools().then((data: MCPTool[]) => {
|
||||
setTools(data)
|
||||
updateTools(data)
|
||||
})
|
||||
}
|
||||
updateTools()
|
||||
setTools()
|
||||
|
||||
let unsubscribe = () => {}
|
||||
listen(SystemEvent.MCP_UPDATE, updateTools).then((unsub) => {
|
||||
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
|
||||
// Unsubscribe from the event when the component unmounts
|
||||
unsubscribe = unsub
|
||||
})
|
||||
return unsubscribe
|
||||
}, [])
|
||||
}, [updateTools])
|
||||
|
||||
useEffect(() => {
|
||||
if (textareaRef.current) {
|
||||
@ -115,115 +91,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const getCurrentThread = useCallback(async () => {
|
||||
let currentThread = retrieveThread()
|
||||
if (!currentThread) {
|
||||
currentThread = await createThread(
|
||||
{
|
||||
id: selectedModel?.id ?? defaultModel(selectedProvider),
|
||||
provider: selectedProvider,
|
||||
},
|
||||
prompt
|
||||
)
|
||||
router.navigate({
|
||||
to: route.threadsDetail,
|
||||
params: { threadId: currentThread.id },
|
||||
})
|
||||
}
|
||||
return currentThread
|
||||
}, [
|
||||
createThread,
|
||||
prompt,
|
||||
retrieveThread,
|
||||
router,
|
||||
selectedModel?.id,
|
||||
selectedProvider,
|
||||
])
|
||||
|
||||
const sendMessage = useCallback(async () => {
|
||||
const activeThread = await getCurrentThread()
|
||||
|
||||
if (!activeThread || !provider) return
|
||||
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
addMessage(newUserThreadContent(activeThread.id, prompt))
|
||||
setPrompt('')
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider.provider, selectedModel.id).catch(
|
||||
console.error
|
||||
)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
const builder = new CompletionMessagesBuilder()
|
||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||
// TODO: If not amend messages history here
|
||||
builder.addUserMessage(prompt)
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
while (!isCompleted) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
builder.getMessages(),
|
||||
tools
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
for await (const part of completion) {
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
}
|
||||
// Create a final content object for adding to the thread
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error sending message:', error)
|
||||
}
|
||||
updateStreamingContent(undefined)
|
||||
}, [
|
||||
getCurrentThread,
|
||||
provider,
|
||||
updateStreamingContent,
|
||||
addMessage,
|
||||
prompt,
|
||||
setPrompt,
|
||||
selectedModel,
|
||||
tools,
|
||||
updateLoadingModel,
|
||||
])
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div
|
||||
@ -266,7 +133,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey && prompt) {
|
||||
e.preventDefault()
|
||||
// Submit the message when Enter is pressed without Shift
|
||||
sendMessage()
|
||||
sendMessage(prompt)
|
||||
// When Shift+Enter is pressed, a new line is added (default behavior)
|
||||
}
|
||||
}}
|
||||
@ -351,7 +218,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => {
|
||||
variant={!prompt ? null : 'default'}
|
||||
size="icon"
|
||||
disabled={!prompt}
|
||||
onClick={sendMessage}
|
||||
onClick={() => sendMessage(prompt)}
|
||||
>
|
||||
{streamingContent ? (
|
||||
<span className="animate-spin h-4 w-4 border-2 border-current border-t-transparent rounded-full" />
|
||||
|
||||
@ -7,7 +7,7 @@ import { Progress } from '@/components/ui/progress'
|
||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||
import { abortDownload } from '@/services/models'
|
||||
import { DownloadEvent, DownloadState, events } from '@janhq/core'
|
||||
import { IconPlayerPauseFilled, IconX } from '@tabler/icons-react'
|
||||
import { IconX } from '@tabler/icons-react'
|
||||
import { useCallback, useEffect, useMemo } from 'react'
|
||||
|
||||
export function DownloadManagement() {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { ThreadMessage } from '@janhq/core'
|
||||
import { RenderMarkdown } from './RenderMarkdown'
|
||||
import { Fragment, memo, useMemo, useState } from 'react'
|
||||
import { Fragment, memo, useCallback, useMemo, useState } from 'react'
|
||||
import {
|
||||
IconCopy,
|
||||
IconCopyCheck,
|
||||
@ -13,6 +13,7 @@ import { cn } from '@/lib/utils'
|
||||
import { useMessages } from '@/hooks/useMessages'
|
||||
import ThinkingBlock from '@/containers/ThinkingBlock'
|
||||
import ToolCallBlock from '@/containers/ToolCallBlock'
|
||||
import { useChat } from '@/hooks/useChat'
|
||||
|
||||
const CopyButton = ({ text }: { text: string }) => {
|
||||
const [copied, setCopied] = useState(false)
|
||||
@ -25,7 +26,7 @@ const CopyButton = ({ text }: { text: string }) => {
|
||||
|
||||
return (
|
||||
<button
|
||||
className="flex items-center gap-1 hover:text-accent transition-colors group relative"
|
||||
className="flex items-center gap-1 hover:text-accent transition-colors group relative cursor-pointer"
|
||||
onClick={handleCopy}
|
||||
>
|
||||
{copied ? (
|
||||
@ -80,7 +81,18 @@ export const ThreadContent = memo(
|
||||
}
|
||||
}, [text])
|
||||
|
||||
const { deleteMessage } = useMessages()
|
||||
const { getMessages, deleteMessage } = useMessages()
|
||||
const { sendMessage } = useChat()
|
||||
|
||||
const regenerate = useCallback(() => {
|
||||
// Only regenerate assistant message is allowed
|
||||
deleteMessage(item.thread_id, item.id)
|
||||
const threadMessages = getMessages(item.thread_id)
|
||||
const lastMessage = threadMessages[threadMessages.length - 1]
|
||||
if (!lastMessage) return
|
||||
deleteMessage(lastMessage.thread_id, lastMessage.id)
|
||||
sendMessage(lastMessage.content?.[0]?.text?.value || '')
|
||||
}, [deleteMessage, getMessages, item, sendMessage])
|
||||
|
||||
const isToolCalls =
|
||||
item.metadata &&
|
||||
@ -170,17 +182,17 @@ export const ThreadContent = memo(
|
||||
Delete
|
||||
</span>
|
||||
</button>
|
||||
<button
|
||||
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
|
||||
onClick={() => {
|
||||
console.log('Regenerate clicked')
|
||||
}}
|
||||
>
|
||||
<IconRefresh size={16} />
|
||||
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
|
||||
Regenerate
|
||||
</span>
|
||||
</button>
|
||||
{item.isLastMessage && (
|
||||
<button
|
||||
className="flex items-center gap-1 hover:text-accent transition-colors cursor-pointer group relative"
|
||||
onClick={regenerate}
|
||||
>
|
||||
<IconRefresh size={16} />
|
||||
<span className="opacity-0 w-0 overflow-hidden whitespace-nowrap group-hover:w-auto group-hover:opacity-100 transition-all duration-300 ease-in-out">
|
||||
Regenerate
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -1,20 +1,27 @@
|
||||
import { create } from 'zustand'
|
||||
import { ThreadMessage } from '@janhq/core'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
|
||||
type AppState = {
|
||||
streamingContent?: ThreadMessage
|
||||
loadingModel?: boolean
|
||||
tools: MCPTool[]
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||
updateLoadingModel: (loading: boolean) => void
|
||||
updateTools: (tools: MCPTool[]) => void
|
||||
}
|
||||
|
||||
export const useAppState = create<AppState>()((set) => ({
|
||||
streamingContent: undefined,
|
||||
loadingModel: false,
|
||||
tools: [],
|
||||
updateStreamingContent: (content) => {
|
||||
set({ streamingContent: content })
|
||||
},
|
||||
updateLoadingModel: (loading) => {
|
||||
set({ loadingModel: loading })
|
||||
},
|
||||
updateTools: (tools) => {
|
||||
set({ tools })
|
||||
},
|
||||
}))
|
||||
|
||||
149
web-app/src/hooks/useChat.ts
Normal file
149
web-app/src/hooks/useChat.ts
Normal file
@ -0,0 +1,149 @@
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { usePrompt } from './usePrompt'
|
||||
import { useModelProvider } from './useModelProvider'
|
||||
import { useThreads } from './useThreads'
|
||||
import { useAppState } from './useAppState'
|
||||
import { useMessages } from './useMessages'
|
||||
import { useRouter } from '@tanstack/react-router'
|
||||
import { defaultModel } from '@/lib/models'
|
||||
import { route } from '@/constants/routes'
|
||||
import {
|
||||
emptyThreadContent,
|
||||
extractToolCall,
|
||||
newAssistantThreadContent,
|
||||
newUserThreadContent,
|
||||
postMessageProcessing,
|
||||
sendCompletion,
|
||||
startModel,
|
||||
} from '@/lib/completion'
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { tools } = useAppState()
|
||||
|
||||
const { getProviderByName, selectedModel, selectedProvider } =
|
||||
useModelProvider()
|
||||
|
||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||
const { updateStreamingContent, updateLoadingModel } = useAppState()
|
||||
const { addMessage } = useMessages()
|
||||
const router = useRouter()
|
||||
|
||||
const provider = useMemo(() => {
|
||||
return getProviderByName(selectedProvider)
|
||||
}, [selectedProvider, getProviderByName])
|
||||
const getCurrentThread = useCallback(async () => {
|
||||
let currentThread = retrieveThread()
|
||||
if (!currentThread) {
|
||||
currentThread = await createThread(
|
||||
{
|
||||
id: selectedModel?.id ?? defaultModel(selectedProvider),
|
||||
provider: selectedProvider,
|
||||
},
|
||||
prompt
|
||||
)
|
||||
router.navigate({
|
||||
to: route.threadsDetail,
|
||||
params: { threadId: currentThread.id },
|
||||
})
|
||||
}
|
||||
return currentThread
|
||||
}, [
|
||||
createThread,
|
||||
prompt,
|
||||
retrieveThread,
|
||||
router,
|
||||
selectedModel?.id,
|
||||
selectedProvider,
|
||||
])
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (message: string) => {
|
||||
const activeThread = await getCurrentThread()
|
||||
|
||||
if (!activeThread || !provider) return
|
||||
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
addMessage(newUserThreadContent(activeThread.id, message))
|
||||
setPrompt('')
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider.provider, selectedModel.id).catch(
|
||||
console.error
|
||||
)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
const builder = new CompletionMessagesBuilder()
|
||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||
// TODO: If not amend messages history here
|
||||
builder.addUserMessage(message)
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
while (!isCompleted) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
builder.getMessages(),
|
||||
tools
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
for await (const part of completion) {
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
}
|
||||
// Create a final content object for adding to the thread
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error sending message:', error)
|
||||
}
|
||||
updateStreamingContent(undefined)
|
||||
},
|
||||
[
|
||||
getCurrentThread,
|
||||
provider,
|
||||
updateStreamingContent,
|
||||
addMessage,
|
||||
setPrompt,
|
||||
selectedModel,
|
||||
tools,
|
||||
updateLoadingModel,
|
||||
]
|
||||
)
|
||||
|
||||
return { sendMessage }
|
||||
}
|
||||
@ -9,6 +9,7 @@ import {
|
||||
|
||||
type MessageState = {
|
||||
messages: Record<string, ThreadMessage[]>
|
||||
getMessages: (threadId: string) => ThreadMessage[]
|
||||
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
||||
addMessage: (message: ThreadMessage) => void
|
||||
deleteMessage: (threadId: string, messageId: string) => void
|
||||
@ -16,8 +17,11 @@ type MessageState = {
|
||||
|
||||
export const useMessages = create<MessageState>()(
|
||||
persist(
|
||||
(set) => ({
|
||||
(set, get) => ({
|
||||
messages: {},
|
||||
getMessages: (threadId) => {
|
||||
return get().messages[threadId] || []
|
||||
},
|
||||
setMessages: (threadId, messages) => {
|
||||
set((state) => ({
|
||||
messages: {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user