chore: handle local models chat with MCP (#5065)

* chore: handle local models chat with MCP

* chore: update MCP server connection status in the settings page

* chore: error handling

* chore: normalize message

* chore: update finally block
This commit is contained in:
Louis 2025-05-22 16:06:55 +07:00 committed by GitHub
parent aba75a7d2c
commit 4d66eaf0a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 39 deletions

View File

@ -116,10 +116,13 @@ export const ThreadContent = memo(
// Only regenerate assistant message is allowed
deleteMessage(item.thread_id, item.id)
const threadMessages = getMessages(item.thread_id)
const lastMessage = threadMessages[threadMessages.length - 1]
if (!lastMessage) return
deleteMessage(lastMessage.thread_id, lastMessage.id)
sendMessage(lastMessage.content?.[0]?.text?.value || '')
let toSendMessage = threadMessages.pop()
while (toSendMessage && toSendMessage?.role !== 'user') {
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
toSendMessage = threadMessages.pop()
}
if (toSendMessage)
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
}, [deleteMessage, getMessages, item, sendMessage])
const editMessage = useCallback(

View File

@ -10,6 +10,7 @@ import { route } from '@/constants/routes'
import {
emptyThreadContent,
extractToolCall,
isCompletionResponse,
newAssistantThreadContent,
newUserThreadContent,
postMessageProcessing,
@ -19,6 +20,7 @@ import {
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { toast } from 'sonner'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
@ -78,9 +80,7 @@ export const useChat = () => {
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await startModel(provider, selectedModel.id).catch(
console.error
)
await startModel(provider, selectedModel.id).catch(console.error)
updateLoadingModel(false)
}
@ -100,29 +100,38 @@ export const useChat = () => {
provider,
builder.getMessages(),
abortController,
tools
tools,
// TODO: replace it with according provider setting later on
selectedProvider === 'llama.cpp' && tools.length > 0 ? false : true
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
} else {
for await (const part of completion) {
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
}
if (delta) {
accumulatedText += delta
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
}
}
}
// Create a final content object for adding to the thread
@ -141,9 +150,14 @@ export const useChat = () => {
isCompleted = !toolCalls.length
}
} catch (error) {
toast.error(
`Error sending message: ${error && typeof error === 'object' && 'message' in error ? error.message : error}`
)
console.error('Error sending message:', error)
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
}
updateStreamingContent(undefined)
},
[
getCurrentThread,
@ -157,6 +171,7 @@ export const useChat = () => {
setAbortController,
updateLoadingModel,
tools,
selectedProvider,
updateTokenSpeed,
]
)

View File

@ -10,6 +10,7 @@ import { invoke } from '@tauri-apps/api/core'
import {
ChatCompletionMessageParam,
ChatCompletionTool,
CompletionResponse,
CompletionResponseChunk,
models,
StreamCompletionResponse,
@ -111,8 +112,9 @@ export const sendCompletion = async (
provider: ModelProvider,
messages: ChatCompletionMessageParam[],
abortController: AbortController,
tools: MCPTool[] = []
): Promise<StreamCompletionResponse | undefined> => {
tools: MCPTool[] = [],
stream: boolean = true
): Promise<StreamCompletionResponse | CompletionResponse | undefined> => {
if (!thread?.model?.id || !provider) return undefined
let providerName = provider.provider as unknown as keyof typeof models
@ -127,22 +129,37 @@ export const sendCompletion = async (
})
// TODO: Add message history
const completion = await tokenJS.chat.completions.create(
{
stream: true,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
},
{
signal: abortController.signal,
}
)
const completion = stream
? await tokenJS.chat.completions.create(
{
stream: true,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
},
{
signal: abortController.signal,
}
)
: await tokenJS.chat.completions.create({
stream: false,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
})
return completion
}
export const isCompletionResponse = (
response: StreamCompletionResponse | CompletionResponse
): response is CompletionResponse => {
return 'choices' in response
}
/**
* @fileoverview Helper function to start a model.
* This function loads the model from the provider.

View File

@ -137,6 +137,12 @@ function MCPServers() {
useEffect(() => {
getConnectedServers().then(setConnectedServers)
const intervalId = setInterval(() => {
getConnectedServers().then(setConnectedServers)
}, 5000)
return () => clearInterval(intervalId)
}, [setConnectedServers])
return (