fix: tool outputs are gone after switching to another thread

This commit is contained in:
Louis 2025-04-17 21:09:53 +07:00
parent 19146fec6a
commit c4ae61dd75
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
5 changed files with 102 additions and 94 deletions

View File

@ -278,8 +278,10 @@ pub async fn create_message<R: Runtime>(
ensure_thread_dir_exists(app_handle.clone(), &thread_id)?; ensure_thread_dir_exists(app_handle.clone(), &thread_id)?;
let path = get_messages_path(app_handle.clone(), &thread_id); let path = get_messages_path(app_handle.clone(), &thread_id);
if message.get("id").is_none() {
let uuid = Uuid::new_v4().to_string(); let uuid = Uuid::new_v4().to_string();
message["id"] = serde_json::Value::String(uuid); message["id"] = serde_json::Value::String(uuid);
}
// Acquire per-thread lock before writing // Acquire per-thread lock before writing
{ {
@ -292,7 +294,7 @@ pub async fn create_message<R: Runtime>(
let _guard = lock.lock().await; let _guard = lock.lock().await;
let mut file = fs::OpenOptions::new() let mut file: File = fs::OpenOptions::new()
.create(true) .create(true)
.append(true) .append(true)
.open(path) .open(path)
@ -354,12 +356,24 @@ pub async fn modify_message<R: Runtime>(
/// Deletes a message from a thread's messages.jsonl file by message ID. /// Deletes a message from a thread's messages.jsonl file by message ID.
/// Rewrites the entire messages.jsonl file for the thread. /// Rewrites the entire messages.jsonl file for the thread.
/// Uses a per-thread async lock to prevent race conditions and ensure file consistency.
#[command] #[command]
pub async fn delete_message<R: Runtime>( pub async fn delete_message<R: Runtime>(
app_handle: tauri::AppHandle<R>, app_handle: tauri::AppHandle<R>,
thread_id: String, thread_id: String,
message_id: String, message_id: String,
) -> Result<(), String> { ) -> Result<(), String> {
// Acquire per-thread lock before modifying
{
let mut locks = MESSAGE_LOCKS.lock().await;
let lock = locks
.entry(thread_id.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
drop(locks); // Release the map lock before awaiting the file lock
let _guard = lock.lock().await;
let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?; let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?;
messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str())); messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str()));
@ -370,6 +384,7 @@ pub async fn delete_message<R: Runtime>(
let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
writeln!(file, "{}", data).map_err(|e| e.to_string())?; writeln!(file, "{}", data).map_err(|e| e.to_string())?;
} }
}
Ok(()) Ok(())
} }
@ -441,7 +456,7 @@ pub async fn modify_thread_assistant<R: Runtime>(
serde_json::from_str(&data).map_err(|e| e.to_string())? serde_json::from_str(&data).map_err(|e| e.to_string())?
}; };
let assistant_id = assistant let assistant_id = assistant
.get("id") .get("assistant_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or("Missing assistant_id")?; .ok_or("Missing assistant_id")?;
if let Some(assistants) = thread if let Some(assistants) = thread

View File

@ -216,33 +216,15 @@ export default function ModelHandler() {
model: activeModelRef.current?.name, model: activeModelRef.current?.name,
} }
}) })
return } else {
} else if (
message.status === MessageStatus.Error &&
activeModelRef.current?.engine &&
engines &&
isLocalEngine(engines, activeModelRef.current.engine)
) {
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.isModelLoaded(activeModelRef.current?.id as string)
.then((isLoaded) => {
if (!isLoaded) {
setActiveModel(undefined)
setStateModel({
state: 'start',
loading: false,
model: undefined,
})
}
})
}
// Mark the thread as not waiting for response // Mark the thread as not waiting for response
updateThreadWaiting(message.thread_id, false) updateThreadWaiting(message.thread_id, false)
setIsGeneratingResponse(false) setIsGeneratingResponse(false)
const thread = threadsRef.current?.find((e) => e.id == message.thread_id) const thread = threadsRef.current?.find(
(e) => e.id == message.thread_id
)
if (!thread) return if (!thread) return
const messageContent = message.content[0]?.text?.value const messageContent = message.content[0]?.text?.value
@ -278,6 +260,13 @@ export default function ModelHandler() {
error: message.content[0]?.text?.value, error: message.content[0]?.text?.value,
error_code: message.error_code, error_code: message.error_code,
} }
// Unassign active model if any
setActiveModel(undefined)
setStateModel({
state: 'start',
loading: false,
model: undefined,
})
} }
extensionManager extensionManager
@ -286,6 +275,7 @@ export default function ModelHandler() {
// Attempt to generate the title of the Thread when needed // Attempt to generate the title of the Thread when needed
generateThreadTitle(message, thread) generateThreadTitle(message, thread)
}
}, },
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
[setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting] [setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting]

View File

@ -38,12 +38,13 @@ export default function useDeleteThread() {
?.listMessages(threadId) ?.listMessages(threadId)
.catch(console.error) .catch(console.error)
if (messages) { if (messages) {
messages.forEach((message) => { for (const message of messages) {
extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(threadId, message.id) ?.deleteMessage(threadId, message.id)
.catch(console.error) .catch(console.error)
}) }
const thread = threads.find((e) => e.id === threadId) const thread = threads.find((e) => e.id === threadId)
if (thread) { if (thread) {
const updatedThread = { const updatedThread = {

View File

@ -135,12 +135,16 @@ export default function useSendChatMessage() {
) => { ) => {
if (!message || message.trim().length === 0) return if (!message || message.trim().length === 0) return
if (!activeThreadRef.current || !activeAssistantRef.current) { const activeThread = activeThreadRef.current
const activeAssistant = activeAssistantRef.current
const activeModel = selectedModelRef.current
if (!activeThread || !activeAssistant) {
console.error('No active thread or assistant') console.error('No active thread or assistant')
return return
} }
if (selectedModelRef.current?.id === undefined) { if (!activeModel?.id) {
setModelDropdownState(true) setModelDropdownState(true)
return return
} }
@ -153,7 +157,7 @@ export default function useSendChatMessage() {
const prompt = message.trim() const prompt = message.trim()
updateThreadWaiting(activeThreadRef.current.id, true) updateThreadWaiting(activeThread.id, true)
setCurrentPrompt('') setCurrentPrompt('')
setEditPrompt('') setEditPrompt('')
@ -164,15 +168,14 @@ export default function useSendChatMessage() {
base64Blob = await compressImage(base64Blob, 512) base64Blob = await compressImage(base64Blob, 512)
} }
const modelRequest = const modelRequest = selectedModel ?? activeAssistant.model
selectedModelRef?.current ?? activeAssistantRef.current?.model
// Fallback support for previous broken threads // Fallback support for previous broken threads
if (activeAssistantRef.current?.model?.id === '*') { if (activeAssistant.model?.id === '*') {
activeAssistantRef.current.model = { activeAssistant.model = {
id: modelRequest.id, id: activeModel.id,
settings: modelRequest.settings, settings: activeModel.settings,
parameters: modelRequest.parameters, parameters: activeModel.parameters,
} }
} }
if (runtimeParams.stream == null) { if (runtimeParams.stream == null) {
@ -187,7 +190,7 @@ export default function useSendChatMessage() {
settings: settingParams, settings: settingParams,
parameters: runtimeParams, parameters: runtimeParams,
}, },
activeThreadRef.current, activeThread,
messages ?? currentMessages, messages ?? currentMessages,
(await window.core.api.getTools())?.map((tool: ModelTool) => ({ (await window.core.api.getTools())?.map((tool: ModelTool) => ({
type: 'function' as const, type: 'function' as const,
@ -198,7 +201,7 @@ export default function useSendChatMessage() {
strict: false, strict: false,
}, },
})) }))
).addSystemMessage(activeAssistantRef.current?.instructions) ).addSystemMessage(activeAssistant.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload) requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
@ -211,10 +214,10 @@ export default function useSendChatMessage() {
// Update thread state // Update thread state
const updatedThread: Thread = { const updatedThread: Thread = {
...activeThreadRef.current, ...activeThread,
updated: newMessage.created_at, updated: newMessage.created_at,
metadata: { metadata: {
...activeThreadRef.current.metadata, ...activeThread.metadata,
lastMessage: prompt, lastMessage: prompt,
}, },
} }
@ -237,17 +240,16 @@ export default function useSendChatMessage() {
} }
// Start Model if not started // Start Model if not started
const modelId = const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) { if (base64Blob) {
setFileUpload(undefined) setFileUpload(undefined)
} }
if (modelRef.current?.id !== modelId && modelId) { if (activeModel?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error) const error = await startModel(modelId).catch((error: Error) => error)
if (error) { if (error) {
updateThreadWaiting(activeThreadRef.current.id, false) updateThreadWaiting(activeThread.id, false)
return return
} }
} }
@ -271,8 +273,8 @@ export default function useSendChatMessage() {
const message: ThreadMessage = { const message: ThreadMessage = {
id: messageId, id: messageId,
object: 'message', object: 'message',
thread_id: activeThreadRef.current.id, thread_id: activeThread.id,
assistant_id: activeAssistantRef.current.assistant_id, assistant_id: activeAssistant.assistant_id,
role: ChatCompletionRole.Assistant, role: ChatCompletionRole.Assistant,
content: [], content: [],
metadata: { metadata: {
@ -317,6 +319,8 @@ export default function useSendChatMessage() {
message message
) )
} }
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
} }
} else { } else {
// Request for inference // Request for inference
@ -504,8 +508,6 @@ export default function useSendChatMessage() {
events.emit(MessageEvent.OnMessageUpdate, message) events.emit(MessageEvent.OnMessageUpdate, message)
} }
} }
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
} }
return { return {

View File

@ -46,7 +46,7 @@ const ToolCallBlock = ({ id, name, result, loading }: Props) => {
{isExpanded && ( {isExpanded && (
<div className="mt-2 overflow-x-hidden pl-6 text-[hsla(var(--text-secondary))]"> <div className="mt-2 overflow-x-hidden pl-6 text-[hsla(var(--text-secondary))]">
<span>{result.trim()} </span> <span>{result ?? ''} </span>
</div> </div>
)} )}
</div> </div>