chore: handle local models chat with MCP (#5065)
* chore: handle local models chat with MCP * chore: update MCP server connection status in the settings page * chore: error handling * chore: normalize message * chore: update finally block
This commit is contained in:
parent
aba75a7d2c
commit
4d66eaf0a7
@ -116,10 +116,13 @@ export const ThreadContent = memo(
|
|||||||
// Only regenerate assistant message is allowed
|
// Only regenerate assistant message is allowed
|
||||||
deleteMessage(item.thread_id, item.id)
|
deleteMessage(item.thread_id, item.id)
|
||||||
const threadMessages = getMessages(item.thread_id)
|
const threadMessages = getMessages(item.thread_id)
|
||||||
const lastMessage = threadMessages[threadMessages.length - 1]
|
let toSendMessage = threadMessages.pop()
|
||||||
if (!lastMessage) return
|
while (toSendMessage && toSendMessage?.role !== 'user') {
|
||||||
deleteMessage(lastMessage.thread_id, lastMessage.id)
|
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
|
||||||
sendMessage(lastMessage.content?.[0]?.text?.value || '')
|
toSendMessage = threadMessages.pop()
|
||||||
|
}
|
||||||
|
if (toSendMessage)
|
||||||
|
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
|
||||||
}, [deleteMessage, getMessages, item, sendMessage])
|
}, [deleteMessage, getMessages, item, sendMessage])
|
||||||
|
|
||||||
const editMessage = useCallback(
|
const editMessage = useCallback(
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import { route } from '@/constants/routes'
|
|||||||
import {
|
import {
|
||||||
emptyThreadContent,
|
emptyThreadContent,
|
||||||
extractToolCall,
|
extractToolCall,
|
||||||
|
isCompletionResponse,
|
||||||
newAssistantThreadContent,
|
newAssistantThreadContent,
|
||||||
newUserThreadContent,
|
newUserThreadContent,
|
||||||
postMessageProcessing,
|
postMessageProcessing,
|
||||||
@ -19,6 +20,7 @@ import {
|
|||||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
import { useAssistant } from './useAssistant'
|
import { useAssistant } from './useAssistant'
|
||||||
|
import { toast } from 'sonner'
|
||||||
|
|
||||||
export const useChat = () => {
|
export const useChat = () => {
|
||||||
const { prompt, setPrompt } = usePrompt()
|
const { prompt, setPrompt } = usePrompt()
|
||||||
@ -78,9 +80,7 @@ export const useChat = () => {
|
|||||||
try {
|
try {
|
||||||
if (selectedModel?.id) {
|
if (selectedModel?.id) {
|
||||||
updateLoadingModel(true)
|
updateLoadingModel(true)
|
||||||
await startModel(provider, selectedModel.id).catch(
|
await startModel(provider, selectedModel.id).catch(console.error)
|
||||||
console.error
|
|
||||||
)
|
|
||||||
updateLoadingModel(false)
|
updateLoadingModel(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,13 +100,21 @@ export const useChat = () => {
|
|||||||
provider,
|
provider,
|
||||||
builder.getMessages(),
|
builder.getMessages(),
|
||||||
abortController,
|
abortController,
|
||||||
tools
|
tools,
|
||||||
|
// TODO: replace it with according provider setting later on
|
||||||
|
selectedProvider === 'llama.cpp' && tools.length > 0 ? false : true
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!completion) throw new Error('No completion received')
|
if (!completion) throw new Error('No completion received')
|
||||||
let accumulatedText = ''
|
let accumulatedText = ''
|
||||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
|
if (isCompletionResponse(completion)) {
|
||||||
|
accumulatedText = completion.choices[0]?.message?.content || ''
|
||||||
|
if (completion.choices[0]?.message?.tool_calls) {
|
||||||
|
toolCalls.push(...completion.choices[0].message.tool_calls)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for await (const part of completion) {
|
for await (const part of completion) {
|
||||||
const delta = part.choices[0]?.delta?.content || ''
|
const delta = part.choices[0]?.delta?.content || ''
|
||||||
if (part.choices[0]?.delta?.tool_calls) {
|
if (part.choices[0]?.delta?.tool_calls) {
|
||||||
@ -125,6 +133,7 @@ export const useChat = () => {
|
|||||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// Create a final content object for adding to the thread
|
// Create a final content object for adding to the thread
|
||||||
const finalContent = newAssistantThreadContent(
|
const finalContent = newAssistantThreadContent(
|
||||||
activeThread.id,
|
activeThread.id,
|
||||||
@ -141,9 +150,14 @@ export const useChat = () => {
|
|||||||
isCompleted = !toolCalls.length
|
isCompleted = !toolCalls.length
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
toast.error(
|
||||||
|
`Error sending message: ${error && typeof error === 'object' && 'message' in error ? error.message : error}`
|
||||||
|
)
|
||||||
console.error('Error sending message:', error)
|
console.error('Error sending message:', error)
|
||||||
}
|
} finally {
|
||||||
|
updateLoadingModel(false)
|
||||||
updateStreamingContent(undefined)
|
updateStreamingContent(undefined)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
getCurrentThread,
|
getCurrentThread,
|
||||||
@ -157,6 +171,7 @@ export const useChat = () => {
|
|||||||
setAbortController,
|
setAbortController,
|
||||||
updateLoadingModel,
|
updateLoadingModel,
|
||||||
tools,
|
tools,
|
||||||
|
selectedProvider,
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import { invoke } from '@tauri-apps/api/core'
|
|||||||
import {
|
import {
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionTool,
|
ChatCompletionTool,
|
||||||
|
CompletionResponse,
|
||||||
CompletionResponseChunk,
|
CompletionResponseChunk,
|
||||||
models,
|
models,
|
||||||
StreamCompletionResponse,
|
StreamCompletionResponse,
|
||||||
@ -111,8 +112,9 @@ export const sendCompletion = async (
|
|||||||
provider: ModelProvider,
|
provider: ModelProvider,
|
||||||
messages: ChatCompletionMessageParam[],
|
messages: ChatCompletionMessageParam[],
|
||||||
abortController: AbortController,
|
abortController: AbortController,
|
||||||
tools: MCPTool[] = []
|
tools: MCPTool[] = [],
|
||||||
): Promise<StreamCompletionResponse | undefined> => {
|
stream: boolean = true
|
||||||
|
): Promise<StreamCompletionResponse | CompletionResponse | undefined> => {
|
||||||
if (!thread?.model?.id || !provider) return undefined
|
if (!thread?.model?.id || !provider) return undefined
|
||||||
|
|
||||||
let providerName = provider.provider as unknown as keyof typeof models
|
let providerName = provider.provider as unknown as keyof typeof models
|
||||||
@ -127,7 +129,8 @@ export const sendCompletion = async (
|
|||||||
})
|
})
|
||||||
|
|
||||||
// TODO: Add message history
|
// TODO: Add message history
|
||||||
const completion = await tokenJS.chat.completions.create(
|
const completion = stream
|
||||||
|
? await tokenJS.chat.completions.create(
|
||||||
{
|
{
|
||||||
stream: true,
|
stream: true,
|
||||||
provider: providerName,
|
provider: providerName,
|
||||||
@ -140,9 +143,23 @@ export const sendCompletion = async (
|
|||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
: await tokenJS.chat.completions.create({
|
||||||
|
stream: false,
|
||||||
|
provider: providerName,
|
||||||
|
model: thread.model?.id,
|
||||||
|
messages,
|
||||||
|
tools: normalizeTools(tools),
|
||||||
|
tool_choice: tools.length ? 'auto' : undefined,
|
||||||
|
})
|
||||||
return completion
|
return completion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const isCompletionResponse = (
|
||||||
|
response: StreamCompletionResponse | CompletionResponse
|
||||||
|
): response is CompletionResponse => {
|
||||||
|
return 'choices' in response
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @fileoverview Helper function to start a model.
|
* @fileoverview Helper function to start a model.
|
||||||
* This function loads the model from the provider.
|
* This function loads the model from the provider.
|
||||||
|
|||||||
@ -137,6 +137,12 @@ function MCPServers() {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
getConnectedServers().then(setConnectedServers)
|
getConnectedServers().then(setConnectedServers)
|
||||||
|
|
||||||
|
const intervalId = setInterval(() => {
|
||||||
|
getConnectedServers().then(setConnectedServers)
|
||||||
|
}, 5000)
|
||||||
|
|
||||||
|
return () => clearInterval(intervalId)
|
||||||
}, [setConnectedServers])
|
}, [setConnectedServers])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user