Merge branch 'feat/tauri-build-option' of https://github.com/menloresearch/jan into chore/tauri-cicd
This commit is contained in:
commit
a3cb4f0ee7
@ -7,6 +7,7 @@ export enum ChatCompletionRole {
|
|||||||
System = 'system',
|
System = 'system',
|
||||||
Assistant = 'assistant',
|
Assistant = 'assistant',
|
||||||
User = 'user',
|
User = 'user',
|
||||||
|
Tool = 'tool',
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -18,6 +19,9 @@ export type ChatCompletionMessage = {
|
|||||||
content?: ChatCompletionMessageContent
|
content?: ChatCompletionMessageContent
|
||||||
/** The role of the author of this message. **/
|
/** The role of the author of this message. **/
|
||||||
role: ChatCompletionRole
|
role: ChatCompletionRole
|
||||||
|
type?: string
|
||||||
|
output?: string
|
||||||
|
tool_call_id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ChatCompletionMessageContent =
|
export type ChatCompletionMessageContent =
|
||||||
|
|||||||
@ -36,6 +36,8 @@ export type ThreadMessage = {
|
|||||||
type?: string
|
type?: string
|
||||||
/** The error code which explain what error type. Used in conjunction with MessageStatus.Error */
|
/** The error code which explain what error type. Used in conjunction with MessageStatus.Error */
|
||||||
error_code?: ErrorCode
|
error_code?: ErrorCode
|
||||||
|
|
||||||
|
tool_call_id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -23,9 +23,7 @@
|
|||||||
"typescript": "^5.7.2"
|
"typescript": "^5.7.2"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@janhq/core": "../../core/package.tgz",
|
"@janhq/core": "../../core/package.tgz"
|
||||||
"ky": "^1.7.2",
|
|
||||||
"p-queue": "^8.0.1"
|
|
||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=18.0.0"
|
"node": ">=18.0.0"
|
||||||
|
|||||||
@ -4,40 +4,12 @@ import {
|
|||||||
ThreadAssistantInfo,
|
ThreadAssistantInfo,
|
||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import ky, { KyInstance } from 'ky'
|
|
||||||
|
|
||||||
type ThreadList = {
|
|
||||||
data: Thread[]
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageList = {
|
|
||||||
data: ThreadMessage[]
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
* JSONConversationalExtension is a ConversationalExtension implementation that provides
|
||||||
* functionality for managing threads.
|
* functionality for managing threads.
|
||||||
*/
|
*/
|
||||||
export default class CortexConversationalExtension extends ConversationalExtension {
|
export default class CortexConversationalExtension extends ConversationalExtension {
|
||||||
api?: KyInstance
|
|
||||||
/**
|
|
||||||
* Get the API instance
|
|
||||||
* @returns
|
|
||||||
*/
|
|
||||||
async apiInstance(): Promise<KyInstance> {
|
|
||||||
if (this.api) return this.api
|
|
||||||
const apiKey = (await window.core?.api.appToken())
|
|
||||||
this.api = ky.extend({
|
|
||||||
prefixUrl: API_URL,
|
|
||||||
headers: apiKey
|
|
||||||
? {
|
|
||||||
Authorization: `Bearer ${apiKey}`,
|
|
||||||
}
|
|
||||||
: {},
|
|
||||||
retry: 10,
|
|
||||||
})
|
|
||||||
return this.api
|
|
||||||
}
|
|
||||||
/**
|
/**
|
||||||
* Called when the extension is loaded.
|
* Called when the extension is loaded.
|
||||||
*/
|
*/
|
||||||
@ -54,12 +26,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* Returns a Promise that resolves to an array of Conversation objects.
|
* Returns a Promise that resolves to an array of Conversation objects.
|
||||||
*/
|
*/
|
||||||
async listThreads(): Promise<Thread[]> {
|
async listThreads(): Promise<Thread[]> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.listThreads()
|
||||||
api
|
|
||||||
.get('v1/threads?limit=-1')
|
|
||||||
.json<ThreadList>()
|
|
||||||
.then((e) => e.data)
|
|
||||||
) as Promise<Thread[]>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -67,9 +34,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @param thread The Thread object to save.
|
* @param thread The Thread object to save.
|
||||||
*/
|
*/
|
||||||
async createThread(thread: Thread): Promise<Thread> {
|
async createThread(thread: Thread): Promise<Thread> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.createThread({ thread })
|
||||||
api.post('v1/threads', { json: thread }).json<Thread>()
|
|
||||||
) as Promise<Thread>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -77,10 +42,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @param thread The Thread object to save.
|
* @param thread The Thread object to save.
|
||||||
*/
|
*/
|
||||||
async modifyThread(thread: Thread): Promise<void> {
|
async modifyThread(thread: Thread): Promise<void> {
|
||||||
return this.apiInstance()
|
return window.core.api.modifyThread({ thread })
|
||||||
.then((api) => api.patch(`v1/threads/${thread.id}`, { json: thread }))
|
|
||||||
|
|
||||||
.then()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -88,9 +50,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @param threadId The ID of the thread to delete.
|
* @param threadId The ID of the thread to delete.
|
||||||
*/
|
*/
|
||||||
async deleteThread(threadId: string): Promise<void> {
|
async deleteThread(threadId: string): Promise<void> {
|
||||||
return this.apiInstance()
|
return window.core.api.deleteThread({ threadId })
|
||||||
.then((api) => api.delete(`v1/threads/${threadId}`))
|
|
||||||
.then()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -99,13 +59,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @returns A Promise that resolves when the message has been added.
|
* @returns A Promise that resolves when the message has been added.
|
||||||
*/
|
*/
|
||||||
async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
async createMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.createMessage({ message })
|
||||||
api
|
|
||||||
.post(`v1/threads/${message.thread_id}/messages`, {
|
|
||||||
json: message,
|
|
||||||
})
|
|
||||||
.json<ThreadMessage>()
|
|
||||||
) as Promise<ThreadMessage>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -114,13 +68,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.modifyMessage({ message })
|
||||||
api
|
|
||||||
.patch(`v1/threads/${message.thread_id}/messages/${message.id}`, {
|
|
||||||
json: message,
|
|
||||||
})
|
|
||||||
.json<ThreadMessage>()
|
|
||||||
) as Promise<ThreadMessage>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -130,9 +78,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @returns A Promise that resolves when the message has been successfully deleted.
|
* @returns A Promise that resolves when the message has been successfully deleted.
|
||||||
*/
|
*/
|
||||||
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
||||||
return this.apiInstance()
|
return window.core.api.deleteMessage({ threadId, messageId })
|
||||||
.then((api) => api.delete(`v1/threads/${threadId}/messages/${messageId}`))
|
|
||||||
.then()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -141,12 +87,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* @returns A Promise that resolves to an array of ThreadMessage objects.
|
* @returns A Promise that resolves to an array of ThreadMessage objects.
|
||||||
*/
|
*/
|
||||||
async listMessages(threadId: string): Promise<ThreadMessage[]> {
|
async listMessages(threadId: string): Promise<ThreadMessage[]> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.listMessages({ threadId })
|
||||||
api
|
|
||||||
.get(`v1/threads/${threadId}/messages?order=asc&limit=-1`)
|
|
||||||
.json<MessageList>()
|
|
||||||
.then((e) => e.data)
|
|
||||||
) as Promise<ThreadMessage[]>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -156,9 +97,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
* the details of the assistant associated with the specified thread.
|
* the details of the assistant associated with the specified thread.
|
||||||
*/
|
*/
|
||||||
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
|
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.getThreadAssistant({ threadId })
|
||||||
api.get(`v1/assistants/${threadId}?limit=-1`).json<ThreadAssistantInfo>()
|
|
||||||
) as Promise<ThreadAssistantInfo>
|
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Creates a new assistant for the specified thread.
|
* Creates a new assistant for the specified thread.
|
||||||
@ -170,11 +109,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
threadId: string,
|
threadId: string,
|
||||||
assistant: ThreadAssistantInfo
|
assistant: ThreadAssistantInfo
|
||||||
): Promise<ThreadAssistantInfo> {
|
): Promise<ThreadAssistantInfo> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.createThreadAssistant(threadId, assistant)
|
||||||
api
|
|
||||||
.post(`v1/assistants/${threadId}`, { json: assistant })
|
|
||||||
.json<ThreadAssistantInfo>()
|
|
||||||
) as Promise<ThreadAssistantInfo>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -187,10 +122,6 @@ export default class CortexConversationalExtension extends ConversationalExtensi
|
|||||||
threadId: string,
|
threadId: string,
|
||||||
assistant: ThreadAssistantInfo
|
assistant: ThreadAssistantInfo
|
||||||
): Promise<ThreadAssistantInfo> {
|
): Promise<ThreadAssistantInfo> {
|
||||||
return this.apiInstance().then((api) =>
|
return window.core.api.modifyThreadAssistant({ threadId, assistant })
|
||||||
api
|
|
||||||
.patch(`v1/assistants/${threadId}`, { json: assistant })
|
|
||||||
.json<ThreadAssistantInfo>()
|
|
||||||
) as Promise<ThreadAssistantInfo>
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,6 +38,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
|
|||||||
"transport-child-process",
|
"transport-child-process",
|
||||||
"tower",
|
"tower",
|
||||||
] }
|
] }
|
||||||
|
uuid = { version = "1.7", features = ["v4"] }
|
||||||
|
|
||||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||||
tauri-plugin-updater = "2"
|
tauri-plugin-updater = "2"
|
||||||
|
once_cell = "1.18"
|
||||||
|
|||||||
@ -7,6 +7,9 @@ use tauri::{AppHandle, Manager, Runtime, State};
|
|||||||
use super::{server, setup, state::AppState};
|
use super::{server, setup, state::AppState};
|
||||||
|
|
||||||
const CONFIGURATION_FILE_NAME: &str = "settings.json";
|
const CONFIGURATION_FILE_NAME: &str = "settings.json";
|
||||||
|
const DEFAULT_MCP_CONFIG: &str = r#"{
|
||||||
|
"mcpServers": {}
|
||||||
|
}"#;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct AppConfiguration {
|
pub struct AppConfiguration {
|
||||||
@ -93,6 +96,10 @@ pub fn update_app_configuration(
|
|||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn get_jan_data_folder_path<R: Runtime>(app_handle: tauri::AppHandle<R>) -> PathBuf {
|
pub fn get_jan_data_folder_path<R: Runtime>(app_handle: tauri::AppHandle<R>) -> PathBuf {
|
||||||
|
if cfg!(test) {
|
||||||
|
return PathBuf::from("./data");
|
||||||
|
}
|
||||||
|
|
||||||
let app_configurations = get_app_configurations(app_handle);
|
let app_configurations = get_app_configurations(app_handle);
|
||||||
PathBuf::from(app_configurations.data_folder)
|
PathBuf::from(app_configurations.data_folder)
|
||||||
}
|
}
|
||||||
@ -348,3 +355,29 @@ pub async fn call_tool(
|
|||||||
|
|
||||||
Err(format!("Tool {} not found", tool_name))
|
Err(format!("Tool {} not found", tool_name))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
|
||||||
|
let mut path = get_jan_data_folder_path(app);
|
||||||
|
path.push("mcp_config.json");
|
||||||
|
log::info!("read mcp configs, path: {:?}", path);
|
||||||
|
|
||||||
|
// Create default empty config if file doesn't exist
|
||||||
|
if !path.exists() {
|
||||||
|
log::info!("mcp_config.json not found, creating default empty config");
|
||||||
|
fs::write(&path, DEFAULT_MCP_CONFIG)
|
||||||
|
.map_err(|e| format!("Failed to create default MCP config: {}", e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let contents = fs::read_to_string(path).map_err(|e| e.to_string())?;
|
||||||
|
return Ok(contents);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn save_mcp_configs(app: AppHandle, configs: String) -> Result<(), String> {
|
||||||
|
let mut path = get_jan_data_folder_path(app);
|
||||||
|
path.push("mcp_config.json");
|
||||||
|
log::info!("save mcp configs, path: {:?}", path);
|
||||||
|
|
||||||
|
fs::write(path, configs).map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
|||||||
@ -107,6 +107,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use std::fs::{self, File};
|
use std::fs::{self, File};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use serde_json::to_string;
|
||||||
use tauri::test::mock_app;
|
use tauri::test::mock_app;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -154,9 +155,11 @@ mod tests {
|
|||||||
fn test_exists_sync() {
|
fn test_exists_sync() {
|
||||||
let app = mock_app();
|
let app = mock_app();
|
||||||
let path = "file://test_exists_sync_file";
|
let path = "file://test_exists_sync_file";
|
||||||
let file_path = get_jan_data_folder_path(app.handle().clone()).join(path);
|
let dir_path = get_jan_data_folder_path(app.handle().clone());
|
||||||
|
fs::create_dir_all(&dir_path).unwrap();
|
||||||
|
let file_path = dir_path.join("test_exists_sync_file");
|
||||||
File::create(&file_path).unwrap();
|
File::create(&file_path).unwrap();
|
||||||
let args = vec![path.to_string()];
|
let args: Vec<String> = vec![path.to_string()];
|
||||||
let result = exists_sync(app.handle().clone(), args).unwrap();
|
let result = exists_sync(app.handle().clone(), args).unwrap();
|
||||||
assert!(result);
|
assert!(result);
|
||||||
fs::remove_file(file_path).unwrap();
|
fs::remove_file(file_path).unwrap();
|
||||||
@ -166,7 +169,9 @@ mod tests {
|
|||||||
fn test_read_file_sync() {
|
fn test_read_file_sync() {
|
||||||
let app = mock_app();
|
let app = mock_app();
|
||||||
let path = "file://test_read_file_sync_file";
|
let path = "file://test_read_file_sync_file";
|
||||||
let file_path = get_jan_data_folder_path(app.handle().clone()).join(path);
|
let dir_path = get_jan_data_folder_path(app.handle().clone());
|
||||||
|
fs::create_dir_all(&dir_path).unwrap();
|
||||||
|
let file_path = dir_path.join("test_read_file_sync_file");
|
||||||
let mut file = File::create(&file_path).unwrap();
|
let mut file = File::create(&file_path).unwrap();
|
||||||
file.write_all(b"test content").unwrap();
|
file.write_all(b"test content").unwrap();
|
||||||
let args = vec![path.to_string()];
|
let args = vec![path.to_string()];
|
||||||
@ -184,7 +189,7 @@ mod tests {
|
|||||||
File::create(dir_path.join("file1.txt")).unwrap();
|
File::create(dir_path.join("file1.txt")).unwrap();
|
||||||
File::create(dir_path.join("file2.txt")).unwrap();
|
File::create(dir_path.join("file2.txt")).unwrap();
|
||||||
|
|
||||||
let args = vec![path.to_string()];
|
let args = vec![dir_path.to_string_lossy().to_string()];
|
||||||
let result = readdir_sync(app.handle().clone(), args).unwrap();
|
let result = readdir_sync(app.handle().clone(), args).unwrap();
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,11 @@ use std::{collections::HashMap, sync::Arc};
|
|||||||
|
|
||||||
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt};
|
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tauri::{AppHandle, State};
|
||||||
use tokio::{process::Command, sync::Mutex};
|
use tokio::{process::Command, sync::Mutex};
|
||||||
|
|
||||||
|
use super::{cmd::get_jan_data_folder_path, state::AppState};
|
||||||
|
|
||||||
/// Runs MCP commands by reading configuration from a JSON file and initializing servers
|
/// Runs MCP commands by reading configuration from a JSON file and initializing servers
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@ -77,6 +80,35 @@ fn extract_command_args(
|
|||||||
Some((command, args, envs))
|
Some((command, args, envs))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn restart_mcp_servers(
|
||||||
|
app: AppHandle,
|
||||||
|
state: State<'_, AppState>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let app_path = get_jan_data_folder_path(app.clone());
|
||||||
|
let app_path_str = app_path.to_str().unwrap().to_string();
|
||||||
|
let servers = state.mcp_servers.clone();
|
||||||
|
// Stop the servers
|
||||||
|
stop_mcp_servers(state.mcp_servers.clone()).await?;
|
||||||
|
|
||||||
|
// Restart the servers
|
||||||
|
run_mcp_commands(app_path_str, servers).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop_mcp_servers(
|
||||||
|
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let mut servers_map = servers_state.lock().await;
|
||||||
|
let keys: Vec<String> = servers_map.keys().cloned().collect();
|
||||||
|
for key in keys {
|
||||||
|
if let Some(service) = servers_map.remove(&key) {
|
||||||
|
service.cancel().await.map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(servers_map); // Release the lock after stopping
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@ -4,3 +4,5 @@ pub mod mcp;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
pub mod setup;
|
pub mod setup;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
pub mod threads;
|
||||||
|
pub mod utils;
|
||||||
613
src-tauri/src/core/threads.rs
Normal file
613
src-tauri/src/core/threads.rs
Normal file
@ -0,0 +1,613 @@
|
|||||||
|
/*!
|
||||||
|
Thread and Message Persistence Module
|
||||||
|
|
||||||
|
This module provides all logic for managing threads and their messages, including creation, modification, deletion, and listing.
|
||||||
|
Messages for each thread are persisted in a JSONL file (messages.jsonl) per thread directory.
|
||||||
|
|
||||||
|
**Concurrency and Consistency Guarantee:**
|
||||||
|
- All operations that write or modify messages for a thread are protected by a global, per-thread asynchronous lock.
|
||||||
|
- This design ensures that only one operation can write to a thread's messages.jsonl file at a time, preventing race conditions.
|
||||||
|
- As a result, the messages.jsonl file for each thread is always consistent and never corrupted, even under concurrent access.
|
||||||
|
*/
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs::{self, File};
|
||||||
|
use std::io::{BufRead, BufReader, Write};
|
||||||
|
use tauri::command;
|
||||||
|
use tauri::Runtime;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
// For async file write serialization
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
// Global per-thread locks for message file writes
|
||||||
|
static MESSAGE_LOCKS: Lazy<Mutex<HashMap<String, Arc<Mutex<()>>>>> =
|
||||||
|
Lazy::new(|| Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
|
use super::utils::{
|
||||||
|
ensure_data_dirs, ensure_thread_dir_exists, get_data_dir, get_messages_path, get_thread_dir,
|
||||||
|
get_thread_metadata_path, THREADS_FILE,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct Thread {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub title: String,
|
||||||
|
pub assistants: Vec<ThreadAssistantInfo>,
|
||||||
|
pub created: i64,
|
||||||
|
pub updated: i64,
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ThreadMessage {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub thread_id: String,
|
||||||
|
pub assistant_id: Option<String>,
|
||||||
|
pub attachments: Option<Vec<Attachment>>,
|
||||||
|
pub role: String,
|
||||||
|
pub content: Vec<ThreadContent>,
|
||||||
|
pub status: String,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub completed_at: i64,
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
pub type_: Option<String>,
|
||||||
|
pub error_code: Option<String>,
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct Attachment {
|
||||||
|
pub file_id: Option<String>,
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum Tool {
|
||||||
|
#[serde(rename = "file_search")]
|
||||||
|
FileSearch,
|
||||||
|
#[serde(rename = "code_interpreter")]
|
||||||
|
CodeInterpreter,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ThreadContent {
|
||||||
|
pub type_: String,
|
||||||
|
pub text: Option<ContentValue>,
|
||||||
|
pub image_url: Option<ImageContentValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ContentValue {
|
||||||
|
pub value: String,
|
||||||
|
pub annotations: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ImageContentValue {
|
||||||
|
pub detail: Option<String>,
|
||||||
|
pub url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ThreadAssistantInfo {
|
||||||
|
pub assistant_id: String,
|
||||||
|
pub assistant_name: String,
|
||||||
|
pub model: ModelInfo,
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
pub tools: Option<Vec<AssistantTool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub settings: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum AssistantTool {
|
||||||
|
#[serde(rename = "code_interpreter")]
|
||||||
|
CodeInterpreter,
|
||||||
|
#[serde(rename = "retrieval")]
|
||||||
|
Retrieval,
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
Function {
|
||||||
|
name: String,
|
||||||
|
description: Option<String>,
|
||||||
|
parameters: Option<serde_json::Value>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct ThreadState {
|
||||||
|
pub has_more: bool,
|
||||||
|
pub waiting_for_response: bool,
|
||||||
|
pub error: Option<String>,
|
||||||
|
pub last_message: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lists all threads by reading their metadata from the threads directory.
|
||||||
|
/// Returns a vector of thread metadata as JSON values.
|
||||||
|
#[command]
|
||||||
|
pub async fn list_threads<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
) -> Result<Vec<serde_json::Value>, String> {
|
||||||
|
ensure_data_dirs(app_handle.clone())?;
|
||||||
|
let data_dir = get_data_dir(app_handle.clone());
|
||||||
|
let mut threads = Vec::new();
|
||||||
|
|
||||||
|
if !data_dir.exists() {
|
||||||
|
return Ok(threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
for entry in fs::read_dir(&data_dir).map_err(|e| e.to_string())? {
|
||||||
|
let entry = entry.map_err(|e| e.to_string())?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path.is_dir() {
|
||||||
|
let thread_metadata_path = path.join(THREADS_FILE);
|
||||||
|
if thread_metadata_path.exists() {
|
||||||
|
let data = fs::read_to_string(&thread_metadata_path).map_err(|e| e.to_string())?;
|
||||||
|
match serde_json::from_str(&data) {
|
||||||
|
Ok(thread) => threads.push(thread),
|
||||||
|
Err(e) => {
|
||||||
|
println!("Failed to parse thread file: {}", e);
|
||||||
|
continue; // skip invalid thread files
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(threads)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new thread, assigns it a unique ID, and persists its metadata.
|
||||||
|
/// Ensures the thread directory exists and writes thread.json.
|
||||||
|
#[command]
|
||||||
|
pub async fn create_thread<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
mut thread: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
ensure_data_dirs(app_handle.clone())?;
|
||||||
|
let uuid = Uuid::new_v4().to_string();
|
||||||
|
thread["id"] = serde_json::Value::String(uuid.clone());
|
||||||
|
let thread_dir = get_thread_dir(app_handle.clone(), &uuid);
|
||||||
|
if !thread_dir.exists() {
|
||||||
|
fs::create_dir_all(&thread_dir).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
let path = get_thread_metadata_path(app_handle.clone(), &uuid);
|
||||||
|
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;
|
||||||
|
fs::write(path, data).map_err(|e| e.to_string())?;
|
||||||
|
Ok(thread)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Modifies an existing thread's metadata by overwriting its thread.json file.
|
||||||
|
/// Returns an error if the thread directory does not exist.
|
||||||
|
#[command]
|
||||||
|
pub async fn modify_thread<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread: serde_json::Value,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let thread_id = thread
|
||||||
|
.get("id")
|
||||||
|
.and_then(|id| id.as_str())
|
||||||
|
.ok_or("Missing thread id")?;
|
||||||
|
let thread_dir = get_thread_dir(app_handle.clone(), thread_id);
|
||||||
|
if !thread_dir.exists() {
|
||||||
|
return Err("Thread directory does not exist".to_string());
|
||||||
|
}
|
||||||
|
let path = get_thread_metadata_path(app_handle.clone(), thread_id);
|
||||||
|
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;
|
||||||
|
fs::write(path, data).map_err(|e| e.to_string())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deletes a thread and all its associated files by removing its directory.
|
||||||
|
#[command]
|
||||||
|
pub async fn delete_thread<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let thread_dir = get_thread_dir(app_handle.clone(), &thread_id);
|
||||||
|
if thread_dir.exists() {
|
||||||
|
fs::remove_dir_all(thread_dir).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lists all messages for a given thread by reading and parsing its messages.jsonl file.
|
||||||
|
/// Returns a vector of message JSON values.
|
||||||
|
#[command]
|
||||||
|
pub async fn list_messages<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
) -> Result<Vec<serde_json::Value>, String> {
|
||||||
|
let path = get_messages_path(app_handle, &thread_id);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let file = File::open(&path).map_err(|e| {
|
||||||
|
eprintln!("Error opening file {}: {}", path.display(), e);
|
||||||
|
e.to_string()
|
||||||
|
})?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line.map_err(|e| {
|
||||||
|
eprintln!("Error reading line from file {}: {}", path.display(), e);
|
||||||
|
e.to_string()
|
||||||
|
})?;
|
||||||
|
let message: serde_json::Value = serde_json::from_str(&line).map_err(|e| {
|
||||||
|
eprintln!(
|
||||||
|
"Error parsing JSON from line in file {}: {}",
|
||||||
|
path.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
e.to_string()
|
||||||
|
})?;
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Appends a new message to a thread's messages.jsonl file.
|
||||||
|
/// Uses a per-thread async lock to prevent race conditions and ensure file consistency.
|
||||||
|
#[command]
|
||||||
|
pub async fn create_message<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
mut message: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let thread_id = {
|
||||||
|
let id = message
|
||||||
|
.get("thread_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or("Missing thread_id")?;
|
||||||
|
id.to_string()
|
||||||
|
};
|
||||||
|
ensure_thread_dir_exists(app_handle.clone(), &thread_id)?;
|
||||||
|
let path = get_messages_path(app_handle.clone(), &thread_id);
|
||||||
|
|
||||||
|
if message.get("id").is_none() {
|
||||||
|
let uuid = Uuid::new_v4().to_string();
|
||||||
|
message["id"] = serde_json::Value::String(uuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire per-thread lock before writing
|
||||||
|
{
|
||||||
|
let mut locks = MESSAGE_LOCKS.lock().await;
|
||||||
|
let lock = locks
|
||||||
|
.entry(thread_id.to_string())
|
||||||
|
.or_insert_with(|| Arc::new(Mutex::new(())))
|
||||||
|
.clone();
|
||||||
|
drop(locks); // Release the map lock before awaiting the file lock
|
||||||
|
|
||||||
|
let _guard = lock.lock().await;
|
||||||
|
|
||||||
|
let mut file: File = fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(path)
|
||||||
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
let data = serde_json::to_string(&message).map_err(|e| e.to_string())?;
|
||||||
|
writeln!(file, "{}", data).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Modifies an existing message in a thread's messages.jsonl file.
|
||||||
|
/// Uses a per-thread async lock to prevent race conditions and ensure file consistency.
|
||||||
|
/// Rewrites the entire messages.jsonl file for the thread.
|
||||||
|
#[command]
|
||||||
|
pub async fn modify_message<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
message: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let thread_id = message
|
||||||
|
.get("thread_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or("Missing thread_id")?;
|
||||||
|
let message_id = message
|
||||||
|
.get("id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or("Missing message id")?;
|
||||||
|
|
||||||
|
// Acquire per-thread lock before modifying
|
||||||
|
{
|
||||||
|
let mut locks = MESSAGE_LOCKS.lock().await;
|
||||||
|
let lock = locks
|
||||||
|
.entry(thread_id.to_string())
|
||||||
|
.or_insert_with(|| Arc::new(Mutex::new(())))
|
||||||
|
.clone();
|
||||||
|
drop(locks); // Release the map lock before awaiting the file lock
|
||||||
|
|
||||||
|
let _guard = lock.lock().await;
|
||||||
|
|
||||||
|
let mut messages = list_messages(app_handle.clone(), thread_id.to_string()).await?;
|
||||||
|
if let Some(index) = messages
|
||||||
|
.iter()
|
||||||
|
.position(|m| m.get("id").and_then(|v| v.as_str()) == Some(message_id))
|
||||||
|
{
|
||||||
|
messages[index] = message.clone();
|
||||||
|
|
||||||
|
// Rewrite all messages
|
||||||
|
let path = get_messages_path(app_handle.clone(), thread_id);
|
||||||
|
let mut file = File::create(path).map_err(|e| e.to_string())?;
|
||||||
|
for msg in messages {
|
||||||
|
let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
|
||||||
|
writeln!(file, "{}", data).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deletes a message from a thread's messages.jsonl file by message ID.
|
||||||
|
/// Rewrites the entire messages.jsonl file for the thread.
|
||||||
|
/// Uses a per-thread async lock to prevent race conditions and ensure file consistency.
|
||||||
|
#[command]
|
||||||
|
pub async fn delete_message<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
message_id: String,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Acquire per-thread lock before modifying
|
||||||
|
{
|
||||||
|
let mut locks = MESSAGE_LOCKS.lock().await;
|
||||||
|
let lock = locks
|
||||||
|
.entry(thread_id.to_string())
|
||||||
|
.or_insert_with(|| Arc::new(Mutex::new(())))
|
||||||
|
.clone();
|
||||||
|
drop(locks); // Release the map lock before awaiting the file lock
|
||||||
|
|
||||||
|
let _guard = lock.lock().await;
|
||||||
|
|
||||||
|
let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?;
|
||||||
|
messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str()));
|
||||||
|
|
||||||
|
// Rewrite remaining messages
|
||||||
|
let path = get_messages_path(app_handle.clone(), &thread_id);
|
||||||
|
let mut file = File::create(path).map_err(|e| e.to_string())?;
|
||||||
|
for msg in messages {
|
||||||
|
let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
|
||||||
|
writeln!(file, "{}", data).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves the first assistant associated with a thread.
|
||||||
|
/// Returns an error if the thread or assistant is not found.
|
||||||
|
#[command]
|
||||||
|
pub async fn get_thread_assistant<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let path = get_thread_metadata_path(app_handle, &thread_id);
|
||||||
|
if !path.exists() {
|
||||||
|
return Err("Thread not found".to_string());
|
||||||
|
}
|
||||||
|
let data = fs::read_to_string(&path).map_err(|e| e.to_string())?;
|
||||||
|
let thread: serde_json::Value = serde_json::from_str(&data).map_err(|e| e.to_string())?;
|
||||||
|
if let Some(assistants) = thread.get("assistants").and_then(|a| a.as_array()) {
|
||||||
|
if let Some(first) = assistants.get(0) {
|
||||||
|
Ok(first.clone())
|
||||||
|
} else {
|
||||||
|
Err("Assistant not found".to_string())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err("Assistant not found".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adds a new assistant to a thread's metadata.
|
||||||
|
/// Updates thread.json with the new assistant information.
|
||||||
|
#[command]
|
||||||
|
pub async fn create_thread_assistant<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
assistant: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let path = get_thread_metadata_path(app_handle.clone(), &thread_id);
|
||||||
|
if !path.exists() {
|
||||||
|
return Err("Thread not found".to_string());
|
||||||
|
}
|
||||||
|
let mut thread: serde_json::Value = {
|
||||||
|
let data = fs::read_to_string(&path).map_err(|e| e.to_string())?;
|
||||||
|
serde_json::from_str(&data).map_err(|e| e.to_string())?
|
||||||
|
};
|
||||||
|
if let Some(assistants) = thread.get_mut("assistants").and_then(|a| a.as_array_mut()) {
|
||||||
|
assistants.push(assistant.clone());
|
||||||
|
} else {
|
||||||
|
thread["assistants"] = serde_json::Value::Array(vec![assistant.clone()]);
|
||||||
|
}
|
||||||
|
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;
|
||||||
|
fs::write(&path, data).map_err(|e| e.to_string())?;
|
||||||
|
Ok(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Modifies an existing assistant's information in a thread's metadata.
|
||||||
|
/// Updates thread.json with the modified assistant data.
|
||||||
|
#[command]
|
||||||
|
pub async fn modify_thread_assistant<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: String,
|
||||||
|
assistant: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let path = get_thread_metadata_path(app_handle.clone(), &thread_id);
|
||||||
|
if !path.exists() {
|
||||||
|
return Err("Thread not found".to_string());
|
||||||
|
}
|
||||||
|
let mut thread: serde_json::Value = {
|
||||||
|
let data = fs::read_to_string(&path).map_err(|e| e.to_string())?;
|
||||||
|
serde_json::from_str(&data).map_err(|e| e.to_string())?
|
||||||
|
};
|
||||||
|
let assistant_id = assistant
|
||||||
|
.get("assistant_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or("Missing assistant_id")?;
|
||||||
|
if let Some(assistants) = thread
|
||||||
|
.get_mut("assistants")
|
||||||
|
.and_then(|a: &mut serde_json::Value| a.as_array_mut())
|
||||||
|
{
|
||||||
|
if let Some(index) = assistants
|
||||||
|
.iter()
|
||||||
|
.position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id))
|
||||||
|
{
|
||||||
|
assistants[index] = assistant.clone();
|
||||||
|
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;
|
||||||
|
fs::write(&path, data).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::core::cmd::get_jan_data_folder_path;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tauri::test::{mock_app, MockRuntime};
|
||||||
|
|
||||||
|
// Helper to create a mock app handle with a temp data dir
|
||||||
|
fn mock_app_with_temp_data_dir() -> (tauri::App<MockRuntime>, PathBuf) {
|
||||||
|
let app = mock_app();
|
||||||
|
let data_dir = get_jan_data_folder_path(app.handle().clone());
|
||||||
|
println!("Mock app data dir: {}", data_dir.display());
|
||||||
|
// Patch get_data_dir to use temp dir (requires get_data_dir to be overridable or injectable)
|
||||||
|
// For now, we assume get_data_dir uses tauri::api::path::app_data_dir(&app_handle)
|
||||||
|
// and that we can set the environment variable to redirect it.
|
||||||
|
(app, data_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_list_threads() {
|
||||||
|
let (app, data_dir) = mock_app_with_temp_data_dir();
|
||||||
|
// Create a thread
|
||||||
|
let thread = json!({
|
||||||
|
"object": "thread",
|
||||||
|
"title": "Test Thread",
|
||||||
|
"assistants": [],
|
||||||
|
"created": 1234567890,
|
||||||
|
"updated": 1234567890,
|
||||||
|
"metadata": null
|
||||||
|
});
|
||||||
|
let created = create_thread(app.handle().clone(), thread.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(created["title"], "Test Thread");
|
||||||
|
|
||||||
|
// List threads
|
||||||
|
let threads = list_threads(app.handle().clone()).await.unwrap();
|
||||||
|
assert!(threads.len() > 0);
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
fs::remove_dir_all(data_dir).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_list_messages() {
|
||||||
|
let (app, data_dir) = mock_app_with_temp_data_dir();
|
||||||
|
// Create a thread first
|
||||||
|
let thread = json!({
|
||||||
|
"object": "thread",
|
||||||
|
"title": "Msg Thread",
|
||||||
|
"assistants": [],
|
||||||
|
"created": 123,
|
||||||
|
"updated": 123,
|
||||||
|
"metadata": null
|
||||||
|
});
|
||||||
|
let created = create_thread(app.handle().clone(), thread.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let thread_id = created["id"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
// Create a message
|
||||||
|
let message = json!({
|
||||||
|
"object": "message",
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": null,
|
||||||
|
"attachments": null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [],
|
||||||
|
"status": "sent",
|
||||||
|
"created_at": 123,
|
||||||
|
"completed_at": 123,
|
||||||
|
"metadata": null,
|
||||||
|
"type_": null,
|
||||||
|
"error_code": null,
|
||||||
|
"tool_call_id": null
|
||||||
|
});
|
||||||
|
let created_msg = create_message(app.handle().clone(), message).await.unwrap();
|
||||||
|
assert_eq!(created_msg["role"], "user");
|
||||||
|
|
||||||
|
// List messages
|
||||||
|
let messages = list_messages(app.handle().clone(), thread_id.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(messages.len() > 0);
|
||||||
|
assert_eq!(messages[0]["role"], "user");
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
fs::remove_dir_all(data_dir).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_get_thread_assistant() {
|
||||||
|
let (app, data_dir) = mock_app_with_temp_data_dir();
|
||||||
|
// Create a thread
|
||||||
|
let thread = json!({
|
||||||
|
"object": "thread",
|
||||||
|
"title": "Assistant Thread",
|
||||||
|
"assistants": [],
|
||||||
|
"created": 1,
|
||||||
|
"updated": 1,
|
||||||
|
"metadata": null
|
||||||
|
});
|
||||||
|
let created = create_thread(app.handle().clone(), thread.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let thread_id = created["id"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
// Add assistant
|
||||||
|
let assistant = json!({
|
||||||
|
"id": "assistant-1",
|
||||||
|
"assistant_name": "Test Assistant",
|
||||||
|
"model": {
|
||||||
|
"id": "model-1",
|
||||||
|
"name": "Test Model",
|
||||||
|
"settings": json!({})
|
||||||
|
},
|
||||||
|
"instructions": null,
|
||||||
|
"tools": null
|
||||||
|
});
|
||||||
|
let _ = create_thread_assistant(app.handle().clone(), thread_id.clone(), assistant.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Get assistant
|
||||||
|
let got = get_thread_assistant(app.handle().clone(), thread_id.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(got["assistant_name"], "Test Assistant");
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
fs::remove_dir_all(data_dir).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
48
src-tauri/src/core/utils/mod.rs
Normal file
48
src-tauri/src/core/utils/mod.rs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tauri::Runtime;
|
||||||
|
|
||||||
|
use super::cmd::get_jan_data_folder_path;
|
||||||
|
|
||||||
|
pub const THREADS_DIR: &str = "threads";
|
||||||
|
pub const THREADS_FILE: &str = "thread.json";
|
||||||
|
pub const MESSAGES_FILE: &str = "messages.jsonl";
|
||||||
|
|
||||||
|
pub fn get_data_dir<R: Runtime>(app_handle: tauri::AppHandle<R>) -> PathBuf {
|
||||||
|
get_jan_data_folder_path(app_handle).join(THREADS_DIR)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_thread_dir<R: Runtime>(app_handle: tauri::AppHandle<R>, thread_id: &str) -> PathBuf {
|
||||||
|
get_data_dir(app_handle).join(thread_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_thread_metadata_path<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: &str,
|
||||||
|
) -> PathBuf {
|
||||||
|
get_thread_dir(app_handle, thread_id).join(THREADS_FILE)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_messages_path<R: Runtime>(app_handle: tauri::AppHandle<R>, thread_id: &str) -> PathBuf {
|
||||||
|
get_thread_dir(app_handle, thread_id).join(MESSAGES_FILE)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ensure_data_dirs<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Result<(), String> {
|
||||||
|
let data_dir = get_data_dir(app_handle.clone());
|
||||||
|
if !data_dir.exists() {
|
||||||
|
fs::create_dir_all(&data_dir).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ensure_thread_dir_exists<R: Runtime>(
|
||||||
|
app_handle: tauri::AppHandle<R>,
|
||||||
|
thread_id: &str,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
ensure_data_dirs(app_handle.clone())?;
|
||||||
|
let thread_dir = get_thread_dir(app_handle, thread_id);
|
||||||
|
if !thread_dir.exists() {
|
||||||
|
fs::create_dir(&thread_dir).map_err(|e| e.to_string())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@ -39,9 +39,24 @@ pub fn run() {
|
|||||||
core::cmd::app_token,
|
core::cmd::app_token,
|
||||||
core::cmd::start_server,
|
core::cmd::start_server,
|
||||||
core::cmd::stop_server,
|
core::cmd::stop_server,
|
||||||
|
core::cmd::save_mcp_configs,
|
||||||
|
core::cmd::get_mcp_configs,
|
||||||
// MCP commands
|
// MCP commands
|
||||||
core::cmd::get_tools,
|
core::cmd::get_tools,
|
||||||
core::cmd::call_tool
|
core::cmd::call_tool,
|
||||||
|
core::mcp::restart_mcp_servers,
|
||||||
|
// Threads
|
||||||
|
core::threads::list_threads,
|
||||||
|
core::threads::create_thread,
|
||||||
|
core::threads::modify_thread,
|
||||||
|
core::threads::delete_thread,
|
||||||
|
core::threads::list_messages,
|
||||||
|
core::threads::create_message,
|
||||||
|
core::threads::modify_message,
|
||||||
|
core::threads::delete_message,
|
||||||
|
core::threads::get_thread_assistant,
|
||||||
|
core::threads::create_thread_assistant,
|
||||||
|
core::threads::modify_thread_assistant
|
||||||
])
|
])
|
||||||
.manage(AppState {
|
.manage(AppState {
|
||||||
app_token: Some(generate_app_token()),
|
app_token: Some(generate_app_token()),
|
||||||
|
|||||||
@ -232,26 +232,6 @@ const ModelDropdown = ({
|
|||||||
stopModel()
|
stopModel()
|
||||||
|
|
||||||
if (activeThread) {
|
if (activeThread) {
|
||||||
// Change assistand tools based on model support RAG
|
|
||||||
updateThreadMetadata({
|
|
||||||
...activeThread,
|
|
||||||
assistants: [
|
|
||||||
{
|
|
||||||
...activeAssistant,
|
|
||||||
tools: [
|
|
||||||
{
|
|
||||||
type: 'retrieval',
|
|
||||||
enabled: model?.engine === InferenceEngine.cortex,
|
|
||||||
settings: {
|
|
||||||
...(activeAssistant.tools &&
|
|
||||||
activeAssistant.tools[0]?.settings),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
const contextLength = model?.settings.ctx_len
|
const contextLength = model?.settings.ctx_len
|
||||||
? Math.min(8192, model?.settings.ctx_len ?? 8192)
|
? Math.min(8192, model?.settings.ctx_len ?? 8192)
|
||||||
: undefined
|
: undefined
|
||||||
@ -273,11 +253,25 @@ const ModelDropdown = ({
|
|||||||
|
|
||||||
// Update model parameter to the thread file
|
// Update model parameter to the thread file
|
||||||
if (model)
|
if (model)
|
||||||
updateModelParameter(activeThread, {
|
updateModelParameter(
|
||||||
params: modelParams,
|
activeThread,
|
||||||
modelId: model.id,
|
{
|
||||||
engine: model.engine,
|
params: modelParams,
|
||||||
})
|
modelId: model.id,
|
||||||
|
engine: model.engine,
|
||||||
|
},
|
||||||
|
// Update tools
|
||||||
|
[
|
||||||
|
{
|
||||||
|
type: 'retrieval',
|
||||||
|
enabled: model?.engine === InferenceEngine.cortex,
|
||||||
|
settings: {
|
||||||
|
...(activeAssistant.tools &&
|
||||||
|
activeAssistant.tools[0]?.settings),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
|
|||||||
@ -83,7 +83,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
|
|||||||
value={searchText}
|
value={searchText}
|
||||||
clearable={searchText.length > 0}
|
clearable={searchText.length > 0}
|
||||||
onClear={onClear}
|
onClear={onClear}
|
||||||
className="border-0 bg-[hsla(var(--app-bg))]"
|
className="bg-[hsla(var(--app-bg))]"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
onSearchLocal?.(inputRef.current?.value ?? '')
|
onSearchLocal?.(inputRef.current?.value ?? '')
|
||||||
}}
|
}}
|
||||||
|
|||||||
@ -114,7 +114,7 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const onNewMessageResponse = useCallback(
|
const onNewMessageResponse = useCallback(
|
||||||
async (message: ThreadMessage) => {
|
async (message: ThreadMessage) => {
|
||||||
if (message.type === MessageRequestType.Thread) {
|
if (message.type !== MessageRequestType.Summary) {
|
||||||
addNewMessage(message)
|
addNewMessage(message)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -129,35 +129,20 @@ export default function ModelHandler() {
|
|||||||
const updateThreadTitle = useCallback(
|
const updateThreadTitle = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
// Update only when it's finished
|
// Update only when it's finished
|
||||||
if (message.status !== MessageStatus.Ready) {
|
if (message.status !== MessageStatus.Ready) return
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
||||||
if (!thread) {
|
|
||||||
console.warn(
|
|
||||||
`Failed to update title for thread ${message.thread_id}: Thread not found!`
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
let messageContent = message.content[0]?.text?.value
|
let messageContent = message.content[0]?.text?.value
|
||||||
if (!messageContent) {
|
if (!thread || !messageContent) return
|
||||||
console.warn(
|
|
||||||
`Failed to update title for thread ${message.thread_id}: Responded content is null!`
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// No new line character is presented in the title
|
// No new line character is presented in the title
|
||||||
// And non-alphanumeric characters should be removed
|
// And non-alphanumeric characters should be removed
|
||||||
if (messageContent.includes('\n')) {
|
if (messageContent.includes('\n'))
|
||||||
messageContent = messageContent.replace(/\n/g, ' ')
|
messageContent = messageContent.replace(/\n/g, ' ')
|
||||||
}
|
|
||||||
const match = messageContent.match(/<\/think>(.*)$/)
|
const match = messageContent.match(/<\/think>(.*)$/)
|
||||||
if (match) {
|
if (match) messageContent = match[1]
|
||||||
messageContent = match[1]
|
|
||||||
}
|
|
||||||
// Remove non-alphanumeric characters
|
// Remove non-alphanumeric characters
|
||||||
const cleanedMessageContent = messageContent
|
const cleanedMessageContent = messageContent
|
||||||
.replace(/[^\p{L}\s]+/gu, '')
|
.replace(/[^\p{L}\s]+/gu, '')
|
||||||
@ -193,18 +178,13 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const updateThreadMessage = useCallback(
|
const updateThreadMessage = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
if (
|
updateMessage(
|
||||||
messageGenerationSubscriber.current &&
|
message.id,
|
||||||
message.thread_id === activeThreadRef.current?.id &&
|
message.thread_id,
|
||||||
!messageGenerationSubscriber.current!.thread_id
|
message.content,
|
||||||
) {
|
message.metadata,
|
||||||
updateMessage(
|
message.status
|
||||||
message.id,
|
)
|
||||||
message.thread_id,
|
|
||||||
message.content,
|
|
||||||
message.status
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (message.status === MessageStatus.Pending) {
|
if (message.status === MessageStatus.Pending) {
|
||||||
if (message.content.length) {
|
if (message.content.length) {
|
||||||
@ -236,82 +216,66 @@ export default function ModelHandler() {
|
|||||||
model: activeModelRef.current?.name,
|
model: activeModelRef.current?.name,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return
|
} else {
|
||||||
} else if (
|
// Mark the thread as not waiting for response
|
||||||
message.status === MessageStatus.Error &&
|
updateThreadWaiting(message.thread_id, false)
|
||||||
activeModelRef.current?.engine &&
|
|
||||||
engines &&
|
|
||||||
isLocalEngine(engines, activeModelRef.current.engine)
|
|
||||||
) {
|
|
||||||
;(async () => {
|
|
||||||
if (
|
|
||||||
!(await extensionManager
|
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
|
||||||
?.isModelLoaded(activeModelRef.current?.id as string))
|
|
||||||
) {
|
|
||||||
setActiveModel(undefined)
|
|
||||||
setStateModel({ state: 'start', loading: false, model: undefined })
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
}
|
|
||||||
// Mark the thread as not waiting for response
|
|
||||||
updateThreadWaiting(message.thread_id, false)
|
|
||||||
|
|
||||||
setIsGeneratingResponse(false)
|
setIsGeneratingResponse(false)
|
||||||
|
|
||||||
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
const thread = threadsRef.current?.find(
|
||||||
if (!thread) return
|
(e) => e.id == message.thread_id
|
||||||
|
)
|
||||||
|
if (!thread) return
|
||||||
|
|
||||||
const messageContent = message.content[0]?.text?.value
|
const messageContent = message.content[0]?.text?.value
|
||||||
|
|
||||||
const metadata = {
|
const metadata = {
|
||||||
...thread.metadata,
|
...thread.metadata,
|
||||||
...(messageContent && { lastMessage: messageContent }),
|
...(messageContent && { lastMessage: messageContent }),
|
||||||
updated_at: Date.now(),
|
updated_at: Date.now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
updateThread({
|
updateThread({
|
||||||
...thread,
|
|
||||||
metadata,
|
|
||||||
})
|
|
||||||
|
|
||||||
extensionManager
|
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
|
||||||
?.modifyThread({
|
|
||||||
...thread,
|
...thread,
|
||||||
metadata,
|
metadata,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Update message's metadata with token usage
|
extensionManager
|
||||||
message.metadata = {
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
...message.metadata,
|
?.modifyThread({
|
||||||
token_speed: tokenSpeedRef.current?.tokenSpeed,
|
...thread,
|
||||||
model: activeModelRef.current?.name,
|
metadata,
|
||||||
}
|
})
|
||||||
|
|
||||||
if (message.status === MessageStatus.Error) {
|
// Update message's metadata with token usage
|
||||||
message.metadata = {
|
message.metadata = {
|
||||||
...message.metadata,
|
...message.metadata,
|
||||||
error: message.content[0]?.text?.value,
|
token_speed: tokenSpeedRef.current?.tokenSpeed,
|
||||||
error_code: message.error_code,
|
model: activeModelRef.current?.name,
|
||||||
}
|
}
|
||||||
}
|
|
||||||
;(async () => {
|
if (message.status === MessageStatus.Error) {
|
||||||
const updatedMessage = await extensionManager
|
message.metadata = {
|
||||||
|
...message.metadata,
|
||||||
|
error: message.content[0]?.text?.value,
|
||||||
|
error_code: message.error_code,
|
||||||
|
}
|
||||||
|
// Unassign active model if any
|
||||||
|
setActiveModel(undefined)
|
||||||
|
setStateModel({
|
||||||
|
state: 'start',
|
||||||
|
loading: false,
|
||||||
|
model: undefined,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.createMessage(message)
|
?.createMessage(message)
|
||||||
.catch(() => undefined)
|
|
||||||
if (updatedMessage) {
|
|
||||||
deleteMessage(message.id)
|
|
||||||
addNewMessage(updatedMessage)
|
|
||||||
setTokenSpeed((prev) =>
|
|
||||||
prev ? { ...prev, message: updatedMessage.id } : undefined
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
// Attempt to generate the title of the Thread when needed
|
// Attempt to generate the title of the Thread when needed
|
||||||
generateThreadTitle(message, thread)
|
generateThreadTitle(message, thread)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
[setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting]
|
[setIsGeneratingResponse, updateMessage, updateThread, updateThreadWaiting]
|
||||||
@ -319,25 +283,21 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const onMessageResponseUpdate = useCallback(
|
const onMessageResponseUpdate = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
switch (message.type) {
|
if (message.type === MessageRequestType.Summary)
|
||||||
case MessageRequestType.Summary:
|
updateThreadTitle(message)
|
||||||
updateThreadTitle(message)
|
else updateThreadMessage(message)
|
||||||
break
|
|
||||||
default:
|
|
||||||
updateThreadMessage(message)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
[updateThreadMessage, updateThreadTitle]
|
[updateThreadMessage, updateThreadTitle]
|
||||||
)
|
)
|
||||||
|
|
||||||
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
||||||
// If this is the first ever prompt in the thread
|
// If this is the first ever prompt in the thread
|
||||||
if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle)
|
if (
|
||||||
|
!activeModelRef.current ||
|
||||||
|
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (!activeModelRef.current) return
|
|
||||||
|
|
||||||
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
|
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
|
||||||
if (
|
if (
|
||||||
activeModelRef.current?.engine !== InferenceEngine.cortex &&
|
activeModelRef.current?.engine !== InferenceEngine.cortex &&
|
||||||
|
|||||||
@ -165,6 +165,7 @@ export const updateMessageAtom = atom(
|
|||||||
id: string,
|
id: string,
|
||||||
conversationId: string,
|
conversationId: string,
|
||||||
text: ThreadContent[],
|
text: ThreadContent[],
|
||||||
|
metadata: Record<string, unknown> | undefined,
|
||||||
status: MessageStatus
|
status: MessageStatus
|
||||||
) => {
|
) => {
|
||||||
const messages = get(chatMessages)[conversationId] ?? []
|
const messages = get(chatMessages)[conversationId] ?? []
|
||||||
@ -172,6 +173,7 @@ export const updateMessageAtom = atom(
|
|||||||
if (message) {
|
if (message) {
|
||||||
message.content = text
|
message.content = text
|
||||||
message.status = status
|
message.status = status
|
||||||
|
message.metadata = metadata
|
||||||
const updatedMessages = [...messages]
|
const updatedMessages = [...messages]
|
||||||
|
|
||||||
const newData: Record<string, ThreadMessage[]> = {
|
const newData: Record<string, ThreadMessage[]> = {
|
||||||
@ -192,6 +194,7 @@ export const updateMessageAtom = atom(
|
|||||||
created_at: Date.now() / 1000,
|
created_at: Date.now() / 1000,
|
||||||
completed_at: Date.now() / 1000,
|
completed_at: Date.now() / 1000,
|
||||||
object: 'thread.message',
|
object: 'thread.message',
|
||||||
|
metadata: metadata,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -180,7 +180,7 @@ export const useCreateNewThread = () => {
|
|||||||
updateThreadCallback(thread)
|
updateThreadCallback(thread)
|
||||||
if (thread.assistants && thread.assistants?.length > 0) {
|
if (thread.assistants && thread.assistants?.length > 0) {
|
||||||
setActiveAssistant(thread.assistants[0])
|
setActiveAssistant(thread.assistants[0])
|
||||||
updateAssistantCallback(thread.id, thread.assistants[0])
|
return updateAssistantCallback(thread.id, thread.assistants[0])
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
|
|||||||
@ -38,12 +38,13 @@ export default function useDeleteThread() {
|
|||||||
?.listMessages(threadId)
|
?.listMessages(threadId)
|
||||||
.catch(console.error)
|
.catch(console.error)
|
||||||
if (messages) {
|
if (messages) {
|
||||||
messages.forEach((message) => {
|
for (const message of messages) {
|
||||||
extensionManager
|
await extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.deleteMessage(threadId, message.id)
|
?.deleteMessage(threadId, message.id)
|
||||||
.catch(console.error)
|
.catch(console.error)
|
||||||
})
|
}
|
||||||
|
|
||||||
const thread = threads.find((e) => e.id === threadId)
|
const thread = threads.find((e) => e.id === threadId)
|
||||||
if (thread) {
|
if (thread) {
|
||||||
const updatedThread = {
|
const updatedThread = {
|
||||||
|
|||||||
@ -24,8 +24,10 @@ import {
|
|||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionRole as OpenAIChatCompletionRole,
|
ChatCompletionRole as OpenAIChatCompletionRole,
|
||||||
ChatCompletionTool,
|
ChatCompletionTool,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
} from 'openai/resources/chat'
|
} from 'openai/resources/chat'
|
||||||
|
|
||||||
|
import { Stream } from 'openai/streaming'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
|
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
|
||||||
@ -133,12 +135,16 @@ export default function useSendChatMessage() {
|
|||||||
) => {
|
) => {
|
||||||
if (!message || message.trim().length === 0) return
|
if (!message || message.trim().length === 0) return
|
||||||
|
|
||||||
if (!activeThreadRef.current || !activeAssistantRef.current) {
|
const activeThread = activeThreadRef.current
|
||||||
|
const activeAssistant = activeAssistantRef.current
|
||||||
|
const activeModel = selectedModelRef.current
|
||||||
|
|
||||||
|
if (!activeThread || !activeAssistant) {
|
||||||
console.error('No active thread or assistant')
|
console.error('No active thread or assistant')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectedModelRef.current?.id === undefined) {
|
if (!activeModel?.id) {
|
||||||
setModelDropdownState(true)
|
setModelDropdownState(true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -151,7 +157,7 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
const prompt = message.trim()
|
const prompt = message.trim()
|
||||||
|
|
||||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
updateThreadWaiting(activeThread.id, true)
|
||||||
setCurrentPrompt('')
|
setCurrentPrompt('')
|
||||||
setEditPrompt('')
|
setEditPrompt('')
|
||||||
|
|
||||||
@ -162,15 +168,14 @@ export default function useSendChatMessage() {
|
|||||||
base64Blob = await compressImage(base64Blob, 512)
|
base64Blob = await compressImage(base64Blob, 512)
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelRequest =
|
const modelRequest = selectedModel ?? activeAssistant.model
|
||||||
selectedModelRef?.current ?? activeAssistantRef.current?.model
|
|
||||||
|
|
||||||
// Fallback support for previous broken threads
|
// Fallback support for previous broken threads
|
||||||
if (activeAssistantRef.current?.model?.id === '*') {
|
if (activeAssistant.model?.id === '*') {
|
||||||
activeAssistantRef.current.model = {
|
activeAssistant.model = {
|
||||||
id: modelRequest.id,
|
id: activeModel.id,
|
||||||
settings: modelRequest.settings,
|
settings: activeModel.settings,
|
||||||
parameters: modelRequest.parameters,
|
parameters: activeModel.parameters,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (runtimeParams.stream == null) {
|
if (runtimeParams.stream == null) {
|
||||||
@ -185,7 +190,7 @@ export default function useSendChatMessage() {
|
|||||||
settings: settingParams,
|
settings: settingParams,
|
||||||
parameters: runtimeParams,
|
parameters: runtimeParams,
|
||||||
},
|
},
|
||||||
activeThreadRef.current,
|
activeThread,
|
||||||
messages ?? currentMessages,
|
messages ?? currentMessages,
|
||||||
(await window.core.api.getTools())?.map((tool: ModelTool) => ({
|
(await window.core.api.getTools())?.map((tool: ModelTool) => ({
|
||||||
type: 'function' as const,
|
type: 'function' as const,
|
||||||
@ -196,7 +201,7 @@ export default function useSendChatMessage() {
|
|||||||
strict: false,
|
strict: false,
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
).addSystemMessage(activeAssistantRef.current?.instructions)
|
).addSystemMessage(activeAssistant.instructions)
|
||||||
|
|
||||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
|
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
|
||||||
|
|
||||||
@ -209,10 +214,10 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
// Update thread state
|
// Update thread state
|
||||||
const updatedThread: Thread = {
|
const updatedThread: Thread = {
|
||||||
...activeThreadRef.current,
|
...activeThread,
|
||||||
updated: newMessage.created_at,
|
updated: newMessage.created_at,
|
||||||
metadata: {
|
metadata: {
|
||||||
...activeThreadRef.current.metadata,
|
...activeThread.metadata,
|
||||||
lastMessage: prompt,
|
lastMessage: prompt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -235,17 +240,16 @@ export default function useSendChatMessage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start Model if not started
|
// Start Model if not started
|
||||||
const modelId =
|
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
|
||||||
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
|
|
||||||
|
|
||||||
if (base64Blob) {
|
if (base64Blob) {
|
||||||
setFileUpload(undefined)
|
setFileUpload(undefined)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId && modelId) {
|
if (activeModel?.id !== modelId && modelId) {
|
||||||
const error = await startModel(modelId).catch((error: Error) => error)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
if (error) {
|
if (error) {
|
||||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
updateThreadWaiting(activeThread.id, false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,111 +262,65 @@ export default function useSendChatMessage() {
|
|||||||
baseURL: `${API_BASE_URL}/v1`,
|
baseURL: `${API_BASE_URL}/v1`,
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
})
|
})
|
||||||
|
let parentMessageId: string | undefined
|
||||||
while (!isDone) {
|
while (!isDone) {
|
||||||
|
let messageId = ulid()
|
||||||
|
if (!parentMessageId) {
|
||||||
|
parentMessageId = ulid()
|
||||||
|
messageId = parentMessageId
|
||||||
|
}
|
||||||
const data = requestBuilder.build()
|
const data = requestBuilder.build()
|
||||||
|
const message: ThreadMessage = {
|
||||||
|
id: messageId,
|
||||||
|
object: 'message',
|
||||||
|
thread_id: activeThread.id,
|
||||||
|
assistant_id: activeAssistant.assistant_id,
|
||||||
|
role: ChatCompletionRole.Assistant,
|
||||||
|
content: [],
|
||||||
|
metadata: {
|
||||||
|
...(messageId !== parentMessageId
|
||||||
|
? { parent_id: parentMessageId }
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
|
status: MessageStatus.Pending,
|
||||||
|
created_at: Date.now() / 1000,
|
||||||
|
completed_at: Date.now() / 1000,
|
||||||
|
}
|
||||||
|
events.emit(MessageEvent.OnMessageResponse, message)
|
||||||
const response = await openai.chat.completions.create({
|
const response = await openai.chat.completions.create({
|
||||||
messages: (data.messages ?? []).map((e) => {
|
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||||
return {
|
|
||||||
role: e.role as OpenAIChatCompletionRole,
|
|
||||||
content: e.content,
|
|
||||||
}
|
|
||||||
}) as ChatCompletionMessageParam[],
|
|
||||||
model: data.model?.id ?? '',
|
model: data.model?.id ?? '',
|
||||||
tools: data.tools as ChatCompletionTool[],
|
tools: data.tools as ChatCompletionTool[],
|
||||||
stream: false,
|
stream: data.model?.parameters?.stream ?? false,
|
||||||
|
tool_choice: 'auto',
|
||||||
})
|
})
|
||||||
if (response.choices[0]?.message.content) {
|
// Variables to track and accumulate streaming content
|
||||||
const newMessage: ThreadMessage = {
|
if (!message.content.length) {
|
||||||
id: ulid(),
|
message.content = [
|
||||||
object: 'message',
|
{
|
||||||
thread_id: activeThreadRef.current.id,
|
type: ContentType.Text,
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
text: {
|
||||||
attachments: [],
|
value: '',
|
||||||
role: response.choices[0].message.role as ChatCompletionRole,
|
annotations: [],
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: response.choices[0].message.content
|
|
||||||
? response.choices[0].message.content
|
|
||||||
: '',
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
],
|
},
|
||||||
status: MessageStatus.Ready,
|
]
|
||||||
created_at: Date.now(),
|
}
|
||||||
completed_at: Date.now(),
|
if (data.model?.parameters?.stream)
|
||||||
}
|
isDone = await processStreamingResponse(
|
||||||
requestBuilder.pushAssistantMessage(
|
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||||
response.choices[0].message.content ?? ''
|
requestBuilder,
|
||||||
|
message
|
||||||
|
)
|
||||||
|
else {
|
||||||
|
isDone = await processNonStreamingResponse(
|
||||||
|
response as OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
requestBuilder,
|
||||||
|
message
|
||||||
)
|
)
|
||||||
events.emit(MessageEvent.OnMessageUpdate, newMessage)
|
|
||||||
}
|
}
|
||||||
|
message.status = MessageStatus.Ready
|
||||||
if (response.choices[0]?.message.tool_calls) {
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
for (const toolCall of response.choices[0].message.tool_calls) {
|
|
||||||
const id = ulid()
|
|
||||||
const toolMessage: ThreadMessage = {
|
|
||||||
id: id,
|
|
||||||
object: 'message',
|
|
||||||
thread_id: activeThreadRef.current.id,
|
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
|
||||||
attachments: [],
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`,
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created_at: Date.now(),
|
|
||||||
completed_at: Date.now(),
|
|
||||||
}
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
|
|
||||||
const result = await window.core.api.callTool({
|
|
||||||
toolName: toolCall.function.name,
|
|
||||||
arguments: JSON.parse(toolCall.function.arguments),
|
|
||||||
})
|
|
||||||
if (result.error) {
|
|
||||||
console.error(result.error)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: id,
|
|
||||||
object: 'message',
|
|
||||||
thread_id: activeThreadRef.current.id,
|
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
|
||||||
attachments: [],
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value:
|
|
||||||
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
|
|
||||||
(result.content[0]?.text ?? ''),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
status: MessageStatus.Ready,
|
|
||||||
created_at: Date.now(),
|
|
||||||
completed_at: Date.now(),
|
|
||||||
}
|
|
||||||
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
|
|
||||||
requestBuilder.pushMessage('Go for the next step')
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isDone =
|
|
||||||
!response.choices[0]?.message.tool_calls ||
|
|
||||||
!response.choices[0]?.message.tool_calls.length
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Request for inference
|
// Request for inference
|
||||||
@ -376,6 +334,182 @@ export default function useSendChatMessage() {
|
|||||||
setEngineParamsUpdate(false)
|
setEngineParamsUpdate(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const processNonStreamingResponse = async (
|
||||||
|
response: OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage
|
||||||
|
): Promise<boolean> => {
|
||||||
|
// Handle tool calls in the response
|
||||||
|
const toolCalls: ChatCompletionMessageToolCall[] =
|
||||||
|
response.choices[0]?.message?.tool_calls ?? []
|
||||||
|
const content = response.choices[0].message?.content
|
||||||
|
message.content = [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: content ?? '',
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
await postMessageProcessing(
|
||||||
|
toolCalls ?? [],
|
||||||
|
requestBuilder,
|
||||||
|
message,
|
||||||
|
content ?? ''
|
||||||
|
)
|
||||||
|
return !toolCalls || !toolCalls.length
|
||||||
|
}
|
||||||
|
|
||||||
|
const processStreamingResponse = async (
|
||||||
|
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage
|
||||||
|
): Promise<boolean> => {
|
||||||
|
// Variables to track and accumulate streaming content
|
||||||
|
let currentToolCall: {
|
||||||
|
id: string
|
||||||
|
function: { name: string; arguments: string }
|
||||||
|
} | null = null
|
||||||
|
let accumulatedContent = ''
|
||||||
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
|
// Process the streaming chunks
|
||||||
|
for await (const chunk of response) {
|
||||||
|
// Handle tool calls in the chunk
|
||||||
|
if (chunk.choices[0]?.delta?.tool_calls) {
|
||||||
|
const deltaToolCalls = chunk.choices[0].delta.tool_calls
|
||||||
|
|
||||||
|
// Handle the beginning of a new tool call
|
||||||
|
if (
|
||||||
|
deltaToolCalls[0]?.index !== undefined &&
|
||||||
|
deltaToolCalls[0]?.function
|
||||||
|
) {
|
||||||
|
const index = deltaToolCalls[0].index
|
||||||
|
|
||||||
|
// Create new tool call if this is the first chunk for it
|
||||||
|
if (!toolCalls[index]) {
|
||||||
|
toolCalls[index] = {
|
||||||
|
id: deltaToolCalls[0]?.id || '',
|
||||||
|
function: {
|
||||||
|
name: deltaToolCalls[0]?.function?.name || '',
|
||||||
|
arguments: deltaToolCalls[0]?.function?.arguments || '',
|
||||||
|
},
|
||||||
|
type: 'function',
|
||||||
|
}
|
||||||
|
currentToolCall = toolCalls[index]
|
||||||
|
} else {
|
||||||
|
// Continuation of existing tool call
|
||||||
|
currentToolCall = toolCalls[index]
|
||||||
|
|
||||||
|
// Append to function name or arguments if they exist in this chunk
|
||||||
|
if (deltaToolCalls[0]?.function?.name) {
|
||||||
|
currentToolCall!.function.name += deltaToolCalls[0].function.name
|
||||||
|
}
|
||||||
|
|
||||||
|
if (deltaToolCalls[0]?.function?.arguments) {
|
||||||
|
currentToolCall!.function.arguments +=
|
||||||
|
deltaToolCalls[0].function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular content in the chunk
|
||||||
|
if (chunk.choices[0]?.delta?.content) {
|
||||||
|
const content = chunk.choices[0].delta.content
|
||||||
|
accumulatedContent += content
|
||||||
|
|
||||||
|
message.content = [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: accumulatedContent,
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await postMessageProcessing(
|
||||||
|
toolCalls ?? [],
|
||||||
|
requestBuilder,
|
||||||
|
message,
|
||||||
|
accumulatedContent ?? ''
|
||||||
|
)
|
||||||
|
return !toolCalls || !toolCalls.length
|
||||||
|
}
|
||||||
|
|
||||||
|
const postMessageProcessing = async (
|
||||||
|
toolCalls: ChatCompletionMessageToolCall[],
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage,
|
||||||
|
content: string
|
||||||
|
) => {
|
||||||
|
requestBuilder.pushAssistantMessage({
|
||||||
|
content,
|
||||||
|
role: 'assistant',
|
||||||
|
refusal: null,
|
||||||
|
tool_calls: toolCalls,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle completed tool calls
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
for (const toolCall of toolCalls) {
|
||||||
|
const toolId = ulid()
|
||||||
|
const toolCallsMetadata =
|
||||||
|
message.metadata?.tool_calls &&
|
||||||
|
Array.isArray(message.metadata?.tool_calls)
|
||||||
|
? message.metadata?.tool_calls
|
||||||
|
: []
|
||||||
|
message.metadata = {
|
||||||
|
...(message.metadata ?? {}),
|
||||||
|
tool_calls: [
|
||||||
|
...toolCallsMetadata,
|
||||||
|
{
|
||||||
|
tool: {
|
||||||
|
...toolCall,
|
||||||
|
id: toolId,
|
||||||
|
},
|
||||||
|
response: undefined,
|
||||||
|
state: 'pending',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
|
||||||
|
const result = await window.core.api.callTool({
|
||||||
|
toolName: toolCall.function.name,
|
||||||
|
arguments: JSON.parse(toolCall.function.arguments),
|
||||||
|
})
|
||||||
|
if (result.error) break
|
||||||
|
|
||||||
|
message.metadata = {
|
||||||
|
...(message.metadata ?? {}),
|
||||||
|
tool_calls: [
|
||||||
|
...toolCallsMetadata,
|
||||||
|
{
|
||||||
|
tool: {
|
||||||
|
...toolCall,
|
||||||
|
id: toolId,
|
||||||
|
},
|
||||||
|
response: result,
|
||||||
|
state: 'ready',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
requestBuilder.pushToolMessage(
|
||||||
|
result.content[0]?.text ?? '',
|
||||||
|
toolCall.id
|
||||||
|
)
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sendChatMessage,
|
sendChatMessage,
|
||||||
resendChatMessage,
|
resendChatMessage,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
AssistantTool,
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
ExtensionTypeEnum,
|
ExtensionTypeEnum,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
@ -51,7 +52,11 @@ export default function useUpdateModelParameters() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
const updateModelParameter = useCallback(
|
const updateModelParameter = useCallback(
|
||||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
async (
|
||||||
|
thread: Thread,
|
||||||
|
settings: UpdateModelParameter,
|
||||||
|
tools?: AssistantTool[]
|
||||||
|
) => {
|
||||||
if (!activeAssistant) return
|
if (!activeAssistant) return
|
||||||
|
|
||||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||||
@ -70,6 +75,7 @@ export default function useUpdateModelParameters() {
|
|||||||
const settingParams = extractModelLoadParams(updatedModelParams)
|
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||||
const assistantInfo = {
|
const assistantInfo = {
|
||||||
...activeAssistant,
|
...activeAssistant,
|
||||||
|
tools: tools ?? activeAssistant.tools,
|
||||||
model: {
|
model: {
|
||||||
...activeAssistant?.model,
|
...activeAssistant?.model,
|
||||||
parameters: runtimeParams,
|
parameters: runtimeParams,
|
||||||
|
|||||||
@ -37,6 +37,7 @@
|
|||||||
"marked": "^9.1.2",
|
"marked": "^9.1.2",
|
||||||
"next": "14.2.3",
|
"next": "14.2.3",
|
||||||
"next-themes": "^0.2.1",
|
"next-themes": "^0.2.1",
|
||||||
|
"npx-scope-finder": "^1.3.0",
|
||||||
"openai": "^4.90.0",
|
"openai": "^4.90.0",
|
||||||
"postcss": "8.4.31",
|
"postcss": "8.4.31",
|
||||||
"postcss-url": "10.1.3",
|
"postcss-url": "10.1.3",
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
import { useRef, useState } from 'react'
|
import { useState } from 'react'
|
||||||
|
|
||||||
import { Slider, Input, Tooltip } from '@janhq/joi'
|
import { Slider, Input } from '@janhq/joi'
|
||||||
|
|
||||||
import { atom, useAtom } from 'jotai'
|
import { atom, useAtom } from 'jotai'
|
||||||
import { InfoIcon } from 'lucide-react'
|
|
||||||
|
|
||||||
export const hubModelSizeMinAtom = atom(0)
|
export const hubModelSizeMinAtom = atom(0)
|
||||||
export const hubModelSizeMaxAtom = atom(100)
|
export const hubModelSizeMaxAtom = atom(100)
|
||||||
|
|||||||
99
web/screens/Settings/MCP/configuration.tsx
Normal file
99
web/screens/Settings/MCP/configuration.tsx
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import React, { useState, useEffect, useCallback } from 'react'
|
||||||
|
|
||||||
|
import { Button, TextArea } from '@janhq/joi'
|
||||||
|
import { useAtomValue } from 'jotai'
|
||||||
|
|
||||||
|
import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||||
|
|
||||||
|
const MCPConfiguration = () => {
|
||||||
|
const janDataFolderPath = useAtomValue(janDataFolderPathAtom)
|
||||||
|
const [configContent, setConfigContent] = useState('')
|
||||||
|
const [isSaving, setIsSaving] = useState(false)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
const [success, setSuccess] = useState('')
|
||||||
|
|
||||||
|
const readConfigFile = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
// Read the file
|
||||||
|
const content = await window.core?.api.getMcpConfigs()
|
||||||
|
setConfigContent(content)
|
||||||
|
|
||||||
|
setError('')
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error reading config file:', err)
|
||||||
|
setError('Failed to read config file')
|
||||||
|
}
|
||||||
|
}, [janDataFolderPath])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (janDataFolderPath) {
|
||||||
|
readConfigFile()
|
||||||
|
}
|
||||||
|
}, [janDataFolderPath, readConfigFile])
|
||||||
|
|
||||||
|
const saveConfigFile = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
setIsSaving(true)
|
||||||
|
setSuccess('')
|
||||||
|
setError('')
|
||||||
|
|
||||||
|
// Validate JSON
|
||||||
|
try {
|
||||||
|
JSON.parse(configContent)
|
||||||
|
} catch (err) {
|
||||||
|
setError('Invalid JSON format')
|
||||||
|
setIsSaving(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
await window.core?.api?.saveMcpConfigs({ configs: configContent })
|
||||||
|
await window.core?.api?.restartMcpServers()
|
||||||
|
|
||||||
|
setSuccess('Config saved successfully')
|
||||||
|
setIsSaving(false)
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error saving config file:', err)
|
||||||
|
setError('Failed to save config file')
|
||||||
|
setIsSaving(false)
|
||||||
|
}
|
||||||
|
}, [janDataFolderPath, configContent])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{error && (
|
||||||
|
<div className="mb-4 rounded bg-[hsla(var(--destructive-bg))] px-4 py-3 text-[hsla(var(--destructive-fg))]">
|
||||||
|
{error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{success && (
|
||||||
|
<div className="mb-4 rounded bg-[hsla(var(--success-bg))] px-4 py-3 text-[hsla(var(--success-fg))]">
|
||||||
|
{success}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="mb-4 mt-2">
|
||||||
|
<label className="mb-2 block text-sm font-medium">
|
||||||
|
Configuration File (JSON)
|
||||||
|
</label>
|
||||||
|
<TextArea
|
||||||
|
// className="h-80 w-full rounded border border-gray-800 p-2 font-mono text-sm"
|
||||||
|
className="font-mono text-xs"
|
||||||
|
value={configContent}
|
||||||
|
rows={20}
|
||||||
|
onChange={(e) => {
|
||||||
|
setConfigContent(e.target.value)
|
||||||
|
setSuccess('')
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex justify-end">
|
||||||
|
<Button onClick={saveConfigFile} disabled={isSaving}>
|
||||||
|
{isSaving ? 'Saving...' : 'Save Config'}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MCPConfiguration
|
||||||
41
web/screens/Settings/MCP/index.tsx
Normal file
41
web/screens/Settings/MCP/index.tsx
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import React, { useState, useEffect, useCallback } from 'react'
|
||||||
|
|
||||||
|
import { ScrollArea, Tabs } from '@janhq/joi'
|
||||||
|
|
||||||
|
import { useAtomValue } from 'jotai'
|
||||||
|
|
||||||
|
import MCPConfiguration from './configuration'
|
||||||
|
import MCPSearch from './search'
|
||||||
|
|
||||||
|
import { showScrollBarAtom } from '@/helpers/atoms/Setting.atom'
|
||||||
|
|
||||||
|
const MCP = () => {
|
||||||
|
const [tabValue, setTabValue] = useState('search_mcp')
|
||||||
|
const showScrollBar = useAtomValue(showScrollBarAtom)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ScrollArea
|
||||||
|
type={showScrollBar ? 'always' : 'scroll'}
|
||||||
|
className="h-full w-full"
|
||||||
|
>
|
||||||
|
<div className="block w-full px-4 pb-4">
|
||||||
|
<div className="sticky top-0 bg-[hsla(var(--app-bg))] py-4">
|
||||||
|
<h2 className="mb-4 text-lg font-bold">MCP servers</h2>
|
||||||
|
<Tabs
|
||||||
|
options={[
|
||||||
|
{ name: 'Search MCP', value: 'search_mcp' },
|
||||||
|
{ name: 'Configuration', value: 'config' },
|
||||||
|
]}
|
||||||
|
tabStyle="segmented"
|
||||||
|
value={tabValue}
|
||||||
|
onValueChange={(value) => setTabValue(value as string)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{tabValue === 'search_mcp' && <MCPSearch />}
|
||||||
|
{tabValue === 'config' && <MCPConfiguration />}
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MCP
|
||||||
209
web/screens/Settings/MCP/search.tsx
Normal file
209
web/screens/Settings/MCP/search.tsx
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
import React, { useState, useEffect, useCallback } from 'react'
|
||||||
|
|
||||||
|
import { Button, Input } from '@janhq/joi'
|
||||||
|
import { PlusIcon } from 'lucide-react'
|
||||||
|
import { npxFinder, NPMPackage } from 'npx-scope-finder'
|
||||||
|
|
||||||
|
import { toaster } from '@/containers/Toast'
|
||||||
|
|
||||||
|
interface MCPConfig {
|
||||||
|
mcpServers: {
|
||||||
|
[key: string]: {
|
||||||
|
command: string
|
||||||
|
args: string[]
|
||||||
|
env: Record<string, string>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MCPSearch = () => {
|
||||||
|
const [showToast, setShowToast] = useState(false)
|
||||||
|
const [toastMessage, setToastMessage] = useState('')
|
||||||
|
const [toastType, setToastType] = useState<'success' | 'error'>('success')
|
||||||
|
const [orgName, setOrgName] = useState('@modelcontextprotocol')
|
||||||
|
const [packages, setPackages] = useState<NPMPackage[]>([])
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
|
const searchOrganizationPackages = useCallback(async () => {
|
||||||
|
if (!orgName) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
|
||||||
|
// Remove @ symbol if present at the beginning
|
||||||
|
// const scopeName = orgName.startsWith('@') ? orgName.substring(1) : orgName
|
||||||
|
|
||||||
|
// Use npxFinder to search for packages from the specified organization
|
||||||
|
const result = await npxFinder(orgName)
|
||||||
|
|
||||||
|
setPackages(result || [])
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error searching for packages:', err)
|
||||||
|
setError('Failed to search for packages. Please try again.')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [orgName])
|
||||||
|
|
||||||
|
// Search for packages when the component mounts
|
||||||
|
useEffect(() => {
|
||||||
|
searchOrganizationPackages()
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<h2 className="mt-2 text-lg font-bold">NPX Package List</h2>
|
||||||
|
<p className="text-[hsla(var(--text-secondary))]">
|
||||||
|
Search and add npm packages as MCP servers
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div className="mt-6">
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<input
|
||||||
|
id="orgName"
|
||||||
|
type="text"
|
||||||
|
value={orgName}
|
||||||
|
onChange={(e) => setOrgName(e.target.value)}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === 'Enter' && orgName) {
|
||||||
|
e.preventDefault()
|
||||||
|
searchOrganizationPackages()
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="input w-full"
|
||||||
|
placeholder="Enter npm scope name (e.g. @janhq)"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
onClick={searchOrganizationPackages}
|
||||||
|
disabled={loading || !orgName}
|
||||||
|
>
|
||||||
|
{loading ? 'Searching...' : 'Search'}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{error && <p className="mt-2 text-sm text-red-500">{error}</p>}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{packages.length > 0 ? (
|
||||||
|
<div className="mt-6">
|
||||||
|
{packages.map((pkg, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className="my-2 rounded-xl border border-[hsla(var(--app-border))]"
|
||||||
|
>
|
||||||
|
<div className="flex justify-between border-b border-[hsla(var(--app-border))] px-4 py-3">
|
||||||
|
<span>{pkg.name?.split('/')[1]}</span>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="font-mono text-sm text-[hsla(var(--text-secondary))]">
|
||||||
|
{pkg.version}
|
||||||
|
</span>
|
||||||
|
<Button theme="icon" onClick={() => handleAddToConfig(pkg)}>
|
||||||
|
<PlusIcon />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="px-4 py-3">
|
||||||
|
<p>{pkg.description || 'No description'}</p>
|
||||||
|
<p className="my-2 whitespace-nowrap text-[hsla(var(--text-secondary))]">
|
||||||
|
Usage: npx {pkg.name}
|
||||||
|
</p>
|
||||||
|
<a
|
||||||
|
target="_blank"
|
||||||
|
href={`https://www.npmjs.com/package/${pkg.name}`}
|
||||||
|
>{`https://www.npmjs.com/package/${pkg.name}`}</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
!loading && (
|
||||||
|
<div className="mt-4">
|
||||||
|
<p>
|
||||||
|
No packages found. Try searching for a different organization.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
)}
|
||||||
|
|
||||||
|
{showToast && (
|
||||||
|
<div
|
||||||
|
className={`fixed bottom-4 right-4 z-50 rounded-md p-4 shadow-lg ${
|
||||||
|
toastType === 'success'
|
||||||
|
? 'bg-green-100 text-green-800'
|
||||||
|
: 'bg-red-100 text-red-800'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span>{toastMessage}</span>
|
||||||
|
<button
|
||||||
|
onClick={() => setShowToast(false)}
|
||||||
|
className="ml-4 text-gray-500 hover:text-gray-700"
|
||||||
|
>
|
||||||
|
×
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
|
||||||
|
// Function to add a package to the MCP configuration
|
||||||
|
async function handleAddToConfig(pkg: NPMPackage) {
|
||||||
|
try {
|
||||||
|
// Get current configuration
|
||||||
|
const currentConfig = await window.core?.api.getMcpConfigs()
|
||||||
|
|
||||||
|
// Parse the configuration
|
||||||
|
let config: MCPConfig
|
||||||
|
try {
|
||||||
|
config = JSON.parse(currentConfig || '{"mcpServers": {}}')
|
||||||
|
} catch (err) {
|
||||||
|
// If parsing fails, start with an empty configuration
|
||||||
|
config = { mcpServers: {} }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique server name based on the package name
|
||||||
|
const serverName = pkg.name?.split('/')[1] || 'unknown'
|
||||||
|
|
||||||
|
// Check if this server already exists
|
||||||
|
if (config.mcpServers[serverName]) {
|
||||||
|
toaster({
|
||||||
|
title: `Add ${serverName} success`,
|
||||||
|
description: `Server ${serverName} already exists in configuration`,
|
||||||
|
type: 'error',
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new server configuration
|
||||||
|
config.mcpServers[serverName] = {
|
||||||
|
command: 'npx',
|
||||||
|
args: ['-y', pkg.name || ''],
|
||||||
|
env: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the updated configuration
|
||||||
|
await window.core?.api?.saveMcpConfigs({
|
||||||
|
configs: JSON.stringify(config, null, 2),
|
||||||
|
})
|
||||||
|
await window.core?.api?.restartMcpServers()
|
||||||
|
|
||||||
|
toaster({
|
||||||
|
title: `Add ${serverName} success`,
|
||||||
|
description: `Added ${serverName} to MCP configuration`,
|
||||||
|
type: 'success',
|
||||||
|
})
|
||||||
|
} catch (err) {
|
||||||
|
toaster({
|
||||||
|
title: `Add ${pkg.name?.split('/')[1] || 'unknown'} failed`,
|
||||||
|
description: `Failed to add package to configuration`,
|
||||||
|
type: 'error',
|
||||||
|
})
|
||||||
|
console.error('Error adding package to configuration:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MCPSearch
|
||||||
@ -15,6 +15,7 @@ import RemoteEngineSettings from '@/screens/Settings/Engines/RemoteEngineSetting
|
|||||||
import ExtensionSetting from '@/screens/Settings/ExtensionSetting'
|
import ExtensionSetting from '@/screens/Settings/ExtensionSetting'
|
||||||
import Hardware from '@/screens/Settings/Hardware'
|
import Hardware from '@/screens/Settings/Hardware'
|
||||||
import Hotkeys from '@/screens/Settings/Hotkeys'
|
import Hotkeys from '@/screens/Settings/Hotkeys'
|
||||||
|
import MCP from '@/screens/Settings/MCP'
|
||||||
import MyModels from '@/screens/Settings/MyModels'
|
import MyModels from '@/screens/Settings/MyModels'
|
||||||
import Privacy from '@/screens/Settings/Privacy'
|
import Privacy from '@/screens/Settings/Privacy'
|
||||||
|
|
||||||
@ -31,6 +32,9 @@ const SettingDetail = () => {
|
|||||||
case 'Engines':
|
case 'Engines':
|
||||||
return <Engines />
|
return <Engines />
|
||||||
|
|
||||||
|
case 'MCP Servers':
|
||||||
|
return <MCP />
|
||||||
|
|
||||||
case 'Extensions':
|
case 'Extensions':
|
||||||
return <ExtensionCatalog />
|
return <ExtensionCatalog />
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ export const SettingScreenList = [
|
|||||||
'Privacy',
|
'Privacy',
|
||||||
'Advanced Settings',
|
'Advanced Settings',
|
||||||
'Engines',
|
'Engines',
|
||||||
|
'MCP Servers',
|
||||||
'Extensions',
|
'Extensions',
|
||||||
] as const
|
] as const
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,15 @@ import { useEffect, useRef, useState } from 'react'
|
|||||||
|
|
||||||
import { InferenceEngine } from '@janhq/core'
|
import { InferenceEngine } from '@janhq/core'
|
||||||
|
|
||||||
import { TextArea, Button, Tooltip, useClickOutside, Badge } from '@janhq/joi'
|
import {
|
||||||
|
TextArea,
|
||||||
|
Button,
|
||||||
|
Tooltip,
|
||||||
|
useClickOutside,
|
||||||
|
Badge,
|
||||||
|
Modal,
|
||||||
|
ModalClose,
|
||||||
|
} from '@janhq/joi'
|
||||||
import { useAtom, useAtomValue } from 'jotai'
|
import { useAtom, useAtomValue } from 'jotai'
|
||||||
import {
|
import {
|
||||||
FileTextIcon,
|
FileTextIcon,
|
||||||
@ -13,6 +21,7 @@ import {
|
|||||||
SettingsIcon,
|
SettingsIcon,
|
||||||
ChevronUpIcon,
|
ChevronUpIcon,
|
||||||
Settings2Icon,
|
Settings2Icon,
|
||||||
|
WrenchIcon,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
|
|
||||||
import { twMerge } from 'tailwind-merge'
|
import { twMerge } from 'tailwind-merge'
|
||||||
@ -45,6 +54,7 @@ import {
|
|||||||
isBlockingSendAtom,
|
isBlockingSendAtom,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'
|
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'
|
||||||
|
import { ModelTool } from '@/types/model'
|
||||||
|
|
||||||
const ChatInput = () => {
|
const ChatInput = () => {
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
@ -69,6 +79,8 @@ const ChatInput = () => {
|
|||||||
const isBlockingSend = useAtomValue(isBlockingSendAtom)
|
const isBlockingSend = useAtomValue(isBlockingSendAtom)
|
||||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const { stopInference } = useActiveModel()
|
const { stopInference } = useActiveModel()
|
||||||
|
const [tools, setTools] = useState<any>([])
|
||||||
|
const [showToolsModal, setShowToolsModal] = useState(false)
|
||||||
|
|
||||||
const upload = uploader()
|
const upload = uploader()
|
||||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||||
@ -92,6 +104,12 @@ const ChatInput = () => {
|
|||||||
}
|
}
|
||||||
}, [activeSettingInputBox, selectedModel, setActiveSettingInputBox])
|
}, [activeSettingInputBox, selectedModel, setActiveSettingInputBox])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
window.core?.api?.getTools().then((data: ModelTool[]) => {
|
||||||
|
setTools(data)
|
||||||
|
})
|
||||||
|
}, [])
|
||||||
|
|
||||||
const onStopInferenceClick = async () => {
|
const onStopInferenceClick = async () => {
|
||||||
stopInference()
|
stopInference()
|
||||||
}
|
}
|
||||||
@ -385,6 +403,62 @@ const ChatInput = () => {
|
|||||||
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
||||||
/>
|
/>
|
||||||
</Badge>
|
</Badge>
|
||||||
|
{tools && tools.length > 0 && (
|
||||||
|
<>
|
||||||
|
<Badge
|
||||||
|
theme="secondary"
|
||||||
|
className={twMerge(
|
||||||
|
'flex cursor-pointer items-center gap-x-1'
|
||||||
|
)}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={() => setShowToolsModal(true)}
|
||||||
|
>
|
||||||
|
<WrenchIcon
|
||||||
|
size={16}
|
||||||
|
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
||||||
|
/>
|
||||||
|
<span className="text-xs">{tools.length}</span>
|
||||||
|
</Badge>
|
||||||
|
|
||||||
|
<Modal
|
||||||
|
open={showToolsModal}
|
||||||
|
onOpenChange={setShowToolsModal}
|
||||||
|
title="Available MCP Tools"
|
||||||
|
content={
|
||||||
|
<div className="overflow-y-auto">
|
||||||
|
<div className="mb-2 py-2 text-sm text-[hsla(var(--text-secondary))]">
|
||||||
|
Jan can use tools provided by specialized servers using
|
||||||
|
Model Context Protocol.{' '}
|
||||||
|
<a
|
||||||
|
href="https://modelcontextprotocol.io/introduction"
|
||||||
|
target="_blank"
|
||||||
|
className="text-[hsla(var(--app-link))]"
|
||||||
|
>
|
||||||
|
Learn more about MCP
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
{tools.map((tool: any) => (
|
||||||
|
<div
|
||||||
|
key={tool.name}
|
||||||
|
className="flex items-center gap-x-3 px-4 py-3 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]"
|
||||||
|
>
|
||||||
|
<WrenchIcon
|
||||||
|
size={16}
|
||||||
|
className="flex-shrink-0 text-[hsla(var(--text-secondary))]"
|
||||||
|
/>
|
||||||
|
<div>
|
||||||
|
<div className="font-medium">{tool.name}</div>
|
||||||
|
<div className="text-sm text-[hsla(var(--text-secondary))]">
|
||||||
|
{tool.description}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
{selectedModel && (
|
{selectedModel && (
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@ -0,0 +1,57 @@
|
|||||||
|
import React from 'react'
|
||||||
|
|
||||||
|
import { atom, useAtom } from 'jotai'
|
||||||
|
import { ChevronDown, ChevronUp, Loader } from 'lucide-react'
|
||||||
|
|
||||||
|
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
result: string
|
||||||
|
name: string
|
||||||
|
id: number
|
||||||
|
loading: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolCallBlockStateAtom = atom<{ [id: number]: boolean }>({})
|
||||||
|
|
||||||
|
const ToolCallBlock = ({ id, name, result, loading }: Props) => {
|
||||||
|
const [collapseState, setCollapseState] = useAtom(toolCallBlockStateAtom)
|
||||||
|
|
||||||
|
const isExpanded = collapseState[id] ?? false
|
||||||
|
const handleClick = () => {
|
||||||
|
setCollapseState((prev) => ({ ...prev, [id]: !isExpanded }))
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div className="mx-auto w-full">
|
||||||
|
<div className="mb-4 rounded-lg border border-dashed border-[hsla(var(--app-border))] p-2">
|
||||||
|
<div
|
||||||
|
className="flex cursor-pointer items-center gap-3"
|
||||||
|
onClick={handleClick}
|
||||||
|
>
|
||||||
|
{loading && (
|
||||||
|
<Loader className="h-4 w-4 animate-spin text-[hsla(var(--primary-bg))]" />
|
||||||
|
)}
|
||||||
|
<button className="flex items-center gap-2 focus:outline-none">
|
||||||
|
{isExpanded ? (
|
||||||
|
<ChevronUp className="h-4 w-4" />
|
||||||
|
) : (
|
||||||
|
<ChevronDown className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
<span className="font-medium">
|
||||||
|
{' '}
|
||||||
|
View result from <span className="font-bold">{name}</span>
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{isExpanded && (
|
||||||
|
<div className="mt-2 overflow-x-hidden pl-6 text-[hsla(var(--text-secondary))]">
|
||||||
|
<span>{result ?? ''} </span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ToolCallBlock
|
||||||
@ -18,6 +18,8 @@ import ImageMessage from './ImageMessage'
|
|||||||
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||||
import ThinkingBlock from './ThinkingBlock'
|
import ThinkingBlock from './ThinkingBlock'
|
||||||
|
|
||||||
|
import ToolCallBlock from './ToolCallBlock'
|
||||||
|
|
||||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import {
|
||||||
editMessageAtom,
|
editMessageAtom,
|
||||||
@ -65,57 +67,64 @@ const MessageContainer: React.FC<
|
|||||||
[props.content]
|
[props.content]
|
||||||
)
|
)
|
||||||
|
|
||||||
const attachedFile = useMemo(() => 'attachments' in props, [props])
|
const attachedFile = useMemo(
|
||||||
|
() =>
|
||||||
|
'attachments' in props &&
|
||||||
|
!!props.attachments?.length &&
|
||||||
|
props.attachments?.length > 0,
|
||||||
|
[props]
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'group relative mx-auto px-4 py-2',
|
'group relative mx-auto px-4',
|
||||||
|
!(props.metadata && 'parent_id' in props.metadata) && 'py-2',
|
||||||
chatWidth === 'compact' && 'max-w-[700px]',
|
chatWidth === 'compact' && 'max-w-[700px]',
|
||||||
isUser && 'pb-4 pt-0'
|
!isUser && 'pb-4 pt-0'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div
|
{!(props.metadata && 'parent_id' in props.metadata) && (
|
||||||
className={twMerge(
|
|
||||||
'mb-2 flex items-center justify-start',
|
|
||||||
!isUser && 'mt-2 gap-x-2'
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{!isUser && !isSystem && <LogoMark width={28} />}
|
|
||||||
|
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'font-extrabold capitalize',
|
'mb-2 flex items-center justify-start',
|
||||||
isUser && 'text-gray-500'
|
!isUser && 'mt-2 gap-x-2'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{!isUser && (
|
{!isUser && !isSystem && <LogoMark width={28} />}
|
||||||
<>
|
<div
|
||||||
{props.metadata && 'model' in props.metadata
|
className={twMerge(
|
||||||
? (props.metadata?.model as string)
|
'font-extrabold capitalize',
|
||||||
: props.isCurrentMessage
|
isUser && 'text-gray-500'
|
||||||
? selectedModel?.name
|
)}
|
||||||
: (activeAssistant?.assistant_name ?? props.role)}
|
>
|
||||||
</>
|
{!isUser && (
|
||||||
)}
|
<>
|
||||||
|
{props.metadata && 'model' in props.metadata
|
||||||
|
? (props.metadata?.model as string)
|
||||||
|
: props.isCurrentMessage
|
||||||
|
? selectedModel?.name
|
||||||
|
: (activeAssistant?.assistant_name ?? props.role)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p className="text-xs font-medium text-gray-400">
|
||||||
|
{props.created_at &&
|
||||||
|
displayDate(props.created_at ?? Date.now() / 1000)}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<p className="text-xs font-medium text-gray-400">
|
<div className="flex w-full flex-col">
|
||||||
{props.created_at &&
|
|
||||||
displayDate(props.created_at ?? Date.now() / 1000)}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex w-full flex-col ">
|
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'absolute right-0 order-1 flex cursor-pointer items-center justify-start gap-x-2 transition-all',
|
'absolute right-0 order-1 flex cursor-pointer items-center justify-start gap-x-2 transition-all',
|
||||||
isUser
|
twMerge(
|
||||||
? twMerge(
|
'hidden group-hover:absolute group-hover:-bottom-4 group-hover:right-4 group-hover:z-50 group-hover:flex',
|
||||||
'hidden group-hover:absolute group-hover:right-4 group-hover:top-4 group-hover:z-50 group-hover:flex',
|
image && 'group-hover:-top-2'
|
||||||
image && 'group-hover:-top-2'
|
),
|
||||||
)
|
|
||||||
: 'relative left-0 order-2 flex w-full justify-between opacity-0 group-hover:opacity-100',
|
|
||||||
props.isCurrentMessage && 'opacity-100'
|
props.isCurrentMessage && 'opacity-100'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@ -179,6 +188,22 @@ const MessageContainer: React.FC<
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
{props.metadata &&
|
||||||
|
'tool_calls' in props.metadata &&
|
||||||
|
Array.isArray(props.metadata.tool_calls) &&
|
||||||
|
props.metadata.tool_calls.length && (
|
||||||
|
<>
|
||||||
|
{props.metadata.tool_calls.map((toolCall) => (
|
||||||
|
<ToolCallBlock
|
||||||
|
id={toolCall.tool?.id}
|
||||||
|
name={toolCall.tool?.function?.name ?? ''}
|
||||||
|
key={toolCall.tool?.id}
|
||||||
|
result={JSON.stringify(toolCall.response)}
|
||||||
|
loading={toolCall.state === 'pending'}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -8,6 +8,20 @@ export const Routes = [
|
|||||||
'installExtensions',
|
'installExtensions',
|
||||||
'getTools',
|
'getTools',
|
||||||
'callTool',
|
'callTool',
|
||||||
|
'listThreads',
|
||||||
|
'createThread',
|
||||||
|
'modifyThread',
|
||||||
|
'deleteThread',
|
||||||
|
'listMessages',
|
||||||
|
'createMessage',
|
||||||
|
'modifyMessage',
|
||||||
|
'deleteMessage',
|
||||||
|
'getThreadAssistant',
|
||||||
|
'createThreadAssistant',
|
||||||
|
'modifyThreadAssistant',
|
||||||
|
'saveMcpConfigs',
|
||||||
|
'getMcpConfigs',
|
||||||
|
'restartMcpServers',
|
||||||
].map((r) => ({
|
].map((r) => ({
|
||||||
path: `app`,
|
path: `app`,
|
||||||
route: r,
|
route: r,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import {
|
|||||||
Thread,
|
Thread,
|
||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
import { ChatCompletionMessage as OAIChatCompletionMessage } from 'openai/resources/chat'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
@ -45,12 +46,26 @@ export class MessageRequestBuilder {
|
|||||||
this.tools = tools
|
this.tools = tools
|
||||||
}
|
}
|
||||||
|
|
||||||
pushAssistantMessage(message: string) {
|
pushAssistantMessage(message: OAIChatCompletionMessage) {
|
||||||
|
const { content, refusal, ...rest } = message
|
||||||
|
const normalizedMessage = {
|
||||||
|
...rest,
|
||||||
|
...(content ? { content } : {}),
|
||||||
|
...(refusal ? { refusal } : {}),
|
||||||
|
}
|
||||||
|
this.messages = [
|
||||||
|
...this.messages,
|
||||||
|
normalizedMessage as ChatCompletionMessage,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
pushToolMessage(message: string, toolCallId: string) {
|
||||||
this.messages = [
|
this.messages = [
|
||||||
...this.messages,
|
...this.messages,
|
||||||
{
|
{
|
||||||
role: ChatCompletionRole.Assistant,
|
role: ChatCompletionRole.Tool,
|
||||||
content: message,
|
content: message,
|
||||||
|
tool_call_id: toolCallId,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -140,40 +155,13 @@ export class MessageRequestBuilder {
|
|||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizeMessages = (
|
|
||||||
messages: ChatCompletionMessage[]
|
|
||||||
): ChatCompletionMessage[] => {
|
|
||||||
const stack = new Stack<ChatCompletionMessage>()
|
|
||||||
for (const message of messages) {
|
|
||||||
if (stack.isEmpty()) {
|
|
||||||
stack.push(message)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const topMessage = stack.peek()
|
|
||||||
|
|
||||||
if (message.role === topMessage.role) {
|
|
||||||
// add an empty message
|
|
||||||
stack.push({
|
|
||||||
role:
|
|
||||||
topMessage.role === ChatCompletionRole.User
|
|
||||||
? ChatCompletionRole.Assistant
|
|
||||||
: ChatCompletionRole.User,
|
|
||||||
content: '.', // some model requires not empty message
|
|
||||||
})
|
|
||||||
}
|
|
||||||
stack.push(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return stack.reverseOutput()
|
|
||||||
}
|
|
||||||
|
|
||||||
build(): MessageRequest {
|
build(): MessageRequest {
|
||||||
return {
|
return {
|
||||||
id: this.msgId,
|
id: this.msgId,
|
||||||
type: this.type,
|
type: this.type,
|
||||||
attachments: [],
|
attachments: [],
|
||||||
threadId: this.thread.id,
|
threadId: this.thread.id,
|
||||||
messages: this.normalizeMessages(this.messages),
|
messages: this.messages,
|
||||||
model: this.model,
|
model: this.model,
|
||||||
thread: this.thread,
|
thread: this.thread,
|
||||||
tools: this.tools,
|
tools: this.tools,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user