chore: handle local models chat with MCP (#5065)

* chore: handle local models chat with MCP

* chore: update MCP server connection status in the settings page

* chore: error handling

* chore: normalize message

* chore: update finally block
This commit is contained in:
Louis 2025-05-22 16:06:55 +07:00 committed by GitHub
parent aba75a7d2c
commit 4d66eaf0a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 39 deletions

View File

@ -116,10 +116,13 @@ export const ThreadContent = memo(
// Only regenerate assistant message is allowed // Only regenerate assistant message is allowed
deleteMessage(item.thread_id, item.id) deleteMessage(item.thread_id, item.id)
const threadMessages = getMessages(item.thread_id) const threadMessages = getMessages(item.thread_id)
const lastMessage = threadMessages[threadMessages.length - 1] let toSendMessage = threadMessages.pop()
if (!lastMessage) return while (toSendMessage && toSendMessage?.role !== 'user') {
deleteMessage(lastMessage.thread_id, lastMessage.id) deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
sendMessage(lastMessage.content?.[0]?.text?.value || '') toSendMessage = threadMessages.pop()
}
if (toSendMessage)
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
}, [deleteMessage, getMessages, item, sendMessage]) }, [deleteMessage, getMessages, item, sendMessage])
const editMessage = useCallback( const editMessage = useCallback(

View File

@ -10,6 +10,7 @@ import { route } from '@/constants/routes'
import { import {
emptyThreadContent, emptyThreadContent,
extractToolCall, extractToolCall,
isCompletionResponse,
newAssistantThreadContent, newAssistantThreadContent,
newUserThreadContent, newUserThreadContent,
postMessageProcessing, postMessageProcessing,
@ -19,6 +20,7 @@ import {
import { CompletionMessagesBuilder } from '@/lib/messages' import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant' import { useAssistant } from './useAssistant'
import { toast } from 'sonner'
export const useChat = () => { export const useChat = () => {
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
@ -78,9 +80,7 @@ export const useChat = () => {
try { try {
if (selectedModel?.id) { if (selectedModel?.id) {
updateLoadingModel(true) updateLoadingModel(true)
await startModel(provider, selectedModel.id).catch( await startModel(provider, selectedModel.id).catch(console.error)
console.error
)
updateLoadingModel(false) updateLoadingModel(false)
} }
@ -100,13 +100,21 @@ export const useChat = () => {
provider, provider,
builder.getMessages(), builder.getMessages(),
abortController, abortController,
tools tools,
// TODO: replace it with according provider setting later on
selectedProvider === 'llama.cpp' && tools.length > 0 ? false : true
) )
if (!completion) throw new Error('No completion received') if (!completion) throw new Error('No completion received')
let accumulatedText = '' let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = [] const toolCalls: ChatCompletionMessageToolCall[] = []
if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
}
} else {
for await (const part of completion) { for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || '' const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) { if (part.choices[0]?.delta?.tool_calls) {
@ -125,6 +133,7 @@ export const useChat = () => {
await new Promise((resolve) => setTimeout(resolve, 0)) await new Promise((resolve) => setTimeout(resolve, 0))
} }
} }
}
// Create a final content object for adding to the thread // Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent( const finalContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
@ -141,9 +150,14 @@ export const useChat = () => {
isCompleted = !toolCalls.length isCompleted = !toolCalls.length
} }
} catch (error) { } catch (error) {
toast.error(
`Error sending message: ${error && typeof error === 'object' && 'message' in error ? error.message : error}`
)
console.error('Error sending message:', error) console.error('Error sending message:', error)
} } finally {
updateLoadingModel(false)
updateStreamingContent(undefined) updateStreamingContent(undefined)
}
}, },
[ [
getCurrentThread, getCurrentThread,
@ -157,6 +171,7 @@ export const useChat = () => {
setAbortController, setAbortController,
updateLoadingModel, updateLoadingModel,
tools, tools,
selectedProvider,
updateTokenSpeed, updateTokenSpeed,
] ]
) )

View File

@ -10,6 +10,7 @@ import { invoke } from '@tauri-apps/api/core'
import { import {
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionTool, ChatCompletionTool,
CompletionResponse,
CompletionResponseChunk, CompletionResponseChunk,
models, models,
StreamCompletionResponse, StreamCompletionResponse,
@ -111,8 +112,9 @@ export const sendCompletion = async (
provider: ModelProvider, provider: ModelProvider,
messages: ChatCompletionMessageParam[], messages: ChatCompletionMessageParam[],
abortController: AbortController, abortController: AbortController,
tools: MCPTool[] = [] tools: MCPTool[] = [],
): Promise<StreamCompletionResponse | undefined> => { stream: boolean = true
): Promise<StreamCompletionResponse | CompletionResponse | undefined> => {
if (!thread?.model?.id || !provider) return undefined if (!thread?.model?.id || !provider) return undefined
let providerName = provider.provider as unknown as keyof typeof models let providerName = provider.provider as unknown as keyof typeof models
@ -127,7 +129,8 @@ export const sendCompletion = async (
}) })
// TODO: Add message history // TODO: Add message history
const completion = await tokenJS.chat.completions.create( const completion = stream
? await tokenJS.chat.completions.create(
{ {
stream: true, stream: true,
provider: providerName, provider: providerName,
@ -140,9 +143,23 @@ export const sendCompletion = async (
signal: abortController.signal, signal: abortController.signal,
} }
) )
: await tokenJS.chat.completions.create({
stream: false,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
})
return completion return completion
} }
export const isCompletionResponse = (
response: StreamCompletionResponse | CompletionResponse
): response is CompletionResponse => {
return 'choices' in response
}
/** /**
* @fileoverview Helper function to start a model. * @fileoverview Helper function to start a model.
* This function loads the model from the provider. * This function loads the model from the provider.

View File

@ -137,6 +137,12 @@ function MCPServers() {
useEffect(() => { useEffect(() => {
getConnectedServers().then(setConnectedServers) getConnectedServers().then(setConnectedServers)
const intervalId = setInterval(() => {
getConnectedServers().then(setConnectedServers)
}, 5000)
return () => clearInterval(intervalId)
}, [setConnectedServers]) }, [setConnectedServers])
return ( return (