diff --git a/core/src/types/inference/inferenceEntity.ts b/core/src/types/inference/inferenceEntity.ts index c37e3b079..ac2e48d32 100644 --- a/core/src/types/inference/inferenceEntity.ts +++ b/core/src/types/inference/inferenceEntity.ts @@ -7,6 +7,7 @@ export enum ChatCompletionRole { System = 'system', Assistant = 'assistant', User = 'user', + Tool = 'tool', } /** @@ -18,6 +19,9 @@ export type ChatCompletionMessage = { content?: ChatCompletionMessageContent /** The role of the author of this message. **/ role: ChatCompletionRole + type?: string + output?: string + tool_call_id?: string } export type ChatCompletionMessageContent = diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 280ce75a3..20979c68e 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -36,6 +36,8 @@ export type ThreadMessage = { type?: string /** The error code which explain what error type. Used in conjunction with MessageStatus.Error */ error_code?: ErrorCode + + tool_call_id?: string } /** diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json index a5224b99b..693adf6d6 100644 --- a/extensions/conversational-extension/package.json +++ b/extensions/conversational-extension/package.json @@ -23,9 +23,7 @@ "typescript": "^5.7.2" }, "dependencies": { - "@janhq/core": "../../core/package.tgz", - "ky": "^1.7.2", - "p-queue": "^8.0.1" + "@janhq/core": "../../core/package.tgz" }, "engines": { "node": ">=18.0.0" diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index e2e068939..720291d88 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -4,40 +4,12 @@ import { ThreadAssistantInfo, ThreadMessage, } from '@janhq/core' -import ky, { KyInstance } from 'ky' - -type ThreadList = { - data: Thread[] -} - -type MessageList = { - data: ThreadMessage[] -} /** * JSONConversationalExtension is a ConversationalExtension implementation that provides * functionality for managing threads. */ export default class CortexConversationalExtension extends ConversationalExtension { - api?: KyInstance - /** - * Get the API instance - * @returns - */ - async apiInstance(): Promise { - if (this.api) return this.api - const apiKey = (await window.core?.api.appToken()) - this.api = ky.extend({ - prefixUrl: API_URL, - headers: apiKey - ? { - Authorization: `Bearer ${apiKey}`, - } - : {}, - retry: 10, - }) - return this.api - } /** * Called when the extension is loaded. */ @@ -54,12 +26,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * Returns a Promise that resolves to an array of Conversation objects. */ async listThreads(): Promise { - return this.apiInstance().then((api) => - api - .get('v1/threads?limit=-1') - .json() - .then((e) => e.data) - ) as Promise + return window.core.api.listThreads() } /** @@ -67,9 +34,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param thread The Thread object to save. */ async createThread(thread: Thread): Promise { - return this.apiInstance().then((api) => - api.post('v1/threads', { json: thread }).json() - ) as Promise + return window.core.api.createThread({ thread }) } /** @@ -77,10 +42,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param thread The Thread object to save. */ async modifyThread(thread: Thread): Promise { - return this.apiInstance() - .then((api) => api.patch(`v1/threads/${thread.id}`, { json: thread })) - - .then() + return window.core.api.modifyThread({ thread }) } /** @@ -88,9 +50,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param threadId The ID of the thread to delete. */ async deleteThread(threadId: string): Promise { - return this.apiInstance() - .then((api) => api.delete(`v1/threads/${threadId}`)) - .then() + return window.core.api.deleteThread({ threadId }) } /** @@ -99,13 +59,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves when the message has been added. */ async createMessage(message: ThreadMessage): Promise { - return this.apiInstance().then((api) => - api - .post(`v1/threads/${message.thread_id}/messages`, { - json: message, - }) - .json() - ) as Promise + return window.core.api.createMessage({ message }) } /** @@ -114,13 +68,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns */ async modifyMessage(message: ThreadMessage): Promise { - return this.apiInstance().then((api) => - api - .patch(`v1/threads/${message.thread_id}/messages/${message.id}`, { - json: message, - }) - .json() - ) as Promise + return window.core.api.modifyMessage({ message }) } /** @@ -130,9 +78,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves when the message has been successfully deleted. */ async deleteMessage(threadId: string, messageId: string): Promise { - return this.apiInstance() - .then((api) => api.delete(`v1/threads/${threadId}/messages/${messageId}`)) - .then() + return window.core.api.deleteMessage({ threadId, messageId }) } /** @@ -141,12 +87,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves to an array of ThreadMessage objects. */ async listMessages(threadId: string): Promise { - return this.apiInstance().then((api) => - api - .get(`v1/threads/${threadId}/messages?order=asc&limit=-1`) - .json() - .then((e) => e.data) - ) as Promise + return window.core.api.listMessages({ threadId }) } /** @@ -156,9 +97,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * the details of the assistant associated with the specified thread. */ async getThreadAssistant(threadId: string): Promise { - return this.apiInstance().then((api) => - api.get(`v1/assistants/${threadId}?limit=-1`).json() - ) as Promise + return window.core.api.getThreadAssistant({ threadId }) } /** * Creates a new assistant for the specified thread. @@ -170,11 +109,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi threadId: string, assistant: ThreadAssistantInfo ): Promise { - return this.apiInstance().then((api) => - api - .post(`v1/assistants/${threadId}`, { json: assistant }) - .json() - ) as Promise + return window.core.api.createThreadAssistant(threadId, assistant) } /** @@ -187,10 +122,6 @@ export default class CortexConversationalExtension extends ConversationalExtensi threadId: string, assistant: ThreadAssistantInfo ): Promise { - return this.apiInstance().then((api) => - api - .patch(`v1/assistants/${threadId}`, { json: assistant }) - .json() - ) as Promise + return window.core.api.modifyThreadAssistant({ threadId, assistant }) } } diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 0c2e9236e..3e6fae702 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -38,6 +38,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai "transport-child-process", "tower", ] } +uuid = { version = "1.7", features = ["v4"] } [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-updater = "2" +once_cell = "1.18" diff --git a/src-tauri/src/core/cmd.rs b/src-tauri/src/core/cmd.rs index 0fe706a1a..3d7d921ee 100644 --- a/src-tauri/src/core/cmd.rs +++ b/src-tauri/src/core/cmd.rs @@ -7,6 +7,9 @@ use tauri::{AppHandle, Manager, Runtime, State}; use super::{server, setup, state::AppState}; const CONFIGURATION_FILE_NAME: &str = "settings.json"; +const DEFAULT_MCP_CONFIG: &str = r#"{ + "mcpServers": {} +}"#; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct AppConfiguration { @@ -93,6 +96,10 @@ pub fn update_app_configuration( #[tauri::command] pub fn get_jan_data_folder_path(app_handle: tauri::AppHandle) -> PathBuf { + if cfg!(test) { + return PathBuf::from("./data"); + } + let app_configurations = get_app_configurations(app_handle); PathBuf::from(app_configurations.data_folder) } @@ -348,3 +355,29 @@ pub async fn call_tool( Err(format!("Tool {} not found", tool_name)) } + +#[tauri::command] +pub async fn get_mcp_configs(app: AppHandle) -> Result { + let mut path = get_jan_data_folder_path(app); + path.push("mcp_config.json"); + log::info!("read mcp configs, path: {:?}", path); + + // Create default empty config if file doesn't exist + if !path.exists() { + log::info!("mcp_config.json not found, creating default empty config"); + fs::write(&path, DEFAULT_MCP_CONFIG) + .map_err(|e| format!("Failed to create default MCP config: {}", e))?; + } + + let contents = fs::read_to_string(path).map_err(|e| e.to_string())?; + return Ok(contents); +} + +#[tauri::command] +pub async fn save_mcp_configs(app: AppHandle, configs: String) -> Result<(), String> { + let mut path = get_jan_data_folder_path(app); + path.push("mcp_config.json"); + log::info!("save mcp configs, path: {:?}", path); + + fs::write(path, configs).map_err(|e| e.to_string()) +} diff --git a/src-tauri/src/core/fs.rs b/src-tauri/src/core/fs.rs index 9e77a812c..c0d7d423d 100644 --- a/src-tauri/src/core/fs.rs +++ b/src-tauri/src/core/fs.rs @@ -107,6 +107,7 @@ mod tests { use super::*; use std::fs::{self, File}; use std::io::Write; + use serde_json::to_string; use tauri::test::mock_app; #[test] @@ -154,9 +155,11 @@ mod tests { fn test_exists_sync() { let app = mock_app(); let path = "file://test_exists_sync_file"; - let file_path = get_jan_data_folder_path(app.handle().clone()).join(path); + let dir_path = get_jan_data_folder_path(app.handle().clone()); + fs::create_dir_all(&dir_path).unwrap(); + let file_path = dir_path.join("test_exists_sync_file"); File::create(&file_path).unwrap(); - let args = vec![path.to_string()]; + let args: Vec = vec![path.to_string()]; let result = exists_sync(app.handle().clone(), args).unwrap(); assert!(result); fs::remove_file(file_path).unwrap(); @@ -166,7 +169,9 @@ mod tests { fn test_read_file_sync() { let app = mock_app(); let path = "file://test_read_file_sync_file"; - let file_path = get_jan_data_folder_path(app.handle().clone()).join(path); + let dir_path = get_jan_data_folder_path(app.handle().clone()); + fs::create_dir_all(&dir_path).unwrap(); + let file_path = dir_path.join("test_read_file_sync_file"); let mut file = File::create(&file_path).unwrap(); file.write_all(b"test content").unwrap(); let args = vec![path.to_string()]; @@ -184,7 +189,7 @@ mod tests { File::create(dir_path.join("file1.txt")).unwrap(); File::create(dir_path.join("file2.txt")).unwrap(); - let args = vec![path.to_string()]; + let args = vec![dir_path.to_string_lossy().to_string()]; let result = readdir_sync(app.handle().clone(), args).unwrap(); assert_eq!(result.len(), 2); diff --git a/src-tauri/src/core/mcp.rs b/src-tauri/src/core/mcp.rs index f2109618a..06bbdcbb2 100644 --- a/src-tauri/src/core/mcp.rs +++ b/src-tauri/src/core/mcp.rs @@ -2,8 +2,11 @@ use std::{collections::HashMap, sync::Arc}; use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt}; use serde_json::Value; +use tauri::{AppHandle, State}; use tokio::{process::Command, sync::Mutex}; +use super::{cmd::get_jan_data_folder_path, state::AppState}; + /// Runs MCP commands by reading configuration from a JSON file and initializing servers /// /// # Arguments @@ -77,6 +80,35 @@ fn extract_command_args( Some((command, args, envs)) } +#[tauri::command] +pub async fn restart_mcp_servers( + app: AppHandle, + state: State<'_, AppState>, +) -> Result<(), String> { + let app_path = get_jan_data_folder_path(app.clone()); + let app_path_str = app_path.to_str().unwrap().to_string(); + let servers = state.mcp_servers.clone(); + // Stop the servers + stop_mcp_servers(state.mcp_servers.clone()).await?; + + // Restart the servers + run_mcp_commands(app_path_str, servers).await +} + +pub async fn stop_mcp_servers( + servers_state: Arc>>>, +) -> Result<(), String> { + let mut servers_map = servers_state.lock().await; + let keys: Vec = servers_map.keys().cloned().collect(); + for key in keys { + if let Some(service) = servers_map.remove(&key) { + service.cancel().await.map_err(|e| e.to_string())?; + } + } + drop(servers_map); // Release the lock after stopping + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src-tauri/src/core/mod.rs b/src-tauri/src/core/mod.rs index e4f0ee6c4..8d4edde3c 100644 --- a/src-tauri/src/core/mod.rs +++ b/src-tauri/src/core/mod.rs @@ -4,3 +4,5 @@ pub mod mcp; pub mod server; pub mod setup; pub mod state; +pub mod threads; +pub mod utils; \ No newline at end of file diff --git a/src-tauri/src/core/threads.rs b/src-tauri/src/core/threads.rs new file mode 100644 index 000000000..051837992 --- /dev/null +++ b/src-tauri/src/core/threads.rs @@ -0,0 +1,613 @@ +/*! + Thread and Message Persistence Module + + This module provides all logic for managing threads and their messages, including creation, modification, deletion, and listing. + Messages for each thread are persisted in a JSONL file (messages.jsonl) per thread directory. + + **Concurrency and Consistency Guarantee:** + - All operations that write or modify messages for a thread are protected by a global, per-thread asynchronous lock. + - This design ensures that only one operation can write to a thread's messages.jsonl file at a time, preventing race conditions. + - As a result, the messages.jsonl file for each thread is always consistent and never corrupted, even under concurrent access. +*/ + +use serde::{Deserialize, Serialize}; +use std::fs::{self, File}; +use std::io::{BufRead, BufReader, Write}; +use tauri::command; +use tauri::Runtime; +use uuid::Uuid; + +// For async file write serialization +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Global per-thread locks for message file writes +static MESSAGE_LOCKS: Lazy>>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +use super::utils::{ + ensure_data_dirs, ensure_thread_dir_exists, get_data_dir, get_messages_path, get_thread_dir, + get_thread_metadata_path, THREADS_FILE, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Thread { + pub id: String, + pub object: String, + pub title: String, + pub assistants: Vec, + pub created: i64, + pub updated: i64, + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadMessage { + pub id: String, + pub object: String, + pub thread_id: String, + pub assistant_id: Option, + pub attachments: Option>, + pub role: String, + pub content: Vec, + pub status: String, + pub created_at: i64, + pub completed_at: i64, + pub metadata: Option, + pub type_: Option, + pub error_code: Option, + pub tool_call_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Attachment { + pub file_id: Option, + pub tools: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum Tool { + #[serde(rename = "file_search")] + FileSearch, + #[serde(rename = "code_interpreter")] + CodeInterpreter, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadContent { + pub type_: String, + pub text: Option, + pub image_url: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ContentValue { + pub value: String, + pub annotations: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageContentValue { + pub detail: Option, + pub url: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadAssistantInfo { + pub assistant_id: String, + pub assistant_name: String, + pub model: ModelInfo, + pub instructions: Option, + pub tools: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub settings: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum AssistantTool { + #[serde(rename = "code_interpreter")] + CodeInterpreter, + #[serde(rename = "retrieval")] + Retrieval, + #[serde(rename = "function")] + Function { + name: String, + description: Option, + parameters: Option, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadState { + pub has_more: bool, + pub waiting_for_response: bool, + pub error: Option, + pub last_message: Option, +} + +/// Lists all threads by reading their metadata from the threads directory. +/// Returns a vector of thread metadata as JSON values. +#[command] +pub async fn list_threads( + app_handle: tauri::AppHandle, +) -> Result, String> { + ensure_data_dirs(app_handle.clone())?; + let data_dir = get_data_dir(app_handle.clone()); + let mut threads = Vec::new(); + + if !data_dir.exists() { + return Ok(threads); + } + + for entry in fs::read_dir(&data_dir).map_err(|e| e.to_string())? { + let entry = entry.map_err(|e| e.to_string())?; + let path = entry.path(); + if path.is_dir() { + let thread_metadata_path = path.join(THREADS_FILE); + if thread_metadata_path.exists() { + let data = fs::read_to_string(&thread_metadata_path).map_err(|e| e.to_string())?; + match serde_json::from_str(&data) { + Ok(thread) => threads.push(thread), + Err(e) => { + println!("Failed to parse thread file: {}", e); + continue; // skip invalid thread files + } + } + } + } + } + + Ok(threads) +} + +/// Creates a new thread, assigns it a unique ID, and persists its metadata. +/// Ensures the thread directory exists and writes thread.json. +#[command] +pub async fn create_thread( + app_handle: tauri::AppHandle, + mut thread: serde_json::Value, +) -> Result { + ensure_data_dirs(app_handle.clone())?; + let uuid = Uuid::new_v4().to_string(); + thread["id"] = serde_json::Value::String(uuid.clone()); + let thread_dir = get_thread_dir(app_handle.clone(), &uuid); + if !thread_dir.exists() { + fs::create_dir_all(&thread_dir).map_err(|e| e.to_string())?; + } + let path = get_thread_metadata_path(app_handle.clone(), &uuid); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(path, data).map_err(|e| e.to_string())?; + Ok(thread) +} + +/// Modifies an existing thread's metadata by overwriting its thread.json file. +/// Returns an error if the thread directory does not exist. +#[command] +pub async fn modify_thread( + app_handle: tauri::AppHandle, + thread: serde_json::Value, +) -> Result<(), String> { + let thread_id = thread + .get("id") + .and_then(|id| id.as_str()) + .ok_or("Missing thread id")?; + let thread_dir = get_thread_dir(app_handle.clone(), thread_id); + if !thread_dir.exists() { + return Err("Thread directory does not exist".to_string()); + } + let path = get_thread_metadata_path(app_handle.clone(), thread_id); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(path, data).map_err(|e| e.to_string())?; + Ok(()) +} + +/// Deletes a thread and all its associated files by removing its directory. +#[command] +pub async fn delete_thread( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result<(), String> { + let thread_dir = get_thread_dir(app_handle.clone(), &thread_id); + if thread_dir.exists() { + fs::remove_dir_all(thread_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} + +/// Lists all messages for a given thread by reading and parsing its messages.jsonl file. +/// Returns a vector of message JSON values. +#[command] +pub async fn list_messages( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result, String> { + let path = get_messages_path(app_handle, &thread_id); + if !path.exists() { + return Ok(vec![]); + } + + let file = File::open(&path).map_err(|e| { + eprintln!("Error opening file {}: {}", path.display(), e); + e.to_string() + })?; + let reader = BufReader::new(file); + + let mut messages = Vec::new(); + for line in reader.lines() { + let line = line.map_err(|e| { + eprintln!("Error reading line from file {}: {}", path.display(), e); + e.to_string() + })?; + let message: serde_json::Value = serde_json::from_str(&line).map_err(|e| { + eprintln!( + "Error parsing JSON from line in file {}: {}", + path.display(), + e + ); + e.to_string() + })?; + messages.push(message); + } + + Ok(messages) +} + +/// Appends a new message to a thread's messages.jsonl file. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. +#[command] +pub async fn create_message( + app_handle: tauri::AppHandle, + mut message: serde_json::Value, +) -> Result { + let thread_id = { + let id = message + .get("thread_id") + .and_then(|v| v.as_str()) + .ok_or("Missing thread_id")?; + id.to_string() + }; + ensure_thread_dir_exists(app_handle.clone(), &thread_id)?; + let path = get_messages_path(app_handle.clone(), &thread_id); + + if message.get("id").is_none() { + let uuid = Uuid::new_v4().to_string(); + message["id"] = serde_json::Value::String(uuid); + } + + // Acquire per-thread lock before writing + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock + + let _guard = lock.lock().await; + + let mut file: File = fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .map_err(|e| e.to_string())?; + + let data = serde_json::to_string(&message).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + + Ok(message) +} + +/// Modifies an existing message in a thread's messages.jsonl file. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. +/// Rewrites the entire messages.jsonl file for the thread. +#[command] +pub async fn modify_message( + app_handle: tauri::AppHandle, + message: serde_json::Value, +) -> Result { + let thread_id = message + .get("thread_id") + .and_then(|v| v.as_str()) + .ok_or("Missing thread_id")?; + let message_id = message + .get("id") + .and_then(|v| v.as_str()) + .ok_or("Missing message id")?; + + // Acquire per-thread lock before modifying + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock + + let _guard = lock.lock().await; + + let mut messages = list_messages(app_handle.clone(), thread_id.to_string()).await?; + if let Some(index) = messages + .iter() + .position(|m| m.get("id").and_then(|v| v.as_str()) == Some(message_id)) + { + messages[index] = message.clone(); + + // Rewrite all messages + let path = get_messages_path(app_handle.clone(), thread_id); + let mut file = File::create(path).map_err(|e| e.to_string())?; + for msg in messages { + let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + } + } + Ok(message) +} + +/// Deletes a message from a thread's messages.jsonl file by message ID. +/// Rewrites the entire messages.jsonl file for the thread. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. +#[command] +pub async fn delete_message( + app_handle: tauri::AppHandle, + thread_id: String, + message_id: String, +) -> Result<(), String> { + // Acquire per-thread lock before modifying + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock + + let _guard = lock.lock().await; + + let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?; + messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str())); + + // Rewrite remaining messages + let path = get_messages_path(app_handle.clone(), &thread_id); + let mut file = File::create(path).map_err(|e| e.to_string())?; + for msg in messages { + let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + } + + Ok(()) +} + +/// Retrieves the first assistant associated with a thread. +/// Returns an error if the thread or assistant is not found. +#[command] +pub async fn get_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result { + let path = get_thread_metadata_path(app_handle, &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + let thread: serde_json::Value = serde_json::from_str(&data).map_err(|e| e.to_string())?; + if let Some(assistants) = thread.get("assistants").and_then(|a| a.as_array()) { + if let Some(first) = assistants.get(0) { + Ok(first.clone()) + } else { + Err("Assistant not found".to_string()) + } + } else { + Err("Assistant not found".to_string()) + } +} + +/// Adds a new assistant to a thread's metadata. +/// Updates thread.json with the new assistant information. +#[command] +pub async fn create_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, + assistant: serde_json::Value, +) -> Result { + let path = get_thread_metadata_path(app_handle.clone(), &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let mut thread: serde_json::Value = { + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + serde_json::from_str(&data).map_err(|e| e.to_string())? + }; + if let Some(assistants) = thread.get_mut("assistants").and_then(|a| a.as_array_mut()) { + assistants.push(assistant.clone()); + } else { + thread["assistants"] = serde_json::Value::Array(vec![assistant.clone()]); + } + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(&path, data).map_err(|e| e.to_string())?; + Ok(assistant) +} + +/// Modifies an existing assistant's information in a thread's metadata. +/// Updates thread.json with the modified assistant data. +#[command] +pub async fn modify_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, + assistant: serde_json::Value, +) -> Result { + let path = get_thread_metadata_path(app_handle.clone(), &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let mut thread: serde_json::Value = { + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + serde_json::from_str(&data).map_err(|e| e.to_string())? + }; + let assistant_id = assistant + .get("assistant_id") + .and_then(|v| v.as_str()) + .ok_or("Missing assistant_id")?; + if let Some(assistants) = thread + .get_mut("assistants") + .and_then(|a: &mut serde_json::Value| a.as_array_mut()) + { + if let Some(index) = assistants + .iter() + .position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id)) + { + assistants[index] = assistant.clone(); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(&path, data).map_err(|e| e.to_string())?; + } + } + Ok(assistant) +} + +#[cfg(test)] +mod tests { + use crate::core::cmd::get_jan_data_folder_path; + + use super::*; + use serde_json::json; + use std::fs; + use std::path::PathBuf; + use tauri::test::{mock_app, MockRuntime}; + + // Helper to create a mock app handle with a temp data dir + fn mock_app_with_temp_data_dir() -> (tauri::App, PathBuf) { + let app = mock_app(); + let data_dir = get_jan_data_folder_path(app.handle().clone()); + println!("Mock app data dir: {}", data_dir.display()); + // Patch get_data_dir to use temp dir (requires get_data_dir to be overridable or injectable) + // For now, we assume get_data_dir uses tauri::api::path::app_data_dir(&app_handle) + // and that we can set the environment variable to redirect it. + (app, data_dir) + } + + #[tokio::test] + async fn test_create_and_list_threads() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread + let thread = json!({ + "object": "thread", + "title": "Test Thread", + "assistants": [], + "created": 1234567890, + "updated": 1234567890, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + assert_eq!(created["title"], "Test Thread"); + + // List threads + let threads = list_threads(app.handle().clone()).await.unwrap(); + assert!(threads.len() > 0); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } + + #[tokio::test] + async fn test_create_and_list_messages() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread first + let thread = json!({ + "object": "thread", + "title": "Msg Thread", + "assistants": [], + "created": 123, + "updated": 123, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + let thread_id = created["id"].as_str().unwrap().to_string(); + + // Create a message + let message = json!({ + "object": "message", + "thread_id": thread_id, + "assistant_id": null, + "attachments": null, + "role": "user", + "content": [], + "status": "sent", + "created_at": 123, + "completed_at": 123, + "metadata": null, + "type_": null, + "error_code": null, + "tool_call_id": null + }); + let created_msg = create_message(app.handle().clone(), message).await.unwrap(); + assert_eq!(created_msg["role"], "user"); + + // List messages + let messages = list_messages(app.handle().clone(), thread_id.clone()) + .await + .unwrap(); + assert!(messages.len() > 0); + assert_eq!(messages[0]["role"], "user"); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } + + #[tokio::test] + async fn test_create_and_get_thread_assistant() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread + let thread = json!({ + "object": "thread", + "title": "Assistant Thread", + "assistants": [], + "created": 1, + "updated": 1, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + let thread_id = created["id"].as_str().unwrap().to_string(); + + // Add assistant + let assistant = json!({ + "id": "assistant-1", + "assistant_name": "Test Assistant", + "model": { + "id": "model-1", + "name": "Test Model", + "settings": json!({}) + }, + "instructions": null, + "tools": null + }); + let _ = create_thread_assistant(app.handle().clone(), thread_id.clone(), assistant.clone()) + .await + .unwrap(); + + // Get assistant + let got = get_thread_assistant(app.handle().clone(), thread_id.clone()) + .await + .unwrap(); + assert_eq!(got["assistant_name"], "Test Assistant"); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } +} diff --git a/src-tauri/src/core/utils/mod.rs b/src-tauri/src/core/utils/mod.rs new file mode 100644 index 000000000..7f80e6f3a --- /dev/null +++ b/src-tauri/src/core/utils/mod.rs @@ -0,0 +1,48 @@ +use std::fs; +use std::path::PathBuf; +use tauri::Runtime; + +use super::cmd::get_jan_data_folder_path; + +pub const THREADS_DIR: &str = "threads"; +pub const THREADS_FILE: &str = "thread.json"; +pub const MESSAGES_FILE: &str = "messages.jsonl"; + +pub fn get_data_dir(app_handle: tauri::AppHandle) -> PathBuf { + get_jan_data_folder_path(app_handle).join(THREADS_DIR) +} + +pub fn get_thread_dir(app_handle: tauri::AppHandle, thread_id: &str) -> PathBuf { + get_data_dir(app_handle).join(thread_id) +} + +pub fn get_thread_metadata_path( + app_handle: tauri::AppHandle, + thread_id: &str, +) -> PathBuf { + get_thread_dir(app_handle, thread_id).join(THREADS_FILE) +} + +pub fn get_messages_path(app_handle: tauri::AppHandle, thread_id: &str) -> PathBuf { + get_thread_dir(app_handle, thread_id).join(MESSAGES_FILE) +} + +pub fn ensure_data_dirs(app_handle: tauri::AppHandle) -> Result<(), String> { + let data_dir = get_data_dir(app_handle.clone()); + if !data_dir.exists() { + fs::create_dir_all(&data_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} + +pub fn ensure_thread_dir_exists( + app_handle: tauri::AppHandle, + thread_id: &str, +) -> Result<(), String> { + ensure_data_dirs(app_handle.clone())?; + let thread_dir = get_thread_dir(app_handle, thread_id); + if !thread_dir.exists() { + fs::create_dir(&thread_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 40cd83f57..3b5b13a64 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -39,9 +39,24 @@ pub fn run() { core::cmd::app_token, core::cmd::start_server, core::cmd::stop_server, + core::cmd::save_mcp_configs, + core::cmd::get_mcp_configs, // MCP commands core::cmd::get_tools, - core::cmd::call_tool + core::cmd::call_tool, + core::mcp::restart_mcp_servers, + // Threads + core::threads::list_threads, + core::threads::create_thread, + core::threads::modify_thread, + core::threads::delete_thread, + core::threads::list_messages, + core::threads::create_message, + core::threads::modify_message, + core::threads::delete_message, + core::threads::get_thread_assistant, + core::threads::create_thread_assistant, + core::threads::modify_thread_assistant ]) .manage(AppState { app_token: Some(generate_app_token()), diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index d95a114c4..0d9f862df 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -232,26 +232,6 @@ const ModelDropdown = ({ stopModel() if (activeThread) { - // Change assistand tools based on model support RAG - updateThreadMetadata({ - ...activeThread, - assistants: [ - { - ...activeAssistant, - tools: [ - { - type: 'retrieval', - enabled: model?.engine === InferenceEngine.cortex, - settings: { - ...(activeAssistant.tools && - activeAssistant.tools[0]?.settings), - }, - }, - ], - }, - ], - }) - const contextLength = model?.settings.ctx_len ? Math.min(8192, model?.settings.ctx_len ?? 8192) : undefined @@ -273,11 +253,25 @@ const ModelDropdown = ({ // Update model parameter to the thread file if (model) - updateModelParameter(activeThread, { - params: modelParams, - modelId: model.id, - engine: model.engine, - }) + updateModelParameter( + activeThread, + { + params: modelParams, + modelId: model.id, + engine: model.engine, + }, + // Update tools + [ + { + type: 'retrieval', + enabled: model?.engine === InferenceEngine.cortex, + settings: { + ...(activeAssistant.tools && + activeAssistant.tools[0]?.settings), + }, + }, + ] + ) } }, [ diff --git a/web/containers/ModelSearch/index.tsx b/web/containers/ModelSearch/index.tsx index aa40f8331..ceecacd39 100644 --- a/web/containers/ModelSearch/index.tsx +++ b/web/containers/ModelSearch/index.tsx @@ -83,7 +83,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => { value={searchText} clearable={searchText.length > 0} onClear={onClear} - className="border-0 bg-[hsla(var(--app-bg))]" + className="bg-[hsla(var(--app-bg))]" onClick={() => { onSearchLocal?.(inputRef.current?.value ?? '') }} diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 786dbd4f0..9590e5048 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -114,7 +114,7 @@ export default function ModelHandler() { const onNewMessageResponse = useCallback( async (message: ThreadMessage) => { - if (message.type === MessageRequestType.Thread) { + if (message.type !== MessageRequestType.Summary) { addNewMessage(message) } }, @@ -129,35 +129,20 @@ export default function ModelHandler() { const updateThreadTitle = useCallback( (message: ThreadMessage) => { // Update only when it's finished - if (message.status !== MessageStatus.Ready) { - return - } + if (message.status !== MessageStatus.Ready) return const thread = threadsRef.current?.find((e) => e.id == message.thread_id) - if (!thread) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Thread not found!` - ) - return - } - let messageContent = message.content[0]?.text?.value - if (!messageContent) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Responded content is null!` - ) - return - } + if (!thread || !messageContent) return // No new line character is presented in the title // And non-alphanumeric characters should be removed - if (messageContent.includes('\n')) { + if (messageContent.includes('\n')) messageContent = messageContent.replace(/\n/g, ' ') - } + const match = messageContent.match(/<\/think>(.*)$/) - if (match) { - messageContent = match[1] - } + if (match) messageContent = match[1] + // Remove non-alphanumeric characters const cleanedMessageContent = messageContent .replace(/[^\p{L}\s]+/gu, '') @@ -193,18 +178,13 @@ export default function ModelHandler() { const updateThreadMessage = useCallback( (message: ThreadMessage) => { - if ( - messageGenerationSubscriber.current && - message.thread_id === activeThreadRef.current?.id && - !messageGenerationSubscriber.current!.thread_id - ) { - updateMessage( - message.id, - message.thread_id, - message.content, - message.status - ) - } + updateMessage( + message.id, + message.thread_id, + message.content, + message.metadata, + message.status + ) if (message.status === MessageStatus.Pending) { if (message.content.length) { @@ -236,82 +216,66 @@ export default function ModelHandler() { model: activeModelRef.current?.name, } }) - return - } else if ( - message.status === MessageStatus.Error && - activeModelRef.current?.engine && - engines && - isLocalEngine(engines, activeModelRef.current.engine) - ) { - ;(async () => { - if ( - !(await extensionManager - .get(ExtensionTypeEnum.Model) - ?.isModelLoaded(activeModelRef.current?.id as string)) - ) { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: undefined }) - } - })() - } - // Mark the thread as not waiting for response - updateThreadWaiting(message.thread_id, false) + } else { + // Mark the thread as not waiting for response + updateThreadWaiting(message.thread_id, false) - setIsGeneratingResponse(false) + setIsGeneratingResponse(false) - const thread = threadsRef.current?.find((e) => e.id == message.thread_id) - if (!thread) return + const thread = threadsRef.current?.find( + (e) => e.id == message.thread_id + ) + if (!thread) return - const messageContent = message.content[0]?.text?.value + const messageContent = message.content[0]?.text?.value - const metadata = { - ...thread.metadata, - ...(messageContent && { lastMessage: messageContent }), - updated_at: Date.now(), - } + const metadata = { + ...thread.metadata, + ...(messageContent && { lastMessage: messageContent }), + updated_at: Date.now(), + } - updateThread({ - ...thread, - metadata, - }) - - extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.modifyThread({ + updateThread({ ...thread, metadata, }) - // Update message's metadata with token usage - message.metadata = { - ...message.metadata, - token_speed: tokenSpeedRef.current?.tokenSpeed, - model: activeModelRef.current?.name, - } + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThread({ + ...thread, + metadata, + }) - if (message.status === MessageStatus.Error) { + // Update message's metadata with token usage message.metadata = { ...message.metadata, - error: message.content[0]?.text?.value, - error_code: message.error_code, + token_speed: tokenSpeedRef.current?.tokenSpeed, + model: activeModelRef.current?.name, } - } - ;(async () => { - const updatedMessage = await extensionManager + + if (message.status === MessageStatus.Error) { + message.metadata = { + ...message.metadata, + error: message.content[0]?.text?.value, + error_code: message.error_code, + } + // Unassign active model if any + setActiveModel(undefined) + setStateModel({ + state: 'start', + loading: false, + model: undefined, + }) + } + + extensionManager .get(ExtensionTypeEnum.Conversational) ?.createMessage(message) - .catch(() => undefined) - if (updatedMessage) { - deleteMessage(message.id) - addNewMessage(updatedMessage) - setTokenSpeed((prev) => - prev ? { ...prev, message: updatedMessage.id } : undefined - ) - } - })() - // Attempt to generate the title of the Thread when needed - generateThreadTitle(message, thread) + // Attempt to generate the title of the Thread when needed + generateThreadTitle(message, thread) + } }, // eslint-disable-next-line react-hooks/exhaustive-deps [setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting] @@ -319,25 +283,21 @@ export default function ModelHandler() { const onMessageResponseUpdate = useCallback( (message: ThreadMessage) => { - switch (message.type) { - case MessageRequestType.Summary: - updateThreadTitle(message) - break - default: - updateThreadMessage(message) - break - } + if (message.type === MessageRequestType.Summary) + updateThreadTitle(message) + else updateThreadMessage(message) }, [updateThreadMessage, updateThreadTitle] ) const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { // If this is the first ever prompt in the thread - if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle) + if ( + !activeModelRef.current || + (thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle + ) return - if (!activeModelRef.current) return - // Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp if ( activeModelRef.current?.engine !== InferenceEngine.cortex && diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 1847aa422..faae6e298 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -165,6 +165,7 @@ export const updateMessageAtom = atom( id: string, conversationId: string, text: ThreadContent[], + metadata: Record | undefined, status: MessageStatus ) => { const messages = get(chatMessages)[conversationId] ?? [] @@ -172,6 +173,7 @@ export const updateMessageAtom = atom( if (message) { message.content = text message.status = status + message.metadata = metadata const updatedMessages = [...messages] const newData: Record = { @@ -192,6 +194,7 @@ export const updateMessageAtom = atom( created_at: Date.now() / 1000, completed_at: Date.now() / 1000, object: 'thread.message', + metadata: metadata, }) } } diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 57ceeb385..6f1efc04d 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -180,7 +180,7 @@ export const useCreateNewThread = () => { updateThreadCallback(thread) if (thread.assistants && thread.assistants?.length > 0) { setActiveAssistant(thread.assistants[0]) - updateAssistantCallback(thread.id, thread.assistants[0]) + return updateAssistantCallback(thread.id, thread.assistants[0]) } }, [ diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index 59aa3a83b..d0c7cac1a 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -38,12 +38,13 @@ export default function useDeleteThread() { ?.listMessages(threadId) .catch(console.error) if (messages) { - messages.forEach((message) => { - extensionManager + for (const message of messages) { + await extensionManager .get(ExtensionTypeEnum.Conversational) ?.deleteMessage(threadId, message.id) .catch(console.error) - }) + } + const thread = threads.find((e) => e.id === threadId) if (thread) { const updatedThread = { diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 3242b085c..49e0d3e5b 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -24,8 +24,10 @@ import { ChatCompletionMessageParam, ChatCompletionRole as OpenAIChatCompletionRole, ChatCompletionTool, + ChatCompletionMessageToolCall, } from 'openai/resources/chat' +import { Stream } from 'openai/streaming' import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -133,12 +135,16 @@ export default function useSendChatMessage() { ) => { if (!message || message.trim().length === 0) return - if (!activeThreadRef.current || !activeAssistantRef.current) { + const activeThread = activeThreadRef.current + const activeAssistant = activeAssistantRef.current + const activeModel = selectedModelRef.current + + if (!activeThread || !activeAssistant) { console.error('No active thread or assistant') return } - if (selectedModelRef.current?.id === undefined) { + if (!activeModel?.id) { setModelDropdownState(true) return } @@ -151,7 +157,7 @@ export default function useSendChatMessage() { const prompt = message.trim() - updateThreadWaiting(activeThreadRef.current.id, true) + updateThreadWaiting(activeThread.id, true) setCurrentPrompt('') setEditPrompt('') @@ -162,15 +168,14 @@ export default function useSendChatMessage() { base64Blob = await compressImage(base64Blob, 512) } - const modelRequest = - selectedModelRef?.current ?? activeAssistantRef.current?.model + const modelRequest = selectedModel ?? activeAssistant.model // Fallback support for previous broken threads - if (activeAssistantRef.current?.model?.id === '*') { - activeAssistantRef.current.model = { - id: modelRequest.id, - settings: modelRequest.settings, - parameters: modelRequest.parameters, + if (activeAssistant.model?.id === '*') { + activeAssistant.model = { + id: activeModel.id, + settings: activeModel.settings, + parameters: activeModel.parameters, } } if (runtimeParams.stream == null) { @@ -185,7 +190,7 @@ export default function useSendChatMessage() { settings: settingParams, parameters: runtimeParams, }, - activeThreadRef.current, + activeThread, messages ?? currentMessages, (await window.core.api.getTools())?.map((tool: ModelTool) => ({ type: 'function' as const, @@ -196,7 +201,7 @@ export default function useSendChatMessage() { strict: false, }, })) - ).addSystemMessage(activeAssistantRef.current?.instructions) + ).addSystemMessage(activeAssistant.instructions) requestBuilder.pushMessage(prompt, base64Blob, fileUpload) @@ -209,10 +214,10 @@ export default function useSendChatMessage() { // Update thread state const updatedThread: Thread = { - ...activeThreadRef.current, + ...activeThread, updated: newMessage.created_at, metadata: { - ...activeThreadRef.current.metadata, + ...activeThread.metadata, lastMessage: prompt, }, } @@ -235,17 +240,16 @@ export default function useSendChatMessage() { } // Start Model if not started - const modelId = - selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id + const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id if (base64Blob) { setFileUpload(undefined) } - if (modelRef.current?.id !== modelId && modelId) { + if (activeModel?.id !== modelId && modelId) { const error = await startModel(modelId).catch((error: Error) => error) if (error) { - updateThreadWaiting(activeThreadRef.current.id, false) + updateThreadWaiting(activeThread.id, false) return } } @@ -258,111 +262,65 @@ export default function useSendChatMessage() { baseURL: `${API_BASE_URL}/v1`, dangerouslyAllowBrowser: true, }) + let parentMessageId: string | undefined while (!isDone) { + let messageId = ulid() + if (!parentMessageId) { + parentMessageId = ulid() + messageId = parentMessageId + } const data = requestBuilder.build() + const message: ThreadMessage = { + id: messageId, + object: 'message', + thread_id: activeThread.id, + assistant_id: activeAssistant.assistant_id, + role: ChatCompletionRole.Assistant, + content: [], + metadata: { + ...(messageId !== parentMessageId + ? { parent_id: parentMessageId } + : {}), + }, + status: MessageStatus.Pending, + created_at: Date.now() / 1000, + completed_at: Date.now() / 1000, + } + events.emit(MessageEvent.OnMessageResponse, message) const response = await openai.chat.completions.create({ - messages: (data.messages ?? []).map((e) => { - return { - role: e.role as OpenAIChatCompletionRole, - content: e.content, - } - }) as ChatCompletionMessageParam[], + messages: requestBuilder.messages as ChatCompletionMessageParam[], model: data.model?.id ?? '', tools: data.tools as ChatCompletionTool[], - stream: false, + stream: data.model?.parameters?.stream ?? false, + tool_choice: 'auto', }) - if (response.choices[0]?.message.content) { - const newMessage: ThreadMessage = { - id: ulid(), - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: response.choices[0].message.role as ChatCompletionRole, - content: [ - { - type: ContentType.Text, - text: { - value: response.choices[0].message.content - ? response.choices[0].message.content - : '', - annotations: [], - }, + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = [ + { + type: ContentType.Text, + text: { + value: '', + annotations: [], }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage( - response.choices[0].message.content ?? '' + }, + ] + } + if (data.model?.parameters?.stream) + isDone = await processStreamingResponse( + response as Stream, + requestBuilder, + message + ) + else { + isDone = await processNonStreamingResponse( + response as OpenAI.Chat.Completions.ChatCompletion, + requestBuilder, + message ) - events.emit(MessageEvent.OnMessageUpdate, newMessage) } - - if (response.choices[0]?.message.tool_calls) { - for (const toolCall of response.choices[0].message.tool_calls) { - const id = ulid() - const toolMessage: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, - annotations: [], - }, - }, - ], - status: MessageStatus.Pending, - created_at: Date.now(), - completed_at: Date.now(), - } - events.emit(MessageEvent.OnMessageUpdate, toolMessage) - const result = await window.core.api.callTool({ - toolName: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments), - }) - if (result.error) { - console.error(result.error) - break - } - const message: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: - `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + - (result.content[0]?.text ?? ''), - annotations: [], - }, - }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') - requestBuilder.pushMessage('Go for the next step') - events.emit(MessageEvent.OnMessageUpdate, message) - } - } - - isDone = - !response.choices[0]?.message.tool_calls || - !response.choices[0]?.message.tool_calls.length + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) } } else { // Request for inference @@ -376,6 +334,182 @@ export default function useSendChatMessage() { setEngineParamsUpdate(false) } + const processNonStreamingResponse = async ( + response: OpenAI.Chat.Completions.ChatCompletion, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Handle tool calls in the response + const toolCalls: ChatCompletionMessageToolCall[] = + response.choices[0]?.message?.tool_calls ?? [] + const content = response.choices[0].message?.content + message.content = [ + { + type: ContentType.Text, + text: { + value: content ?? '', + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + content ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const processStreamingResponse = async ( + response: Stream, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Variables to track and accumulate streaming content + let currentToolCall: { + id: string + function: { name: string; arguments: string } + } | null = null + let accumulatedContent = '' + const toolCalls: ChatCompletionMessageToolCall[] = [] + // Process the streaming chunks + for await (const chunk of response) { + // Handle tool calls in the chunk + if (chunk.choices[0]?.delta?.tool_calls) { + const deltaToolCalls = chunk.choices[0].delta.tool_calls + + // Handle the beginning of a new tool call + if ( + deltaToolCalls[0]?.index !== undefined && + deltaToolCalls[0]?.function + ) { + const index = deltaToolCalls[0].index + + // Create new tool call if this is the first chunk for it + if (!toolCalls[index]) { + toolCalls[index] = { + id: deltaToolCalls[0]?.id || '', + function: { + name: deltaToolCalls[0]?.function?.name || '', + arguments: deltaToolCalls[0]?.function?.arguments || '', + }, + type: 'function', + } + currentToolCall = toolCalls[index] + } else { + // Continuation of existing tool call + currentToolCall = toolCalls[index] + + // Append to function name or arguments if they exist in this chunk + if (deltaToolCalls[0]?.function?.name) { + currentToolCall!.function.name += deltaToolCalls[0].function.name + } + + if (deltaToolCalls[0]?.function?.arguments) { + currentToolCall!.function.arguments += + deltaToolCalls[0].function.arguments + } + } + } + } + + // Handle regular content in the chunk + if (chunk.choices[0]?.delta?.content) { + const content = chunk.choices[0].delta.content + accumulatedContent += content + + message.content = [ + { + type: ContentType.Text, + text: { + value: accumulatedContent, + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + accumulatedContent ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const postMessageProcessing = async ( + toolCalls: ChatCompletionMessageToolCall[], + requestBuilder: MessageRequestBuilder, + message: ThreadMessage, + content: string + ) => { + requestBuilder.pushAssistantMessage({ + content, + role: 'assistant', + refusal: null, + tool_calls: toolCalls, + }) + + // Handle completed tool calls + if (toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolId = ulid() + const toolCallsMetadata = + message.metadata?.tool_calls && + Array.isArray(message.metadata?.tool_calls) + ? message.metadata?.tool_calls + : [] + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: undefined, + state: 'pending', + }, + ], + } + events.emit(MessageEvent.OnMessageUpdate, message) + + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) break + + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: result, + state: 'ready', + }, + ], + } + + requestBuilder.pushToolMessage( + result.content[0]?.text ?? '', + toolCall.id + ) + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + } + return { sendChatMessage, resendChatMessage, diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index dab2f6e28..8bab0c357 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -1,6 +1,7 @@ import { useCallback } from 'react' import { + AssistantTool, ConversationalExtension, ExtensionTypeEnum, InferenceEngine, @@ -51,7 +52,11 @@ export default function useUpdateModelParameters() { ) const updateModelParameter = useCallback( - async (thread: Thread, settings: UpdateModelParameter) => { + async ( + thread: Thread, + settings: UpdateModelParameter, + tools?: AssistantTool[] + ) => { if (!activeAssistant) return const toUpdateSettings = processStopWords(settings.params ?? {}) @@ -70,6 +75,7 @@ export default function useUpdateModelParameters() { const settingParams = extractModelLoadParams(updatedModelParams) const assistantInfo = { ...activeAssistant, + tools: tools ?? activeAssistant.tools, model: { ...activeAssistant?.model, parameters: runtimeParams, diff --git a/web/package.json b/web/package.json index 7999c74e9..cdf2d8d8b 100644 --- a/web/package.json +++ b/web/package.json @@ -37,6 +37,7 @@ "marked": "^9.1.2", "next": "14.2.3", "next-themes": "^0.2.1", + "npx-scope-finder": "^1.3.0", "openai": "^4.90.0", "postcss": "8.4.31", "postcss-url": "10.1.3", diff --git a/web/screens/Hub/ModelFilter/ModelSize/index.tsx b/web/screens/Hub/ModelFilter/ModelSize/index.tsx index b95d57f8b..a8d411e33 100644 --- a/web/screens/Hub/ModelFilter/ModelSize/index.tsx +++ b/web/screens/Hub/ModelFilter/ModelSize/index.tsx @@ -1,9 +1,8 @@ -import { useRef, useState } from 'react' +import { useState } from 'react' -import { Slider, Input, Tooltip } from '@janhq/joi' +import { Slider, Input } from '@janhq/joi' import { atom, useAtom } from 'jotai' -import { InfoIcon } from 'lucide-react' export const hubModelSizeMinAtom = atom(0) export const hubModelSizeMaxAtom = atom(100) diff --git a/web/screens/Settings/MCP/configuration.tsx b/web/screens/Settings/MCP/configuration.tsx new file mode 100644 index 000000000..c5593d528 --- /dev/null +++ b/web/screens/Settings/MCP/configuration.tsx @@ -0,0 +1,99 @@ +import React, { useState, useEffect, useCallback } from 'react' + +import { Button, TextArea } from '@janhq/joi' +import { useAtomValue } from 'jotai' + +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' + +const MCPConfiguration = () => { + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) + const [configContent, setConfigContent] = useState('') + const [isSaving, setIsSaving] = useState(false) + const [error, setError] = useState('') + const [success, setSuccess] = useState('') + + const readConfigFile = useCallback(async () => { + try { + // Read the file + const content = await window.core?.api.getMcpConfigs() + setConfigContent(content) + + setError('') + } catch (err) { + console.error('Error reading config file:', err) + setError('Failed to read config file') + } + }, [janDataFolderPath]) + + useEffect(() => { + if (janDataFolderPath) { + readConfigFile() + } + }, [janDataFolderPath, readConfigFile]) + + const saveConfigFile = useCallback(async () => { + try { + setIsSaving(true) + setSuccess('') + setError('') + + // Validate JSON + try { + JSON.parse(configContent) + } catch (err) { + setError('Invalid JSON format') + setIsSaving(false) + return + } + await window.core?.api?.saveMcpConfigs({ configs: configContent }) + await window.core?.api?.restartMcpServers() + + setSuccess('Config saved successfully') + setIsSaving(false) + } catch (err) { + console.error('Error saving config file:', err) + setError('Failed to save config file') + setIsSaving(false) + } + }, [janDataFolderPath, configContent]) + + return ( + <> + {error && ( +
+ {error} +
+ )} + + {success && ( +
+ {success} +
+ )} + +
+ +