chore: handle local models chat with MCP (#5065)
* chore: handle local models chat with MCP * chore: update MCP server connection status in the settings page * chore: error handling * chore: normalize message * chore: update finally block
This commit is contained in:
parent
aba75a7d2c
commit
4d66eaf0a7
@ -116,10 +116,13 @@ export const ThreadContent = memo(
|
||||
// Only regenerate assistant message is allowed
|
||||
deleteMessage(item.thread_id, item.id)
|
||||
const threadMessages = getMessages(item.thread_id)
|
||||
const lastMessage = threadMessages[threadMessages.length - 1]
|
||||
if (!lastMessage) return
|
||||
deleteMessage(lastMessage.thread_id, lastMessage.id)
|
||||
sendMessage(lastMessage.content?.[0]?.text?.value || '')
|
||||
let toSendMessage = threadMessages.pop()
|
||||
while (toSendMessage && toSendMessage?.role !== 'user') {
|
||||
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
|
||||
toSendMessage = threadMessages.pop()
|
||||
}
|
||||
if (toSendMessage)
|
||||
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
|
||||
}, [deleteMessage, getMessages, item, sendMessage])
|
||||
|
||||
const editMessage = useCallback(
|
||||
|
||||
@ -10,6 +10,7 @@ import { route } from '@/constants/routes'
|
||||
import {
|
||||
emptyThreadContent,
|
||||
extractToolCall,
|
||||
isCompletionResponse,
|
||||
newAssistantThreadContent,
|
||||
newUserThreadContent,
|
||||
postMessageProcessing,
|
||||
@ -19,6 +20,7 @@ import {
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { useAssistant } from './useAssistant'
|
||||
import { toast } from 'sonner'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
@ -78,9 +80,7 @@ export const useChat = () => {
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider, selectedModel.id).catch(
|
||||
console.error
|
||||
)
|
||||
await startModel(provider, selectedModel.id).catch(console.error)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
@ -100,29 +100,38 @@ export const useChat = () => {
|
||||
provider,
|
||||
builder.getMessages(),
|
||||
abortController,
|
||||
tools
|
||||
tools,
|
||||
// TODO: replace it with according provider setting later on
|
||||
selectedProvider === 'llama.cpp' && tools.length > 0 ? false : true
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
for await (const part of completion) {
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
if (isCompletionResponse(completion)) {
|
||||
accumulatedText = completion.choices[0]?.message?.content || ''
|
||||
if (completion.choices[0]?.message?.tool_calls) {
|
||||
toolCalls.push(...completion.choices[0].message.tool_calls)
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
updateTokenSpeed(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
} else {
|
||||
for await (const part of completion) {
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
updateTokenSpeed(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create a final content object for adding to the thread
|
||||
@ -141,9 +150,14 @@ export const useChat = () => {
|
||||
isCompleted = !toolCalls.length
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error(
|
||||
`Error sending message: ${error && typeof error === 'object' && 'message' in error ? error.message : error}`
|
||||
)
|
||||
console.error('Error sending message:', error)
|
||||
} finally {
|
||||
updateLoadingModel(false)
|
||||
updateStreamingContent(undefined)
|
||||
}
|
||||
updateStreamingContent(undefined)
|
||||
},
|
||||
[
|
||||
getCurrentThread,
|
||||
@ -157,6 +171,7 @@ export const useChat = () => {
|
||||
setAbortController,
|
||||
updateLoadingModel,
|
||||
tools,
|
||||
selectedProvider,
|
||||
updateTokenSpeed,
|
||||
]
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ import { invoke } from '@tauri-apps/api/core'
|
||||
import {
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionTool,
|
||||
CompletionResponse,
|
||||
CompletionResponseChunk,
|
||||
models,
|
||||
StreamCompletionResponse,
|
||||
@ -111,8 +112,9 @@ export const sendCompletion = async (
|
||||
provider: ModelProvider,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
abortController: AbortController,
|
||||
tools: MCPTool[] = []
|
||||
): Promise<StreamCompletionResponse | undefined> => {
|
||||
tools: MCPTool[] = [],
|
||||
stream: boolean = true
|
||||
): Promise<StreamCompletionResponse | CompletionResponse | undefined> => {
|
||||
if (!thread?.model?.id || !provider) return undefined
|
||||
|
||||
let providerName = provider.provider as unknown as keyof typeof models
|
||||
@ -127,22 +129,37 @@ export const sendCompletion = async (
|
||||
})
|
||||
|
||||
// TODO: Add message history
|
||||
const completion = await tokenJS.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
},
|
||||
{
|
||||
signal: abortController.signal,
|
||||
}
|
||||
)
|
||||
const completion = stream
|
||||
? await tokenJS.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
},
|
||||
{
|
||||
signal: abortController.signal,
|
||||
}
|
||||
)
|
||||
: await tokenJS.chat.completions.create({
|
||||
stream: false,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
})
|
||||
return completion
|
||||
}
|
||||
|
||||
export const isCompletionResponse = (
|
||||
response: StreamCompletionResponse | CompletionResponse
|
||||
): response is CompletionResponse => {
|
||||
return 'choices' in response
|
||||
}
|
||||
|
||||
/**
|
||||
* @fileoverview Helper function to start a model.
|
||||
* This function loads the model from the provider.
|
||||
|
||||
@ -137,6 +137,12 @@ function MCPServers() {
|
||||
|
||||
useEffect(() => {
|
||||
getConnectedServers().then(setConnectedServers)
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
getConnectedServers().then(setConnectedServers)
|
||||
}, 5000)
|
||||
|
||||
return () => clearInterval(intervalId)
|
||||
}, [setConnectedServers])
|
||||
|
||||
return (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user