diff --git a/src-tauri/src/core/threads.rs b/src-tauri/src/core/threads.rs index e4b8f2e52..34b0bbd77 100644 --- a/src-tauri/src/core/threads.rs +++ b/src-tauri/src/core/threads.rs @@ -278,8 +278,10 @@ pub async fn create_message( ensure_thread_dir_exists(app_handle.clone(), &thread_id)?; let path = get_messages_path(app_handle.clone(), &thread_id); - let uuid = Uuid::new_v4().to_string(); - message["id"] = serde_json::Value::String(uuid); + if message.get("id").is_none() { + let uuid = Uuid::new_v4().to_string(); + message["id"] = serde_json::Value::String(uuid); + } // Acquire per-thread lock before writing { @@ -292,7 +294,7 @@ pub async fn create_message( let _guard = lock.lock().await; - let mut file = fs::OpenOptions::new() + let mut file: File = fs::OpenOptions::new() .create(true) .append(true) .open(path) @@ -354,21 +356,34 @@ pub async fn modify_message( /// Deletes a message from a thread's messages.jsonl file by message ID. /// Rewrites the entire messages.jsonl file for the thread. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. #[command] pub async fn delete_message( app_handle: tauri::AppHandle, thread_id: String, message_id: String, ) -> Result<(), String> { - let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?; - messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str())); + // Acquire per-thread lock before modifying + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock - // Rewrite remaining messages - let path = get_messages_path(app_handle.clone(), &thread_id); - let mut file = File::create(path).map_err(|e| e.to_string())?; - for msg in messages { - let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; - writeln!(file, "{}", data).map_err(|e| e.to_string())?; + let _guard = lock.lock().await; + + let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?; + messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str())); + + // Rewrite remaining messages + let path = get_messages_path(app_handle.clone(), &thread_id); + let mut file = File::create(path).map_err(|e| e.to_string())?; + for msg in messages { + let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } } Ok(()) @@ -441,7 +456,7 @@ pub async fn modify_thread_assistant( serde_json::from_str(&data).map_err(|e| e.to_string())? }; let assistant_id = assistant - .get("id") + .get("assistant_id") .and_then(|v| v.as_str()) .ok_or("Missing assistant_id")?; if let Some(assistants) = thread diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index a1fffa011..9590e5048 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -216,76 +216,66 @@ export default function ModelHandler() { model: activeModelRef.current?.name, } }) - return - } else if ( - message.status === MessageStatus.Error && - activeModelRef.current?.engine && - engines && - isLocalEngine(engines, activeModelRef.current.engine) - ) { - extensionManager - .get(ExtensionTypeEnum.Model) - ?.isModelLoaded(activeModelRef.current?.id as string) - .then((isLoaded) => { - if (!isLoaded) { - setActiveModel(undefined) - setStateModel({ - state: 'start', - loading: false, - model: undefined, - }) - } - }) - } - // Mark the thread as not waiting for response - updateThreadWaiting(message.thread_id, false) + } else { + // Mark the thread as not waiting for response + updateThreadWaiting(message.thread_id, false) - setIsGeneratingResponse(false) + setIsGeneratingResponse(false) - const thread = threadsRef.current?.find((e) => e.id == message.thread_id) - if (!thread) return + const thread = threadsRef.current?.find( + (e) => e.id == message.thread_id + ) + if (!thread) return - const messageContent = message.content[0]?.text?.value + const messageContent = message.content[0]?.text?.value - const metadata = { - ...thread.metadata, - ...(messageContent && { lastMessage: messageContent }), - updated_at: Date.now(), - } + const metadata = { + ...thread.metadata, + ...(messageContent && { lastMessage: messageContent }), + updated_at: Date.now(), + } - updateThread({ - ...thread, - metadata, - }) - - extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.modifyThread({ + updateThread({ ...thread, metadata, }) - // Update message's metadata with token usage - message.metadata = { - ...message.metadata, - token_speed: tokenSpeedRef.current?.tokenSpeed, - model: activeModelRef.current?.name, - } + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThread({ + ...thread, + metadata, + }) - if (message.status === MessageStatus.Error) { + // Update message's metadata with token usage message.metadata = { ...message.metadata, - error: message.content[0]?.text?.value, - error_code: message.error_code, + token_speed: tokenSpeedRef.current?.tokenSpeed, + model: activeModelRef.current?.name, } + + if (message.status === MessageStatus.Error) { + message.metadata = { + ...message.metadata, + error: message.content[0]?.text?.value, + error_code: message.error_code, + } + // Unassign active model if any + setActiveModel(undefined) + setStateModel({ + state: 'start', + loading: false, + model: undefined, + }) + } + + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(message) + + // Attempt to generate the title of the Thread when needed + generateThreadTitle(message, thread) } - - extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.createMessage(message) - - // Attempt to generate the title of the Thread when needed - generateThreadTitle(message, thread) }, // eslint-disable-next-line react-hooks/exhaustive-deps [setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting] diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index 59aa3a83b..d0c7cac1a 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -38,12 +38,13 @@ export default function useDeleteThread() { ?.listMessages(threadId) .catch(console.error) if (messages) { - messages.forEach((message) => { - extensionManager + for (const message of messages) { + await extensionManager .get(ExtensionTypeEnum.Conversational) ?.deleteMessage(threadId, message.id) .catch(console.error) - }) + } + const thread = threads.find((e) => e.id === threadId) if (thread) { const updatedThread = { diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index bfe3a601f..49e0d3e5b 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -135,12 +135,16 @@ export default function useSendChatMessage() { ) => { if (!message || message.trim().length === 0) return - if (!activeThreadRef.current || !activeAssistantRef.current) { + const activeThread = activeThreadRef.current + const activeAssistant = activeAssistantRef.current + const activeModel = selectedModelRef.current + + if (!activeThread || !activeAssistant) { console.error('No active thread or assistant') return } - if (selectedModelRef.current?.id === undefined) { + if (!activeModel?.id) { setModelDropdownState(true) return } @@ -153,7 +157,7 @@ export default function useSendChatMessage() { const prompt = message.trim() - updateThreadWaiting(activeThreadRef.current.id, true) + updateThreadWaiting(activeThread.id, true) setCurrentPrompt('') setEditPrompt('') @@ -164,15 +168,14 @@ export default function useSendChatMessage() { base64Blob = await compressImage(base64Blob, 512) } - const modelRequest = - selectedModelRef?.current ?? activeAssistantRef.current?.model + const modelRequest = selectedModel ?? activeAssistant.model // Fallback support for previous broken threads - if (activeAssistantRef.current?.model?.id === '*') { - activeAssistantRef.current.model = { - id: modelRequest.id, - settings: modelRequest.settings, - parameters: modelRequest.parameters, + if (activeAssistant.model?.id === '*') { + activeAssistant.model = { + id: activeModel.id, + settings: activeModel.settings, + parameters: activeModel.parameters, } } if (runtimeParams.stream == null) { @@ -187,7 +190,7 @@ export default function useSendChatMessage() { settings: settingParams, parameters: runtimeParams, }, - activeThreadRef.current, + activeThread, messages ?? currentMessages, (await window.core.api.getTools())?.map((tool: ModelTool) => ({ type: 'function' as const, @@ -198,7 +201,7 @@ export default function useSendChatMessage() { strict: false, }, })) - ).addSystemMessage(activeAssistantRef.current?.instructions) + ).addSystemMessage(activeAssistant.instructions) requestBuilder.pushMessage(prompt, base64Blob, fileUpload) @@ -211,10 +214,10 @@ export default function useSendChatMessage() { // Update thread state const updatedThread: Thread = { - ...activeThreadRef.current, + ...activeThread, updated: newMessage.created_at, metadata: { - ...activeThreadRef.current.metadata, + ...activeThread.metadata, lastMessage: prompt, }, } @@ -237,17 +240,16 @@ export default function useSendChatMessage() { } // Start Model if not started - const modelId = - selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id + const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id if (base64Blob) { setFileUpload(undefined) } - if (modelRef.current?.id !== modelId && modelId) { + if (activeModel?.id !== modelId && modelId) { const error = await startModel(modelId).catch((error: Error) => error) if (error) { - updateThreadWaiting(activeThreadRef.current.id, false) + updateThreadWaiting(activeThread.id, false) return } } @@ -271,8 +273,8 @@ export default function useSendChatMessage() { const message: ThreadMessage = { id: messageId, object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, + thread_id: activeThread.id, + assistant_id: activeAssistant.assistant_id, role: ChatCompletionRole.Assistant, content: [], metadata: { @@ -317,6 +319,8 @@ export default function useSendChatMessage() { message ) } + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) } } else { // Request for inference @@ -504,8 +508,6 @@ export default function useSendChatMessage() { events.emit(MessageEvent.OnMessageUpdate, message) } } - message.status = MessageStatus.Ready - events.emit(MessageEvent.OnMessageUpdate, message) } return { diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/ToolCallBlock.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/ToolCallBlock.tsx index 3ebfba7d3..818af5b4d 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/ToolCallBlock.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/ToolCallBlock.tsx @@ -46,7 +46,7 @@ const ToolCallBlock = ({ id, name, result, loading }: Props) => { {isExpanded && (
- {result.trim()} + {result ?? ''}
)}