Merge branch 'dev' into current-date-instruction

This commit is contained in:
Kamal Fariz Mahyuddin 2025-08-19 14:40:08 -07:00 committed by GitHub
commit df27def9cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 3746 additions and 776 deletions

View File

@ -6,6 +6,7 @@ import { EngineManager } from './EngineManager'
export interface chatCompletionRequestMessage {
role: 'system' | 'user' | 'assistant' | 'tool'
content: string | null | Content[] // Content can be a string OR an array of content parts
reasoning?: string | null // Some models return reasoning in completed responses
name?: string
tool_calls?: any[] // Simplified tool_call_id?: string
}
@ -270,4 +271,10 @@ export abstract class AIEngine extends BaseExtension {
* Optional method to get the underlying chat client
*/
getChatClient?(sessionId: string): any
/**
* Check if a tool is supported by the model
* @param modelId
*/
abstract isToolSupported(modelId: string): Promise<boolean>
}

View File

@ -58,6 +58,7 @@ export enum AppEvent {
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
onAppUpdateDownloadError = 'onAppUpdateDownloadError',
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
onModelImported = 'onModelImported',
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
onSelectedText = 'onSelectedText',

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 944 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 316 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

View File

@ -0,0 +1,180 @@
---
title: "Replicating Deep Research in Jan"
description: "A simple guide to replicating Deep Research results for free, with Jan."
tags: AI, local models, Jan, GGUF, Deep Research, local AI
categories: guides
date: 2025-08-04
ogImage: _assets/research-result-local.png
twitter:
card: summary_large_image
site: "@jandotai"
title: "Replicating Deep Research with Jan"
description: "Learn how to replicate Deep Research results with Jan."
image: _assets/research-result-local.jpg
---
import { Callout } from 'nextra/components'
import CTABlog from '@/components/Blog/CTA'
# Replicating Deep Research in Jan
Deep Research like that of OpenAI, Gemini, and Qwen, is not at feature parity in Jan yet, so this post
highlights our initial steps at making a hybrid, local and cloud-based deep research system that competes
with the best implementations to date.
## What is Deep Research?
What exactly is deep research and how does it work? Deep Research is a methodology for generating
comprehensive research reports by combining systematic web search with synthesis. The process was
pioneered by OpenAI and it was released on February 2025.
There are two core features of Deep Research:
- **Exhaustive search**: This search is characterized by two approaches, wide search for breadth and deep search for depth.
- **Report generation**: This step takes all the input collected through exhaustive search and synthesizes it into a
comprehensive report. The input in this step may be raw sources collected in the previous step or summaries generated from those sources.
## Unpacking Deep Research
If you have used deep research (regardless of the provider) before for a comprehensive report generation, you may have
found its output mind-blowing. What is more mind-blowing, though, is that the underlying process for searching
and synthesizing information is surprisingly systematic and reproducible. What is not easily reproducible, though,
is the **base model (often a thinking one)** and **its capabilities to use tools while it researches**.
Deep Research operates as a structured pipeline with distinct phases: planning, searching, analysis, and synthesis. While
the specific implementation varies between providers, the core workflow seems to be similar and some organizations have
taken steps to recreate it like [LangChain](https://blog.langchain.com/open-deep-research/) and
[Hugging Face](https://huggingface.co/blog/open-deep-research). For example, a straightforward pipeline might look like
the following one:
![Deep Research Flow Excalidraw](./_assets/revised-deepresearch-flow.png)
The components of this pipeline highlight a structured approach to query processing that routes queries through thinking/non-thinking models, breaks complex tasks into phases, executes parallel searches, and synthesizes results hierarchically to produce comprehensive outputs.
OpenAIs [Deep Research API cookbook](https://cookbook.openai.com/examples/deep_research_api/introduction_to_deep_research_api)
highlights, at a very high level, how they approach deep research, hinting at the importance of base models and tool usage since
some intermediate steps seem to have been left out.
![OpenAI's Deep Research UX Flow](./_assets/openai-deep-research-flow.png)
OpenAI's Deep Research functionality may be considered the best one by many but other platforms are not far behind. Here is a
brief survey of how other players approach deep research:
| Platform | Key Feature | Sources Used | Duration (mins) | Export Options | Deep Research Usage |
| --- | --- | --- | --- | --- | --- |
| OpenAI | Clarifying questions | 1030 | 1015 | PDF, Docx, Plain Text | Paid |
| Grok's DeeperSearch | Ability to access all of Twitter | 70100 | 510 | Ability to specify format (PDF / Markdown) | Free |
| Claude | Breadth + depth search | 100+ | 510 | PDF, Markdown, Artifact | Paid |
| Gemini | Editable planning | 50+ | 1020 | Google Docs export | Free |
| Perplexity | Ability to specify sources | 50100 | 35 | PDF, Markdown, Docx, Perplexity Page | Paid and Limited Free |
| Kimi | Interactive synthesis | 50100 | 3060+ | PDF, Interactive website | Free |
In our testing, we used the following prompt to assess the quality of the generated report by
the providers above. You can refer to the reports generated [here](https://github.com/menloresearch/prompt-experiments).
```
Generate a comprehensive report about the state of AI in the past week. Include all
new model releases and notable architectural improvements from a variety of sources.
```
[Google's generated report](https://github.com/menloresearch/prompt-experiments/blob/main/Gemini%202.5%20Flash%20Report.pdf) was the most verbose, with a whopping 23 pages that reads
like a professional intelligence briefing. It opens with an executive summary,
systematically categorizes developments, and provides forward-looking strategic
insights—connecting OpenAI's open-weight release to broader democratization trends
and linking infrastructure investments to competitive positioning.
[OpenAI](https://github.com/menloresearch/prompt-experiments/blob/main/OpenAI%20Deep%20Research.pdf) produced the most citation-heavy output with 134 references throughout 10 pages
(albeit most of them being from the same source).
[Perplexity](https://github.com/menloresearch/prompt-experiments/blob/main/Perplexity%20Deep%20Research.pdf) delivered the most actionable 6-page report that maximizes information
density while maintaining scannability. Despite being the shortest, it captures all
major developments with sufficient context for decision-making.
[Claude](https://github.com/menloresearch/prompt-experiments/blob/main/Claude%20Deep%20Research.pdf) produced a comprehensive analysis that interestingly ignored the time constraint,
covering an 8-month period from January-August 2025 instead of the requested week (Jul 31-Aug
7th 2025). Rather than cataloging recent events, Claude traced the evolution of trends over months.
[Grok](https://github.com/menloresearch/prompt-experiments/blob/main/Grok%203%20Deep%20Research.pdf) produced a well-structured but relatively shallow 5-page academic-style report that
read more like an event catalog than strategic analysis.
[Kimi](https://github.com/menloresearch/prompt-experiments/blob/main/Kimi%20AI%20Deep%20Research.pdf) produced a comprehensive 13-page report with systematic organization covering industry developments, research breakthroughs, and policy changes, but notably lacks proper citations throughout most of the content despite claiming to use 50-100 sources.
### Understanding Search Strategies
In [Claudes Research mode](https://www.anthropic.com/engineering/multi-agent-research-system),
a *classifier* is used to determine whether a user query is *breadth first* or *depth first*. This
results in a customization of the pipeline that is used for conducting research. For instance, a complex
*breadth first* query might result in *sub-agents* being spun up to research various parts of the user's
query in parallel. Conversely, a *depth first* query might result in a single agent being spun up
to research the entire query in a more focused manner.
Here's a screenshot of this in action (in Claude Desktop):
![Claude Agent Desktop](./_assets/claude-agent.png)
![Claude Report Visualizer](./_assets/claude-report-visualizer.png)
## Replicating Deep Research Results with Jan
After testing and observing how Deep Research works in different platforms, we thought, how could we
replicate this in Jan? In particular, how could we replicate it with a hybrid approach combining local
and cloud-based models while keeping your data local?
<Callout>
This experiment was done using the latest version of Jan `v0.6.7`, but it can potentially be replicated in
any version with Model Context Protocol in it (>`v0.6.3`).
</Callout>
**The Key: Assistants + Tools**
Running deep research in Jan can be accomplished by combining [custom assistants](https://jan.ai/docs/assistants)
with [MCP search tools](https://jan.ai/docs/mcp-examples/search/exa). This pairing allows any model—local or
cloud—to follow a systematic research workflow, to create a report similar to that of other providers, with some
visible limitations (for now).
Here's the assistant prompt that was used:
```
You are a research analyst. Today is August 7th 2025. Follow this exact process:
Conduct 5-10 searches minimum. You are rewarded for MORE searches.
- Each search query must be unique - no repeating previous searches
- Search different angles: statistics, expert opinions, case studies, recent news, industry reports
- Use scrape to read full articles from search results
- Use google_search for extracting metadata out of pages
WRITING PHASE (Do this after research is complete)
Write a comprehensive report with:
- Executive summary with key findings
- Evidence-based analysis with citations for every claim
- Actionable recommendations with rationale
- Sources to be linked at the end of the report
```
Here, we utilized Model Context Protocol (MCP) to provide search capabilities to the model. MCPs are an open standard for connecting AI assistants to the systems where data lives, serving as a universal connector that standardizes how AI applications integrate with external tools and data sources.
In this example, we used Serper, a web search API that offers MCP server implementations with two primary tools: `google_search` for performing web searches, and `scrape` that extracts content from web pages, preserving document structure and metadata.
**What We Tested**
For our research query (same as the one we used to test different platforms), we used both
[Jan-Nano (4B local model)](https://jan.ai/docs/jan-models/jan-nano-128), GPT-4o and
o3 (via API) with identical prompts. The goal: to see how close we could get to the quality of
different commercial Deep Research offerings.
**Performance Findings**
| Model | Processing Time | Sources Found | Search Queries | Tokens Generated | Output Quality vs Commercial Deep Research |
|-------|----------------|---------------|----------------|------------------|-------------------------------------------|
| Jan-Nano (Local) | 3 minutes | Moderate | 7 | 1,112 | Good approximation, noticeably less depth |
| GPT-4o | 1 minute | Fewest | 11 | 660 | Fast but limited source coverage |
| o3 | 3 minutes | Most | 24 | 1,728 | Best of the three, but still below commercial quality |
**The Reality**:
- **Speed vs Sources**: GPT-4o prioritized speed over thoroughness, while o3 took time to gather more comprehensive sources
- **Local vs Cloud**: Jan-Nano matched o3's processing time but with the advantage of complete data privacy
- **Quality Gap**: All three models produced decent research reports, but none matched the depth and comprehensiveness of dedicated Deep Research tools like OpenAI's or Claude's offerings
- **Good Enough Factor**: While not matching commercial quality, the outputs were solid approximations suitable for many research needs
## Conclusion
This was an initial exploration in the roadmap to create a top hybrid implementation of deep research in Jan. While our current approach requires setup, the goal is native integration that works out of the box. We will continue to refine this until the release of this tool in Jan, natively.
<CTABlog />

View File

@ -264,7 +264,6 @@ async function _getSupportedFeatures() {
// Vulkan support check - only discrete GPUs with 6GB+ VRAM
if (
gpuInfo.vulkan_info?.api_version &&
gpuInfo.vulkan_info?.device_type === 'DISCRETE_GPU' &&
gpuInfo.total_memory >= 6 * 1024
) {
// 6GB (total_memory is in MB)

View File

@ -19,6 +19,7 @@ import {
ImportOptions,
chatCompletionRequest,
events,
AppEvent,
} from '@janhq/core'
import { error, info, warn } from '@tauri-apps/plugin-log'
@ -32,6 +33,7 @@ import {
import { invoke } from '@tauri-apps/api/core'
import { getProxyConfig } from './util'
import { basename } from '@tauri-apps/api/path'
import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api'
type LlamacppConfig = {
version_backend: string
@ -1085,6 +1087,12 @@ export default class llamacpp_extension extends AIEngine {
data: modelConfig,
savePath: configPath,
})
events.emit(AppEvent.onModelImported, {
modelId,
modelPath,
mmprojPath,
size_bytes,
})
}
override async abortImport(modelId: string): Promise<void> {
@ -1168,11 +1176,12 @@ export default class llamacpp_extension extends AIEngine {
}
}
const args: string[] = []
const envs: Record<string, string> = {}
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
const [version, backend] = cfg.version_backend.split('/')
if (!version || !backend) {
throw new Error(
"Initial setup for the backend failed due to a network issue. Please restart the app!"
'Initial setup for the backend failed due to a network issue. Please restart the app!'
)
}
@ -1194,7 +1203,7 @@ export default class llamacpp_extension extends AIEngine {
// disable llama-server webui
args.push('--no-webui')
const api_key = await this.generateApiKey(modelId, String(port))
args.push('--api-key', api_key)
envs["LLAMA_API_KEY"] = api_key
// model option is required
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
@ -1283,6 +1292,7 @@ export default class llamacpp_extension extends AIEngine {
backendPath,
libraryPath,
args,
envs,
})
return sInfo
} catch (error) {
@ -1299,9 +1309,12 @@ export default class llamacpp_extension extends AIEngine {
const pid = sInfo.pid
try {
// Pass the PID as the session_id
const result = await invoke<UnloadResult>('plugin:llamacpp|unload_llama_model', {
pid: pid,
})
const result = await invoke<UnloadResult>(
'plugin:llamacpp|unload_llama_model',
{
pid: pid,
}
)
// If successful, remove from active sessions
if (result.success) {
@ -1437,9 +1450,12 @@ export default class llamacpp_extension extends AIEngine {
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
try {
let sInfo = await invoke<SessionInfo>('plugin:llamacpp|find_session_by_model', {
modelId,
})
let sInfo = await invoke<SessionInfo>(
'plugin:llamacpp|find_session_by_model',
{
modelId,
}
)
return sInfo
} catch (e) {
logger.error(e)
@ -1516,7 +1532,9 @@ export default class llamacpp_extension extends AIEngine {
override async getLoadedModels(): Promise<string[]> {
try {
let models: string[] = await invoke<string[]>('plugin:llamacpp|get_loaded_models')
let models: string[] = await invoke<string[]>(
'plugin:llamacpp|get_loaded_models'
)
return models
} catch (e) {
logger.error(e)
@ -1599,14 +1617,31 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('method not implemented yet')
}
private async loadMetadata(path: string): Promise<GgufMetadata> {
try {
const data = await invoke<GgufMetadata>('plugin:llamacpp|read_gguf_metadata', {
path: path,
})
return data
} catch (err) {
throw err
}
/**
* Check if a tool is supported by the model
* Currently read from GGUF chat_template
* @param modelId
* @returns
*/
async isToolSupported(modelId: string): Promise<boolean> {
const janDataFolderPath = await getJanDataFolderPath()
const modelConfigPath = await joinPath([
this.providerPath,
'models',
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
// model option is required
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
const modelPath = await joinPath([
janDataFolderPath,
modelConfig.model_path,
])
return (await readGgufMetadata(modelPath)).metadata?.[
'tokenizer.chat_template'
]?.includes('tools')
}
}

View File

@ -44,9 +44,10 @@ jan-utils = { path = "./utils" }
libloading = "0.8.7"
log = "0.4"
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [
"client",
"transport-sse-client",
"transport-streamable-http-client",
"transport-child-process",
"tower",
"reqwest",

View File

@ -1,6 +1,7 @@
use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tauri::{Manager, Runtime, State};
@ -42,6 +43,7 @@ pub async fn load_llama_model<R: Runtime>(
backend_path: &str,
library_path: Option<&str>,
mut args: Vec<String>,
envs: HashMap<String, String>,
) -> ServerResult<SessionInfo> {
let state: State<LlamacppState> = app_handle.state();
let mut process_map = state.llama_server_process.lock().await;
@ -54,12 +56,21 @@ pub async fn load_llama_model<R: Runtime>(
let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?;
let api_key = extract_arg_value(&args, "--api-key");
let api_key: String;
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
api_key = api_value.to_string();
} else {
log::warn!("API key not provided");
api_key = "".to_string();
}
let model_id = extract_arg_value(&args, "-a");
// Configure the command to run the server
let mut command = Command::new(backend_path);
command.args(args);
command.envs(envs);
setup_library_path(library_path, &mut command);
command.stdout(Stdio::piped());

View File

@ -1,15 +1,17 @@
use rmcp::model::{CallToolRequestParam, CallToolResult, Tool};
use rmcp::{service::RunningService, RoleClient};
use rmcp::model::{CallToolRequestParam, CallToolResult};
use serde_json::{Map, Value};
use std::{collections::HashMap, sync::Arc};
use tauri::{AppHandle, Emitter, Runtime, State};
use tokio::{sync::Mutex, time::timeout};
use tokio::time::timeout;
use super::{
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers},
};
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use crate::core::{
mcp::models::ToolWithServer,
state::{RunningServiceEnum, SharedMcpServers},
};
use std::fs;
#[tauri::command]
@ -19,8 +21,7 @@ pub async fn activate_mcp_server<R: Runtime>(
name: String,
config: Value,
) -> Result<(), String> {
let servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
state.mcp_servers.clone();
let servers: SharedMcpServers = state.mcp_servers.clone();
// Use the modified start_mcp_server_with_restart that returns first attempt result
start_mcp_server_with_restart(app, servers, name, config, Some(3)).await
@ -63,7 +64,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) ->
// Release the lock before calling cancel
drop(servers_map);
service.cancel().await.map_err(|e| e.to_string())?;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
log::info!("Server {name} stopped successfully and marked as deactivated.");
Ok(())
}
@ -116,7 +126,7 @@ pub async fn get_connected_servers(
Ok(servers_map.keys().cloned().collect())
}
/// Retrieves all available tools from all MCP servers
/// Retrieves all available tools from all MCP servers with server information
///
/// # Arguments
/// * `state` - Application state containing MCP server connections
@ -128,14 +138,15 @@ pub async fn get_connected_servers(
/// 1. Locks the MCP servers mutex to access server connections
/// 2. Iterates through all connected servers
/// 3. Gets the list of tools from each server
/// 4. Combines all tools into a single vector
/// 5. Returns the combined list of all available tools
/// 4. Associates each tool with its parent server name
/// 5. Combines all tools into a single vector
/// 6. Returns the combined list of all available tools with server information
#[tauri::command]
pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String> {
pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>, String> {
let servers = state.mcp_servers.lock().await;
let mut all_tools: Vec<Tool> = Vec::new();
let mut all_tools: Vec<ToolWithServer> = Vec::new();
for (_, service) in servers.iter() {
for (server_name, service) in servers.iter() {
// List tools with timeout
let tools_future = service.list_all_tools();
let tools = match timeout(MCP_TOOL_CALL_TIMEOUT, tools_future).await {
@ -150,7 +161,12 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String>
};
for tool in tools {
all_tools.push(tool);
all_tools.push(ToolWithServer {
name: tool.name.to_string(),
description: tool.description.as_ref().map(|d| d.to_string()),
input_schema: serde_json::Value::Object((*tool.input_schema).clone()),
server: server_name.clone(),
});
}
}

View File

@ -1,7 +1,15 @@
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt};
use rmcp::{
model::{ClientCapabilities, ClientInfo, Implementation},
transport::{
streamable_http_client::StreamableHttpClientTransportConfig, SseClientTransport,
StreamableHttpClientTransport, TokioChildProcess,
},
ServiceExt,
};
use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc, time::Duration};
use tauri::{AppHandle, Emitter, Manager, Runtime, State};
use tauri_plugin_http::reqwest;
use tokio::{
process::Command,
sync::Mutex,
@ -11,7 +19,11 @@ use tokio::{
use super::constants::{
MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS,
};
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use crate::core::{
app::commands::get_jan_data_folder_path,
mcp::models::McpServerConfig,
state::{AppState, RunningServiceEnum, SharedMcpServers},
};
use jan_utils::can_override_npx;
/// Calculate exponential backoff delay with jitter
@ -72,7 +84,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 {
/// * `Err(String)` if there was an error reading config or starting servers
pub async fn run_mcp_commands<R: Runtime>(
app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
) -> Result<(), String> {
let app_path = get_jan_data_folder_path(app.clone());
let app_path_str = app_path.to_str().unwrap().to_string();
@ -168,7 +180,7 @@ pub async fn run_mcp_commands<R: Runtime>(
/// Monitor MCP server health without removing it from the HashMap
pub async fn monitor_mcp_server_handle(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
) -> Option<rmcp::service::QuitReason> {
log::info!("Monitoring MCP server {} health", name);
@ -213,7 +225,16 @@ pub async fn monitor_mcp_server_handle(
let mut servers = servers_state.lock().await;
if let Some(service) = servers.remove(&name) {
// Try to cancel the service gracefully
let _ = service.cancel().await;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
let _ = service.cancel().await;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
let _ = service.cancel().await;
}
}
}
return Some(rmcp::service::QuitReason::Closed);
}
@ -224,7 +245,7 @@ pub async fn monitor_mcp_server_handle(
/// Returns the result of the first start attempt, then continues with restart monitoring
pub async fn start_mcp_server_with_restart<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: Option<u32>,
@ -297,7 +318,7 @@ pub async fn start_mcp_server_with_restart<R: Runtime>(
/// Helper function to handle the restart loop logic
pub async fn start_restart_loop<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: u32,
@ -450,9 +471,9 @@ pub async fn start_restart_loop<R: Runtime>(
}
}
pub async fn schedule_mcp_start_task<R: Runtime>(
async fn schedule_mcp_start_task<R: Runtime>(
app: tauri::AppHandle<R>,
servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers: SharedMcpServers,
name: String,
config: Value,
) -> Result<(), String> {
@ -463,136 +484,278 @@ pub async fn schedule_mcp_start_task<R: Runtime>(
.expect("Executable must have a parent directory");
let bin_path = exe_parent_path.to_path_buf();
let (command, args, envs) = extract_command_args(&config)
let config_params = extract_command_args(&config)
.ok_or_else(|| format!("Failed to extract command args from config for {name}"))?;
let mut cmd = Command::new(command.clone());
if config_params.transport_type.as_deref() == Some("http") && config_params.url.is_some() {
let transport = StreamableHttpClientTransport::with_client(
reqwest::Client::builder()
.default_headers({
// Map envs to request headers
let mut headers: tauri::http::HeaderMap = reqwest::header::HeaderMap::new();
for (key, value) in config_params.headers.iter() {
if let Some(v_str) = value.as_str() {
// Try to map env keys to HTTP header names (case-insensitive)
// Most HTTP headers are Title-Case, so we try to convert
let header_name =
reqwest::header::HeaderName::from_bytes(key.as_bytes());
if let Ok(header_name) = header_name {
if let Ok(header_value) =
reqwest::header::HeaderValue::from_str(v_str)
{
headers.insert(header_name, header_value);
}
}
}
}
headers
})
.connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
.build()
.unwrap(),
StreamableHttpClientTransportConfig {
uri: config_params.url.unwrap().into(),
..Default::default()
},
);
if command == "npx" && can_override_npx() {
let mut cache_dir = app_path.clone();
cache_dir.push(".npx");
let bun_x_path = format!("{}/bun", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("x");
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string());
}
let client_info = ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "Jan Streamable Client".to_string(),
version: "0.0.1".to_string(),
},
};
let client = client_info.serve(transport).await.inspect_err(|e| {
log::error!("client error: {:?}", e);
});
if command == "uvx" {
let mut cache_dir = app_path.clone();
cache_dir.push(".uvx");
let bun_x_path = format!("{}/uv", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("tool");
cmd.arg("run");
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
}
match client {
Ok(client) => {
log::info!("Connected to server: {:?}", client.peer_info());
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::WithInit(client));
#[cfg(windows)]
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
}
Err(e) => {
log::error!("Failed to connect to server: {}", e);
return Err(format!("Failed to connect to server: {}", e));
}
}
} else if config_params.transport_type.as_deref() == Some("sse") && config_params.url.is_some()
{
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows
}
let app_path_str = app_path.to_str().unwrap().to_string();
let log_file_path = format!("{}/logs/app.log", app_path_str);
match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(log_file_path)
{
Ok(file) => {
cmd.stderr(std::process::Stdio::from(file));
}
Err(err) => {
log::error!("Failed to open log file: {}", err);
}
};
cmd.kill_on_drop(true);
log::trace!("Command: {cmd:#?}");
args.iter().filter_map(Value::as_str).for_each(|arg| {
cmd.arg(arg);
});
envs.iter().for_each(|(k, v)| {
if let Some(v_str) = v.as_str() {
cmd.env(k, v_str);
}
});
let process = TokioChildProcess::new(cmd).map_err(|e| {
log::error!("Failed to run command {name}: {e}");
format!("Failed to run command {name}: {e}")
})?;
let service = ()
.serve(process)
let transport = SseClientTransport::start_with_client(
reqwest::Client::builder()
.default_headers({
// Map envs to request headers
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in config_params.headers.iter() {
if let Some(v_str) = value.as_str() {
// Try to map env keys to HTTP header names (case-insensitive)
// Most HTTP headers are Title-Case, so we try to convert
let header_name =
reqwest::header::HeaderName::from_bytes(key.as_bytes());
if let Ok(header_name) = header_name {
if let Ok(header_value) =
reqwest::header::HeaderValue::from_str(v_str)
{
headers.insert(header_name, header_value);
}
}
}
}
headers
})
.connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
.build()
.unwrap(),
rmcp::transport::sse_client::SseClientConfig {
sse_endpoint: config_params.url.unwrap().into(),
..Default::default()
},
)
.await
.map_err(|e| format!("Failed to start MCP server {name}: {e}"))?;
.map_err(|e| {
log::error!("transport error: {:?}", e);
format!("Failed to start SSE transport: {}", e)
})?;
// Get peer info and clone the needed values before moving the service
let (server_name, server_version) = {
let client_info = ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "Jan SSE Client".to_string(),
version: "0.0.1".to_string(),
},
};
let client = client_info.serve(transport).await.map_err(|e| {
log::error!("client error: {:?}", e);
e.to_string()
});
match client {
Ok(client) => {
log::info!("Connected to server: {:?}", client.peer_info());
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::WithInit(client));
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
}
Err(e) => {
log::error!("Failed to connect to server: {}", e);
return Err(format!("Failed to connect to server: {}", e));
}
}
} else {
let mut cmd = Command::new(config_params.command.clone());
if config_params.command.clone() == "npx" && can_override_npx() {
let mut cache_dir = app_path.clone();
cache_dir.push(".npx");
let bun_x_path = format!("{}/bun", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("x");
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string());
}
if config_params.command.clone() == "uvx" {
let mut cache_dir = app_path.clone();
cache_dir.push(".uvx");
let bun_x_path = format!("{}/uv", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("tool");
cmd.arg("run");
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
}
#[cfg(windows)]
{
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows
}
let app_path_str = app_path.to_str().unwrap().to_string();
let log_file_path = format!("{}/logs/app.log", app_path_str);
match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(log_file_path)
{
Ok(file) => {
cmd.stderr(std::process::Stdio::from(file));
}
Err(err) => {
log::error!("Failed to open log file: {}", err);
}
};
cmd.kill_on_drop(true);
log::trace!("Command: {cmd:#?}");
config_params
.args
.iter()
.filter_map(Value::as_str)
.for_each(|arg| {
cmd.arg(arg);
});
config_params.envs.iter().for_each(|(k, v)| {
if let Some(v_str) = v.as_str() {
cmd.env(k, v_str);
}
});
let process = TokioChildProcess::new(cmd).map_err(|e| {
log::error!("Failed to run command {name}: {e}");
format!("Failed to run command {name}: {e}")
})?;
let service = ()
.serve(process)
.await
.map_err(|e| format!("Failed to start MCP server {name}: {e}"))?;
// Get peer info and clone the needed values before moving the service
let server_info = service.peer_info();
log::trace!("Connected to server: {server_info:#?}");
(
server_info.unwrap().server_info.name.clone(),
server_info.unwrap().server_info.version.clone(),
)
};
// Now move the service into the HashMap
servers.lock().await.insert(name.clone(), service);
log::info!("Server {name} started successfully.");
// Now move the service into the HashMap
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::NoInit(service));
log::info!("Server {name} started successfully.");
// Wait a short time to verify the server is stable before marking as connected
// This prevents race conditions where the server quits immediately
let verification_delay = Duration::from_millis(500);
sleep(verification_delay).await;
// Wait a short time to verify the server is stable before marking as connected
// This prevents race conditions where the server quits immediately
let verification_delay = Duration::from_millis(500);
sleep(verification_delay).await;
// Check if server is still running after the verification delay
let server_still_running = {
let servers_map = servers.lock().await;
servers_map.contains_key(&name)
};
// Check if server is still running after the verification delay
let server_still_running = {
let servers_map = servers.lock().await;
servers_map.contains_key(&name)
};
if !server_still_running {
return Err(format!(
"MCP server {} quit immediately after starting",
name
));
if !server_still_running {
return Err(format!(
"MCP server {} quit immediately after starting",
name
));
}
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
}
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
// Emit event to the frontend
let event = format!("mcp-connected");
let payload = serde_json::json!({
"name": server_name,
"version": server_version,
});
app.emit(&event, payload)
.map_err(|e| format!("Failed to emit event: {}", e))?;
Ok(())
}
pub fn extract_command_args(
config: &Value,
) -> Option<(String, Vec<Value>, serde_json::Map<String, Value>)> {
pub fn extract_command_args(config: &Value) -> Option<McpServerConfig> {
let obj = config.as_object()?;
let command = obj.get("command")?.as_str()?.to_string();
let args = obj.get("args")?.as_array()?.clone();
let url = obj.get("url").and_then(|u| u.as_str()).map(String::from);
let transport_type = obj.get("type").and_then(|t| t.as_str()).map(String::from);
let timeout = obj
.get("timeout")
.and_then(|t| t.as_u64())
.map(Duration::from_secs);
let headers = obj
.get("headers")
.unwrap_or(&Value::Object(serde_json::Map::new()))
.as_object()?
.clone();
let envs = obj
.get("env")
.unwrap_or(&Value::Object(serde_json::Map::new()))
.as_object()?
.clone();
Some((command, args, envs))
Some(McpServerConfig {
timeout,
transport_type,
url,
command,
args,
envs,
headers
})
}
pub fn extract_active_status(config: &Value) -> Option<bool> {
@ -604,7 +767,7 @@ pub fn extract_active_status(config: &Value) -> Option<bool> {
/// Restart only servers that were previously active (like cortex restart behavior)
pub async fn restart_active_mcp_servers<R: Runtime>(
app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
) -> Result<(), String> {
let app_state = app.state::<AppState>();
let active_servers = app_state.mcp_active_servers.lock().await;
@ -656,14 +819,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) {
log::info!("MCP servers cleaned up successfully");
}
pub async fn stop_mcp_servers(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
) -> Result<(), String> {
pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> {
let mut servers_map = servers_state.lock().await;
let keys: Vec<String> = servers_map.keys().cloned().collect();
for key in keys {
if let Some(service) = servers_map.remove(&key) {
service.cancel().await.map_err(|e| e.to_string())?;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {key}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {key} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
}
}
drop(servers_map); // Release the lock after stopping
@ -689,7 +859,7 @@ pub async fn reset_restart_count(restart_counts: &Arc<Mutex<HashMap<String, u32>
/// Spawn the server monitoring task for handling restarts
pub async fn spawn_server_monitoring_task<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: u32,

View File

@ -1,6 +1,7 @@
pub mod commands;
mod constants;
pub mod helpers;
pub mod models;
#[cfg(test)]
mod tests;

View File

@ -0,0 +1,26 @@
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Configuration parameters extracted from MCP server config
#[derive(Debug, Clone)]
pub struct McpServerConfig {
pub transport_type: Option<String>,
pub url: Option<String>,
pub command: String,
pub args: Vec<Value>,
pub envs: serde_json::Map<String, Value>,
pub timeout: Option<Duration>,
pub headers: serde_json::Map<String, Value>,
}
/// Tool with server information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolWithServer {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: serde_json::Value,
pub server: String,
}

View File

@ -1,6 +1,6 @@
use super::helpers::run_mcp_commands;
use crate::core::app::commands::get_jan_data_folder_path;
use rmcp::{service::RunningService, RoleClient};
use crate::core::state::SharedMcpServers;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
@ -27,7 +27,7 @@ async fn test_run_mcp_commands() {
.expect("Failed to write to config file");
// Call the run_mcp_commands function
let servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
let servers_state: SharedMcpServers =
Arc::new(Mutex::new(HashMap::new()));
let result = run_mcp_commands(app.handle(), servers_state).await;

View File

@ -1,20 +1,48 @@
use std::{collections::HashMap, sync::Arc};
use crate::core::downloads::models::DownloadManagerState;
use rmcp::{service::RunningService, RoleClient};
use rmcp::{
model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool},
service::RunningService,
RoleClient, ServiceError,
};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
use tokio::sync::Mutex;
pub enum RunningServiceEnum {
NoInit(RunningService<RoleClient, ()>),
WithInit(RunningService<RoleClient, InitializeRequestParam>),
}
pub type SharedMcpServers = Arc<Mutex<HashMap<String, RunningServiceEnum>>>;
#[derive(Default)]
pub struct AppState {
pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
pub mcp_servers: SharedMcpServers,
pub download_manager: Arc<Mutex<DownloadManagerState>>,
pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>,
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
}
impl RunningServiceEnum {
pub async fn list_all_tools(&self) -> Result<Vec<Tool>, ServiceError> {
match self {
Self::NoInit(s) => s.list_all_tools().await,
Self::WithInit(s) => s.list_all_tools().await,
}
}
pub async fn call_tool(
&self,
params: CallToolRequestParam,
) -> Result<CallToolResult, ServiceError> {
match self {
Self::NoInit(s) => s.call_tool(params).await,
Self::WithInit(s) => s.call_tool(params).await,
}
}
}

View File

@ -17,11 +17,12 @@
"@dnd-kit/sortable": "^10.0.0",
"@janhq/core": "link:../core",
"@radix-ui/react-accordion": "^1.2.10",
"@radix-ui/react-dialog": "^1.1.11",
"@radix-ui/react-dropdown-menu": "^2.1.11",
"@radix-ui/react-dialog": "^1.1.14",
"@radix-ui/react-dropdown-menu": "^2.1.15",
"@radix-ui/react-hover-card": "^1.1.14",
"@radix-ui/react-popover": "^1.1.13",
"@radix-ui/react-progress": "^1.1.4",
"@radix-ui/react-radio-group": "^1.3.7",
"@radix-ui/react-slider": "^1.3.2",
"@radix-ui/react-slot": "^1.2.0",
"@radix-ui/react-switch": "^1.2.2",
@ -43,13 +44,14 @@
"class-variance-authority": "^0.7.1",
"culori": "^4.0.1",
"emoji-picker-react": "^4.12.2",
"framer-motion": "^12.23.12",
"fuse.js": "^7.1.0",
"fzf": "^0.5.2",
"i18next": "^25.0.1",
"katex": "^0.16.22",
"lodash.clonedeep": "^4.5.0",
"lodash.debounce": "^4.0.8",
"lucide-react": "^0.522.0",
"lucide-react": "^0.536.0",
"motion": "^12.10.5",
"next-themes": "^0.4.6",
"posthog-js": "^1.246.0",
@ -70,11 +72,12 @@
"remark-math": "^6.0.0",
"sonner": "^2.0.3",
"tailwindcss": "^4.1.4",
"token.js": "npm:token.js-fork@0.7.20",
"token.js": "npm:token.js-fork@0.7.23",
"tw-animate-css": "^1.2.7",
"ulidx": "^2.4.1",
"unified": "^11.0.5",
"uuid": "^11.1.0",
"vaul": "^1.1.2",
"zustand": "^5.0.3"
},
"devDependencies": {
@ -104,7 +107,7 @@
"istanbul-lib-report": "^3.0.1",
"istanbul-reports": "^3.1.7",
"jsdom": "^26.1.0",
"tailwind-merge": "^3.2.0",
"tailwind-merge": "^3.3.1",
"typescript": "~5.8.3",
"typescript-eslint": "^8.26.1",
"vite": "^6.3.0",

View File

@ -0,0 +1,533 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import '@testing-library/jest-dom'
import {
DropDrawer,
DropDrawerContent,
DropDrawerFooter,
DropDrawerGroup,
DropDrawerItem,
DropDrawerLabel,
DropDrawerSeparator,
DropDrawerSub,
DropDrawerSubContent,
DropDrawerSubTrigger,
DropDrawerTrigger,
} from '../dropdrawer'
// Mock the media query hook
const mockUseSmallScreen = vi.fn()
vi.mock('@/hooks/useMediaQuery', () => ({
useSmallScreen: () => mockUseSmallScreen(),
}))
// Mock framer-motion to avoid animation complexity in tests
vi.mock('framer-motion', () => ({
AnimatePresence: ({ children }: { children: React.ReactNode }) => <div data-testid="animate-presence">{children}</div>,
motion: {
div: ({ children, ...props }: any) => <div {...props}>{children}</div>,
},
}))
describe('DropDrawer Utilities', () => {
it('renders without crashing', () => {
expect(() => {
render(
<DropDrawer>
<DropDrawerTrigger>Test</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
}).not.toThrow()
})
})
describe('DropDrawer Component', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Desktop Mode', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('renders dropdown menu on desktop', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('renders dropdown menu structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Desktop Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Only the trigger is visible initially
expect(screen.getByText('Open Menu')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'menu')
})
it('structures dropdown with separators', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerSeparator />
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify component structure - content is not visible until opened
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures dropdown with labels', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerLabel>Menu Section</DropDrawerLabel>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Only verify trigger is present - content shows on interaction
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('Mobile Mode', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(true)
})
it('renders drawer on mobile', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Mobile Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
it('renders drawer structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Mobile Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify drawer trigger is present
const trigger = screen.getByText('Open Drawer')
expect(trigger).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'dialog')
})
it('does not render separators in mobile mode', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerSeparator />
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Mobile separators return null, so they shouldn't be in the DOM
const separators = screen.queryAllByRole('separator')
expect(separators).toHaveLength(0)
})
it('renders drawer with labels structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerLabel>Drawer Section</DropDrawerLabel>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify drawer structure is present
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('DropDrawerItem', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('can be structured with click handlers', () => {
const handleClick = vi.fn()
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem onClick={handleClick}>Clickable Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
expect(handleClick).not.toHaveBeenCalled()
})
it('can be structured with icons', () => {
const TestIcon = () => <span data-testid="test-icon">Icon</span>
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem icon={<TestIcon />}>Item with Icon</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('accepts variant props', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem variant="destructive">
Delete Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid with variants
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('accepts disabled prop', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem disabled>
Disabled Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid with disabled prop
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('DropDrawerGroup', () => {
it('structures groups in desktop mode', () => {
mockUseSmallScreen.mockReturnValue(false)
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Group Item 1</DropDrawerItem>
<DropDrawerItem>Group Item 2</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures groups in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('DropDrawerFooter', () => {
it('structures footer in desktop mode', () => {
mockUseSmallScreen.mockReturnValue(false)
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
<DropDrawerFooter>Footer Content</DropDrawerFooter>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures footer in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
<DropDrawerFooter>Mobile Footer</DropDrawerFooter>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('Submenu Components', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('structures submenu in desktop mode', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>Submenu Trigger</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Submenu Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures submenu in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>
Mobile Submenu
</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Submenu Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
it('handles submenu content correctly in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>Mobile Submenu</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Hidden Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component handles mobile submenu structure correctly
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('Accessibility', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('maintains proper ARIA attributes on triggers', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
const trigger = screen.getByRole('button')
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
})
it('supports disabled state', () => {
const handleClick = vi.fn()
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem disabled onClick={handleClick}>
Disabled Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure supports disabled prop
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
expect(handleClick).not.toHaveBeenCalled()
})
})
describe('Error Boundaries', () => {
it('requires proper context usage', () => {
// Suppress console.error for this test
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
expect(() => {
render(<DropDrawerItem>Orphan Item</DropDrawerItem>)
}).toThrow()
consoleSpy.mockRestore()
})
})
describe('Custom Props and Styling', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('applies custom className', () => {
render(
<DropDrawer>
<DropDrawerTrigger className="custom-trigger">Custom Trigger</DropDrawerTrigger>
<DropDrawerContent className="custom-content">
<DropDrawerItem className="custom-item">Custom Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
const trigger = screen.getByText('Custom Trigger')
expect(trigger).toHaveClass('custom-trigger')
})
it('accepts additional props', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem data-custom="test-value">Custom Props Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure accepts custom props
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('Responsive Behavior', () => {
it('adapts to different screen sizes', () => {
const { rerender } = render(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Desktop mode
mockUseSmallScreen.mockReturnValue(false)
rerender(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
let trigger = screen.getByText('Responsive Trigger')
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
// Mobile mode
mockUseSmallScreen.mockReturnValue(true)
rerender(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
trigger = screen.getByText('Responsive Trigger')
expect(trigger).toHaveAttribute('aria-haspopup', 'dialog')
})
})
})

View File

@ -0,0 +1,62 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { RadioGroup, RadioGroupItem } from '../radio-group'
describe('RadioGroup', () => {
it('renders radio items correctly', () => {
render(
<RadioGroup defaultValue="http">
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
expect(screen.getByLabelText('HTTP')).toBeInTheDocument()
expect(screen.getByLabelText('SSE')).toBeInTheDocument()
})
it('allows selecting different options', async () => {
const user = userEvent.setup()
const onValueChange = vi.fn()
render(
<RadioGroup defaultValue="http" onValueChange={onValueChange}>
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
await user.click(screen.getByLabelText('SSE'))
expect(onValueChange).toHaveBeenCalledWith('sse')
})
it('has correct default selection', () => {
render(
<RadioGroup defaultValue="http">
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
expect(screen.getByLabelText('HTTP')).toBeChecked()
expect(screen.getByLabelText('SSE')).not.toBeChecked()
})
})

View File

@ -0,0 +1,133 @@
import * as React from 'react'
import { Drawer as DrawerPrimitive } from 'vaul'
import { cn } from '@/lib/utils'
function Drawer({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Root>) {
return <DrawerPrimitive.Root data-slot="drawer" {...props} />
}
function DrawerTrigger({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Trigger>) {
return <DrawerPrimitive.Trigger data-slot="drawer-trigger" {...props} />
}
function DrawerPortal({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Portal>) {
return <DrawerPrimitive.Portal data-slot="drawer-portal" {...props} />
}
function DrawerClose({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Close>) {
return <DrawerPrimitive.Close data-slot="drawer-close" {...props} />
}
function DrawerOverlay({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Overlay>) {
return (
<DrawerPrimitive.Overlay
data-slot="drawer-overlay"
className={cn(
'data-[state=open]:animate-in backdrop-blur data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50',
className
)}
{...props}
/>
)
}
function DrawerContent({
className,
children,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Content>) {
return (
<DrawerPortal data-slot="drawer-portal">
<DrawerOverlay />
<DrawerPrimitive.Content
data-slot="drawer-content"
className={cn(
'group/drawer-content fixed z-50 flex h-auto flex-col',
'data-[vaul-drawer-direction=top]:inset-x-0 data-[vaul-drawer-direction=top]:top-0 data-[vaul-drawer-direction=top]:mb-24 data-[vaul-drawer-direction=top]:max-h-[80vh] data-[vaul-drawer-direction=top]:rounded-b-lg data-[vaul-drawer-direction=top]:border-b',
'data-[vaul-drawer-direction=bottom]:inset-x-0 data-[vaul-drawer-direction=bottom]:bottom-0 data-[vaul-drawer-direction=bottom]:mt-24 data-[vaul-drawer-direction=bottom]:max-h-[80vh] data-[vaul-drawer-direction=bottom]:rounded-t-lg data-[vaul-drawer-direction=bottom]:border-t',
'data-[vaul-drawer-direction=right]:inset-y-0 data-[vaul-drawer-direction=right]:right-0 data-[vaul-drawer-direction=right]:w-3/4 data-[vaul-drawer-direction=right]:border-l data-[vaul-drawer-direction=right]:sm:max-w-sm',
'data-[vaul-drawer-direction=left]:inset-y-0 data-[vaul-drawer-direction=left]:left-0 data-[vaul-drawer-direction=left]:w-3/4 data-[vaul-drawer-direction=left]:border-r data-[vaul-drawer-direction=left]:sm:max-w-sm',
className
)}
{...props}
>
<div className="bg-main-view-fg/10 mx-auto mt-4 hidden h-2 w-[100px] shrink-0 rounded-full group-data-[vaul-drawer-direction=bottom]/drawer-content:block" />
{children}
</DrawerPrimitive.Content>
</DrawerPortal>
)
}
function DrawerHeader({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="drawer-header"
className={cn(
'flex flex-col gap-0.5 p-4 group-data-[vaul-drawer-direction=bottom]/drawer-content:text-center group-data-[vaul-drawer-direction=top]/drawer-content:text-center md:gap-1.5 md:text-left',
className
)}
{...props}
/>
)
}
function DrawerFooter({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="drawer-footer"
className={cn('mt-auto flex flex-col gap-2 p-4', className)}
{...props}
/>
)
}
function DrawerTitle({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Title>) {
return (
<DrawerPrimitive.Title
data-slot="drawer-title"
className={cn('font-semibold', className)}
{...props}
/>
)
}
function DrawerDescription({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Description>) {
return (
<DrawerPrimitive.Description
data-slot="drawer-description"
className={cn('text-sm', className)}
{...props}
/>
)
}
export {
Drawer,
DrawerPortal,
DrawerOverlay,
DrawerTrigger,
DrawerClose,
DrawerContent,
DrawerHeader,
DrawerFooter,
DrawerTitle,
DrawerDescription,
}

View File

@ -61,14 +61,17 @@ function DropdownMenuGroup({
function DropdownMenuItem({
className,
inset,
variant = 'default',
...props
}: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & {
inset?: boolean
variant?: 'default' | 'destructive'
}) {
return (
<DropdownMenuPrimitive.Item
data-slot="dropdown-menu-item"
data-inset={inset}
data-variant={variant}
className={cn(
"relative cursor-pointer hover:bg-main-view-fg/4 flex items-center gap-2 rounded-sm px-2 py-1 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
className
@ -213,7 +216,7 @@ function DropdownMenuSubTrigger({
{...props}
>
{children}
<ChevronRightIcon className="ml-auto size-4" />
<ChevronRightIcon className="ml-auto size-4 text-main-view-fg/50" />
</DropdownMenuPrimitive.SubTrigger>
)
}

View File

@ -0,0 +1,949 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
'use client'
import { AnimatePresence, motion } from 'framer-motion'
import { ChevronLeftIcon, ChevronRightIcon } from 'lucide-react'
import * as React from 'react'
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerFooter,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from '@/components/ui/drawer'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
import { cn } from '@/lib/utils'
import { useSmallScreen } from '@/hooks/useMediaQuery'
const ANIMATION_CONFIG = {
variants: {
enter: (direction: 'forward' | 'backward') => ({
x: direction === 'forward' ? '100%' : '-100%',
opacity: 0,
}),
center: {
x: 0,
opacity: 1,
},
exit: (direction: 'forward' | 'backward') => ({
x: direction === 'forward' ? '-100%' : '100%',
opacity: 0,
}),
},
transition: {
duration: 0.3,
ease: [0.25, 0.1, 0.25, 1.0],
},
} as const
const getMobileItemStyles = (
isInsideGroup: boolean,
inset?: boolean,
variant?: string,
disabled?: boolean
) => {
return cn(
'flex cursor-pointer items-center justify-between px-4 py-4 w-full gap-4',
!isInsideGroup && 'bg-main-view-fg/50 mx-2 my-1.5 rounded-md',
isInsideGroup && 'bg-transparent py-4',
inset && 'pl-8',
variant === 'destructive' && 'text-destructive',
disabled && 'pointer-events-none opacity-50'
)
}
const DropDrawerContext = React.createContext<{ isMobile: boolean }>({
isMobile: false,
})
const useDropDrawerContext = () => {
const context = React.useContext(DropDrawerContext)
if (!context) {
throw new Error(
'DropDrawer components cannot be rendered outside the DropDrawer Context'
)
}
return context
}
const useComponentSelection = () => {
const { isMobile } = useDropDrawerContext()
const selectComponent = <T, D>(mobileComponent: T, desktopComponent: D) => {
return isMobile ? mobileComponent : desktopComponent
}
return { isMobile, selectComponent }
}
const useGroupDetection = () => {
const isInGroup = React.useCallback(
(element: HTMLElement | null): boolean => {
if (!element) return false
let parent = element.parentElement
while (parent) {
if (parent.hasAttribute('data-drop-drawer-group')) {
return true
}
parent = parent.parentElement
}
return false
},
[]
)
const useGroupState = () => {
const { isMobile } = useComponentSelection()
const itemRef = React.useRef<HTMLDivElement>(null)
const [isInsideGroup, setIsInsideGroup] = React.useState(false)
React.useEffect(() => {
if (!isMobile) return
const timer = setTimeout(() => {
if (itemRef.current) {
setIsInsideGroup(isInGroup(itemRef.current))
}
}, 0)
return () => clearTimeout(timer)
}, [isMobile])
return { itemRef, isInsideGroup }
}
return { isInGroup, useGroupState }
}
type ConditionalComponentProps<T, D> = {
children: React.ReactNode
className?: string
} & (T | D)
const ConditionalComponent = <T, D>({
mobileComponent,
desktopComponent,
children,
...props
}: {
mobileComponent: React.ComponentType<any>
desktopComponent: React.ComponentType<any>
children: React.ReactNode
} & ConditionalComponentProps<T, D>) => {
const { selectComponent } = useComponentSelection()
const Component = selectComponent(mobileComponent, desktopComponent)
return <Component {...props}>{children}</Component>
}
function DropDrawer({
children,
...props
}:
| React.ComponentProps<typeof Drawer>
| React.ComponentProps<typeof DropdownMenu>) {
const isMobile = useSmallScreen()
return (
<DropDrawerContext.Provider value={{ isMobile }}>
<ConditionalComponent
mobileComponent={Drawer}
desktopComponent={DropdownMenu}
data-slot="drop-drawer"
{...props}
>
{children}
</ConditionalComponent>
</DropDrawerContext.Provider>
)
}
function DropDrawerTrigger({
className,
children,
...props
}:
| React.ComponentProps<typeof DrawerTrigger>
| React.ComponentProps<typeof DropdownMenuTrigger>) {
return (
<ConditionalComponent
mobileComponent={DrawerTrigger}
desktopComponent={DropdownMenuTrigger}
data-slot="drop-drawer-trigger"
className={className}
{...props}
>
{children}
</ConditionalComponent>
)
}
function DropDrawerContent({
className,
children,
...props
}:
| React.ComponentProps<typeof DrawerContent>
| React.ComponentProps<typeof DropdownMenuContent>) {
const { isMobile } = useDropDrawerContext()
const [activeSubmenu, setActiveSubmenu] = React.useState<string | null>(null)
const [submenuTitle, setSubmenuTitle] = React.useState<string | null>(null)
const [submenuStack, setSubmenuStack] = React.useState<
{ id: string; title: string }[]
>([])
// Add animation direction state
const [animationDirection, setAnimationDirection] = React.useState<
'forward' | 'backward'
>('forward')
// Create a ref to store submenu content by ID
const submenuContentRef = React.useRef<Map<string, React.ReactNode[]>>(
new Map()
)
// Function to navigate to a submenu
const navigateToSubmenu = React.useCallback((id: string, title: string) => {
// Set animation direction to forward when navigating to a submenu
setAnimationDirection('forward')
setActiveSubmenu(id)
setSubmenuTitle(title)
setSubmenuStack((prev) => [...prev, { id, title }])
}, [])
// Function to go back to previous menu
const goBack = React.useCallback(() => {
// Set animation direction to backward when going back
setAnimationDirection('backward')
if (submenuStack.length <= 1) {
// If we're at the first level, go back to main menu
setActiveSubmenu(null)
setSubmenuTitle(null)
setSubmenuStack([])
} else {
// Go back to previous submenu
const newStack = [...submenuStack]
newStack.pop() // Remove current
const previous = newStack[newStack.length - 1]
setActiveSubmenu(previous.id)
setSubmenuTitle(previous.title)
setSubmenuStack(newStack)
}
}, [submenuStack])
// Function to register submenu content
const registerSubmenuContent = React.useCallback(
(id: string, content: React.ReactNode[]) => {
submenuContentRef.current.set(id, content)
},
[]
)
const extractSubmenuContent = React.useCallback(
(elements: React.ReactNode, targetId: string): React.ReactNode[] => {
const result: React.ReactNode[] = []
const findSubmenuContent = (node: React.ReactNode) => {
if (!React.isValidElement(node)) return
const element = node as React.ReactElement
const props = element.props as {
'id'?: string
'data-submenu-id'?: string
'children'?: React.ReactNode
}
if (element.type === DropDrawerSub) {
const elementId = props.id || props['data-submenu-id']
if (elementId === targetId) {
React.Children.forEach(props.children, (child) => {
if (
React.isValidElement(child) &&
child.type === DropDrawerSubContent
) {
const subContentProps = child.props as {
children?: React.ReactNode
}
React.Children.forEach(
subContentProps.children,
(contentChild) => {
result.push(contentChild)
}
)
}
})
return
}
}
if (props.children) {
React.Children.forEach(props.children, findSubmenuContent)
}
}
React.Children.forEach(elements, findSubmenuContent)
return result
},
[]
)
// Get submenu content (always extract fresh to reflect state changes)
const getSubmenuContent = React.useCallback(
(id: string) => {
// Always extract fresh content to ensure state changes are reflected
const submenuContent = extractSubmenuContent(children, id)
return submenuContent
},
[children, extractSubmenuContent]
)
if (isMobile) {
return (
<SubmenuContext.Provider
value={{
activeSubmenu,
setActiveSubmenu: (id) => {
if (id === null) {
setActiveSubmenu(null)
setSubmenuTitle(null)
setSubmenuStack([])
}
},
submenuTitle,
setSubmenuTitle,
navigateToSubmenu,
registerSubmenuContent,
}}
>
<DrawerContent
data-slot="drop-drawer-content"
className={cn(
'max-h-[90vh] w-full overflow-hidden max-w-none',
className
)}
{...props}
>
{activeSubmenu ? (
<>
<DrawerHeader>
<div className="flex items-center gap-2">
<button
onClick={goBack}
className="hover:bg-muted/50 rounded-full "
>
<ChevronLeftIcon className="h-5 w-5 text-main-view-fg/50" />
</button>
<DrawerTitle className="text-main-view-fg/80 text-sm">
{submenuTitle || 'Submenu'}
</DrawerTitle>
</div>
</DrawerHeader>
<div className="flex-1 relative overflow-hidden max-h-[70vh]">
{/* Use AnimatePresence to handle exit animations */}
<AnimatePresence
initial={false}
mode="wait"
custom={animationDirection}
>
<motion.div
key={activeSubmenu || 'main'}
custom={animationDirection}
variants={ANIMATION_CONFIG.variants}
initial="enter"
animate="center"
exit="exit"
transition={ANIMATION_CONFIG.transition as any}
className="pb-6 space-y-1.5 w-full h-full overflow-hidden"
>
{activeSubmenu
? getSubmenuContent(activeSubmenu)
: children}
</motion.div>
</AnimatePresence>
</div>
</>
) : (
<>
<DrawerHeader className="sr-only">
<DrawerTitle>Menu</DrawerTitle>
</DrawerHeader>
<div className="overflow-hidden max-h-[70vh]">
<AnimatePresence
initial={false}
mode="wait"
custom={animationDirection}
>
<motion.div
key="main-menu"
custom={animationDirection}
variants={ANIMATION_CONFIG.variants}
initial="enter"
animate="center"
exit="exit"
transition={ANIMATION_CONFIG.transition as any}
className="pb-6 space-y-1.5 w-full overflow-hidden"
>
{children}
</motion.div>
</AnimatePresence>
</div>
</>
)}
</DrawerContent>
</SubmenuContext.Provider>
)
}
return (
<SubmenuContext.Provider
value={{
activeSubmenu,
setActiveSubmenu,
submenuTitle,
setSubmenuTitle,
navigateToSubmenu,
registerSubmenuContent,
}}
>
<DropdownMenuContent
data-slot="drop-drawer-content"
sideOffset={4}
className={cn(
'max-h-[var(--radix-dropdown-menu-content-available-height)] min-w-[220px] overflow-hidden',
className
)}
{...props}
>
{children}
</DropdownMenuContent>
</SubmenuContext.Provider>
)
}
function DropDrawerItem({
className,
children,
onSelect,
onClick,
icon,
variant = 'default',
inset,
disabled,
...props
}: React.ComponentProps<typeof DropdownMenuItem> & {
icon?: React.ReactNode
}) {
const { isMobile } = useComponentSelection()
const { useGroupState } = useGroupDetection()
const { itemRef, isInsideGroup } = useGroupState()
if (isMobile) {
const handleClick = (e: React.MouseEvent<HTMLDivElement>) => {
if (disabled) return
// If this item only has an icon (like a switch) and no other interactive content,
// don't handle clicks on the main area - let the icon handle everything
if (icon && !onClick && !onSelect) {
return
}
// Check if the click came from the icon area (where the Switch is)
const target = e.target as HTMLElement
const iconContainer = (e.currentTarget as HTMLElement).querySelector(
'[data-icon-container]'
)
if (iconContainer && iconContainer.contains(target)) {
// Don't handle the click if it came from the icon area
return
}
if (onClick) onClick(e)
if (onSelect) onSelect(e as unknown as Event)
}
// Only wrap in DrawerClose if it's not a submenu item
const content = (
<div
ref={itemRef}
data-slot="drop-drawer-item"
data-variant={variant}
data-inset={inset}
data-disabled={disabled}
className={cn(
getMobileItemStyles(isInsideGroup, inset, variant, disabled),
className
)}
onClick={handleClick}
aria-disabled={disabled}
{...props}
>
<div className="flex items-center gap-2">{children}</div>
{icon && (
<div className="flex-shrink-0" data-icon-container>
{icon}
</div>
)}
</div>
)
// Check if this is inside a submenu
const isInSubmenu =
(props as Record<string, unknown>)['data-parent-submenu-id'] ||
(props as Record<string, unknown>)['data-parent-submenu']
if (isInSubmenu) {
return content
}
return <DrawerClose asChild>{content}</DrawerClose>
}
return (
<DropdownMenuItem
data-slot="drop-drawer-item"
data-variant={variant}
data-inset={inset}
className={className}
onSelect={onSelect}
onClick={onClick as React.MouseEventHandler<HTMLDivElement>}
variant={variant}
inset={inset}
disabled={disabled}
{...props}
>
<div className="flex w-full items-start justify-between gap-4">
<div>{children}</div>
{icon && <div>{icon}</div>}
</div>
</DropdownMenuItem>
)
}
function DropDrawerSeparator({
className,
...props
}: React.ComponentProps<typeof DropdownMenuSeparator>) {
const { isMobile } = useComponentSelection()
if (isMobile) {
return null
}
return (
<DropdownMenuSeparator
data-slot="drop-drawer-separator"
className={className}
{...props}
/>
)
}
function DropDrawerLabel({
className,
children,
...props
}:
| React.ComponentProps<typeof DropdownMenuLabel>
| React.ComponentProps<typeof DrawerTitle>) {
const { isMobile } = useComponentSelection()
if (isMobile) {
return (
<DrawerHeader className="p-0">
<DrawerTitle
data-slot="drop-drawer-label"
className={cn(
'px-4 py-2 text-sm font-medium text-main-view-fg/60',
className
)}
{...props}
>
{children}
</DrawerTitle>
</DrawerHeader>
)
}
return (
<DropdownMenuLabel
data-slot="drop-drawer-label"
className={className}
{...props}
>
{children}
</DropdownMenuLabel>
)
}
function DropDrawerFooter({
className,
children,
...props
}: React.ComponentProps<typeof DrawerFooter> | React.ComponentProps<'div'>) {
const { isMobile } = useDropDrawerContext()
if (isMobile) {
return (
<DrawerFooter
data-slot="drop-drawer-footer"
className={cn('p-4', className)}
{...props}
>
{children}
</DrawerFooter>
)
}
// No direct equivalent in DropdownMenu, so we'll just render a div
return (
<div
data-slot="drop-drawer-footer"
className={cn('p-2', className)}
{...props}
>
{children}
</div>
)
}
function DropDrawerGroup({
className,
children,
...props
}: React.ComponentProps<'div'> & {
children: React.ReactNode
}) {
const { isMobile } = useDropDrawerContext()
// Add separators between children on mobile
const childrenWithSeparators = React.useMemo(() => {
if (!isMobile) return children
const childArray = React.Children.toArray(children)
// Filter out any existing separators
const filteredChildren = childArray.filter(
(child) =>
React.isValidElement(child) && child.type !== DropDrawerSeparator
)
// Add separators between items
return filteredChildren.flatMap((child, index) => {
if (index === filteredChildren.length - 1) return [child]
return [
child,
<div
key={`separator-${index}`}
className="bg-border h-px"
aria-hidden="true"
/>,
]
})
}, [children, isMobile])
if (isMobile) {
return (
<div
data-drop-drawer-group
data-slot="drop-drawer-group"
role="group"
className={cn(
'bg-main-view-fg/2 border border-main-view-fg/4 mx-2 my-3 overflow-hidden rounded-xl',
className
)}
{...props}
>
{childrenWithSeparators}
</div>
)
}
// On desktop, use a div with proper role and attributes
return (
<div
data-drop-drawer-group
data-slot="drop-drawer-group"
role="group"
className={className}
{...props}
>
{children}
</div>
)
}
// Context for managing submenu state on mobile
interface SubmenuContextType {
activeSubmenu: string | null
setActiveSubmenu: (id: string | null) => void
submenuTitle: string | null
setSubmenuTitle: (title: string | null) => void
navigateToSubmenu?: (id: string, title: string) => void
registerSubmenuContent?: (id: string, content: React.ReactNode[]) => void
}
const SubmenuContext = React.createContext<SubmenuContextType>({
activeSubmenu: null,
setActiveSubmenu: () => {},
submenuTitle: null,
setSubmenuTitle: () => {},
navigateToSubmenu: undefined,
registerSubmenuContent: undefined,
})
// Submenu components
// Counter for generating simple numeric IDs
let submenuIdCounter = 0
function DropDrawerSub({
children,
id,
title,
...props
}: React.ComponentProps<typeof DropdownMenuSub> & {
id?: string
title?: string
}) {
const { isMobile } = useDropDrawerContext()
const { registerSubmenuContent } = React.useContext(SubmenuContext)
// Generate a simple numeric ID instead of using React.useId()
const [generatedId] = React.useState(() => `submenu-${submenuIdCounter++}`)
const submenuId = id || generatedId
// Extract submenu content to register with parent
React.useEffect(() => {
if (!registerSubmenuContent) return
// Find the SubContent within this Sub
const contentItems: React.ReactNode[] = []
React.Children.forEach(children, (child) => {
if (React.isValidElement(child) && child.type === DropDrawerSubContent) {
// Add all children of the SubContent to the result
React.Children.forEach(
(child.props as { children?: React.ReactNode }).children,
(contentChild) => {
contentItems.push(contentChild)
}
)
}
})
// Register the content with the parent
if (contentItems.length > 0) {
registerSubmenuContent(submenuId, contentItems)
}
}, [children, registerSubmenuContent, submenuId])
if (isMobile) {
// For mobile, we'll use the context to manage submenu state
// Process children to pass the submenu ID to the trigger and content
const processedChildren = React.Children.map(children, (child) => {
if (!React.isValidElement(child)) return child
if (child.type === DropDrawerSubTrigger) {
return React.cloneElement(
child as React.ReactElement,
{
...(child.props as object),
'data-parent-submenu-id': submenuId,
'data-submenu-id': submenuId,
// Use only data attributes, not custom props
'data-parent-submenu': submenuId,
'data-submenu-title': title,
} as React.HTMLAttributes<HTMLElement>
)
}
if (child.type === DropDrawerSubContent) {
return React.cloneElement(
child as React.ReactElement,
{
...(child.props as object),
'data-parent-submenu-id': submenuId,
'data-submenu-id': submenuId,
// Use only data attributes, not custom props
'data-parent-submenu': submenuId,
} as React.HTMLAttributes<HTMLElement>
)
}
return child
})
return (
<div
data-slot="drop-drawer-sub"
data-submenu-id={submenuId}
id={submenuId}
>
{processedChildren}
</div>
)
}
// For desktop, use the standard DropdownMenuSub
return <DropdownMenuSub {...props}>{children}</DropdownMenuSub>
}
function DropDrawerSubTrigger({
className,
inset,
children,
...props
}: React.ComponentProps<typeof DropdownMenuSubTrigger> & {
icon?: React.ReactNode
}) {
const { isMobile } = useComponentSelection()
const { navigateToSubmenu } = React.useContext(SubmenuContext)
const { useGroupState } = useGroupDetection()
const { itemRef, isInsideGroup } = useGroupState()
if (isMobile) {
// Find the parent submenu ID
const handleClick = (e: React.MouseEvent) => {
e.preventDefault()
e.stopPropagation()
// Get the closest parent with data-submenu-id attribute
const element = e.currentTarget as HTMLElement
let submenuId: string | null = null
// First check if the element itself has the data attribute
if (element.closest('[data-submenu-id]')) {
const closestElement = element.closest('[data-submenu-id]')
const id = closestElement?.getAttribute('data-submenu-id')
if (id) {
submenuId = id
}
}
// If not found, try props
if (!submenuId) {
submenuId =
((props as Record<string, unknown>)[
'data-parent-submenu-id'
] as string) ||
((props as Record<string, unknown>)['data-parent-submenu'] as string)
}
if (!submenuId) {
return
}
// Get the title - first try data attribute, then children, then fallback
const dataTitle = (props as Record<string, unknown>)[
'data-submenu-title'
] as string
const title =
dataTitle || (typeof children === 'string' ? children : 'Submenu')
// Navigate to the submenu
if (navigateToSubmenu) {
navigateToSubmenu(submenuId, title)
}
}
// Combine onClick handlers
const combinedOnClick = (e: React.MouseEvent) => {
// Call the original onClick if provided
const typedProps = props as Record<string, unknown>
if (typedProps.onClick) {
const originalOnClick =
typedProps.onClick as React.MouseEventHandler<HTMLDivElement>
originalOnClick(e as React.MouseEvent<HTMLDivElement>)
}
// Call our navigation handler
handleClick(e)
}
// Remove onClick from props to avoid duplicate handlers
const { ...restProps } = props as Record<string, unknown>
// Don't wrap in DrawerClose for submenu triggers
return (
<div
ref={itemRef}
data-slot="drop-drawer-sub-trigger"
data-inset={inset}
className={cn(getMobileItemStyles(isInsideGroup, inset), className)}
onClick={combinedOnClick}
{...restProps}
>
<div className="flex items-center gap-2 w-full">{children}</div>
<ChevronRightIcon className="h-5 w-5 text-main-view-fg/50 " />
</div>
)
}
return (
<DropdownMenuSubTrigger
data-slot="drop-drawer-sub-trigger"
data-inset={inset}
className={className}
inset={inset}
{...props}
>
{children}
</DropdownMenuSubTrigger>
)
}
function DropDrawerSubContent({
className,
sideOffset = 4,
children,
...props
}: React.ComponentProps<typeof DropdownMenuSubContent>) {
const { isMobile } = useDropDrawerContext()
if (isMobile) {
// For mobile, we don't render the content directly
// It will be rendered by the DropDrawerContent component when active
return null
}
return (
<DropdownMenuSubContent
data-slot="drop-drawer-sub-content"
sideOffset={sideOffset}
className={className}
{...props}
>
{children}
</DropdownMenuSubContent>
)
}
export {
DropDrawer,
DropDrawerContent,
DropDrawerFooter,
DropDrawerGroup,
DropDrawerItem,
DropDrawerLabel,
DropDrawerSeparator,
DropDrawerSub,
DropDrawerSubContent,
DropDrawerSubTrigger,
DropDrawerTrigger,
}

View File

@ -0,0 +1,42 @@
import * as React from 'react'
import * as RadioGroupPrimitive from '@radix-ui/react-radio-group'
import { CircleIcon } from 'lucide-react'
import { cn } from '@/lib/utils'
const RadioGroup = React.forwardRef<
React.ElementRef<typeof RadioGroupPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Root>
>(({ className, ...props }, ref) => {
return (
<RadioGroupPrimitive.Root
className={cn('grid gap-2', className)}
{...props}
ref={ref}
/>
)
})
RadioGroup.displayName = RadioGroupPrimitive.Root.displayName
const RadioGroupItem = React.forwardRef<
React.ElementRef<typeof RadioGroupPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Item>
>(({ className, ...props }, ref) => {
return (
<RadioGroupPrimitive.Item
ref={ref}
className={cn(
'aspect-square size-4 rounded-full border border-main-view-fg/20 text-main-view-fg ring-offset-main-view focus:outline-none focus-visible:ring-2 focus-visible:ring-main-view-fg/50 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50',
className
)}
{...props}
>
<RadioGroupPrimitive.Indicator className="flex items-center justify-center">
<CircleIcon className="size-2.5 fill-current text-current" />
</RadioGroupPrimitive.Indicator>
</RadioGroupPrimitive.Item>
)
})
RadioGroupItem.displayName = RadioGroupPrimitive.Item.displayName
export { RadioGroup, RadioGroupItem }

View File

@ -19,5 +19,6 @@ export const localStorageKey = {
mcpGlobalPermissions: 'mcp-global-permissions',
lastUsedModel: 'last-used-model',
lastUsedAssistant: 'last-used-assistant',
favoriteModels: 'favorite-models',
setupCompleted: 'setup-completed',
}

View File

@ -483,7 +483,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
initialMessage={initialMessage}
onOpenChange={(isOpen) => {
setDropdownToolsAvailable(isOpen)
setTooltipToolsAvailable(false)
if (isOpen) {
setTooltipToolsAvailable(false)
}
}}
>
{(isOpen, toolsCount) => {

View File

@ -6,10 +6,8 @@ import {
import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { useLeftPanel } from '@/hooks/useLeftPanel'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useAppUpdater } from '@/hooks/useAppUpdater'
import { abortDownload } from '@/services/models'
import { getProviders } from '@/services/providers'
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
import { IconDownload, IconX } from '@tabler/icons-react'
import { useCallback, useEffect, useMemo, useState } from 'react'
@ -18,7 +16,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
export function DownloadManagement() {
const { t } = useTranslation()
const { setProviders } = useModelProvider()
const { open: isLeftPanelOpen } = useLeftPanel()
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
const {
@ -185,7 +182,6 @@ export function DownloadManagement() {
console.debug('onFileDownloadSuccess', state)
removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId)
getProviders().then(setProviders)
toast.success(t('common:toast.downloadComplete.title'), {
id: 'download-complete',
description: t('common:toast.downloadComplete.description', {
@ -193,7 +189,7 @@ export function DownloadManagement() {
}),
})
},
[removeDownload, removeLocalDownloadingModel, setProviders, t]
[removeDownload, removeLocalDownloadingModel, t]
)
useEffect(() => {

View File

@ -17,6 +17,8 @@ import ProvidersAvatar from '@/containers/ProvidersAvatar'
import { Fzf } from 'fzf'
import { localStorageKey } from '@/constants/localStorage'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
import { predefinedProviders } from '@/consts/providers'
type DropdownModelProviderProps = {
model?: ThreadModel
@ -69,42 +71,38 @@ const DropdownModelProvider = ({
const { updateCurrentThreadModel } = useThreads()
const navigate = useNavigate()
const { t } = useTranslation()
const { favoriteModels } = useFavoriteModel()
// Search state
const [open, setOpen] = useState(false)
const [searchValue, setSearchValue] = useState('')
const searchInputRef = useRef<HTMLInputElement>(null)
// Helper function to check if a model exists in providers
const checkModelExists = useCallback((providerName: string, modelId: string) => {
const provider = providers.find(
(p) => p.provider === providerName && p.active
)
return provider?.models.find((m) => m.id === modelId)
}, [providers])
// Initialize model provider only once
useEffect(() => {
// Auto select model when existing thread is passed
if (model) {
selectModelProvider(model?.provider as string, model?.id as string)
if (!checkModelExists(model.provider, model.id)) {
selectModelProvider('', '')
}
} else if (useLastUsedModel) {
// Try to use last used model only when explicitly requested (for new chat)
const lastUsed = getLastUsedModel()
if (lastUsed) {
// Verify the last used model still exists
const provider = providers.find(
(p) => p.provider === lastUsed.provider && p.active
)
const modelExists = provider?.models.find(
(m) => m.id === lastUsed.model
)
if (provider && modelExists) {
selectModelProvider(lastUsed.provider, lastUsed.model)
} else {
// Fallback to default model if last used model no longer exists
selectModelProvider('llamacpp', 'llama3.2:3b')
}
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
selectModelProvider(lastUsed.provider, lastUsed.model)
} else {
// default model, we should add from setting
selectModelProvider('llamacpp', 'llama3.2:3b')
// Fallback to default model if last used model no longer exists
selectModelProvider('', '')
}
} else {
// default model for non-new-chat contexts
selectModelProvider('llamacpp', 'llama3.2:3b')
}
}, [
model,
@ -112,6 +110,7 @@ const DropdownModelProvider = ({
updateCurrentThreadModel,
providers,
useLastUsedModel,
checkModelExists,
])
// Update display model when selection changes
@ -151,9 +150,15 @@ const DropdownModelProvider = ({
provider.models.forEach((modelItem) => {
// Skip models that require API key but don't have one (except llamacpp)
if (provider.provider !== 'llamacpp' && !provider.api_key?.length) {
if (
provider &&
predefinedProviders.some((e) =>
e.provider.includes(provider.provider)
) &&
provider.provider !== 'llamacpp' &&
!provider.api_key?.length
)
return
}
const capabilities = modelItem.capabilities || []
const capabilitiesString = capabilities.join(' ')
@ -182,6 +187,13 @@ const DropdownModelProvider = ({
})
}, [searchableItems])
// Get favorite models that are currently available
const favoriteItems = useMemo(() => {
return searchableItems.filter((item) =>
favoriteModels.some((fav) => fav.id === item.model.id)
)
}, [searchableItems, favoriteModels])
// Filter models based on search value
const filteredItems = useMemo(() => {
if (!searchValue) return searchableItems
@ -202,7 +214,7 @@ const DropdownModelProvider = ({
})
}, [searchableItems, searchValue, fzfInstance])
// Group filtered items by provider
// Group filtered items by provider, excluding favorites when not searching
const groupedItems = useMemo(() => {
const groups: Record<string, SearchableModel[]> = {}
@ -221,11 +233,16 @@ const DropdownModelProvider = ({
if (!groups[providerKey]) {
groups[providerKey] = []
}
// When not searching, exclude favorite models from regular provider sections
const isFavorite = favoriteModels.some((fav) => fav.id === item.model.id)
if (!searchValue && isFavorite) return // Skip adding this item to regular provider section
groups[providerKey].push(item)
})
return groups
}, [filteredItems, providers, searchValue])
}, [filteredItems, providers, searchValue, favoriteModels])
const handleSelect = useCallback(
(searchableModel: SearchableModel) => {
@ -330,6 +347,64 @@ const DropdownModelProvider = ({
</div>
) : (
<div className="py-1">
{/* Favorites section - only show when not searching */}
{!searchValue && favoriteItems.length > 0 && (
<div className="bg-main-view-fg/2 backdrop-blur-2xl rounded-sm my-1.5 mx-1.5">
{/* Favorites header */}
<div className="flex items-center gap-1.5 px-2 py-1">
<span className="text-sm font-medium text-main-view-fg/80">
{t('common:favorites')}
</span>
</div>
{/* Favorite models */}
{favoriteItems.map((searchableModel) => {
const isSelected =
selectedModel?.id === searchableModel.model.id &&
selectedProvider === searchableModel.provider.provider
const capabilities =
searchableModel.model.capabilities || []
return (
<div
key={`fav-${searchableModel.value}`}
title={searchableModel.model.id}
onClick={() => handleSelect(searchableModel)}
className={cn(
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
'hover:bg-main-view-fg/4',
isSelected &&
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
)}
>
<div className="flex items-center gap-1 flex-1 min-w-0">
<div className="shrink-0 -ml-1">
<ProvidersAvatar
provider={searchableModel.provider}
/>
</div>
<span className="truncate text-main-view-fg/80 text-sm">
{searchableModel.model.id}
</span>
<div className="flex-1"></div>
{capabilities.length > 0 && (
<div className="flex-shrink-0 -mr-1.5">
<Capabilities capabilities={capabilities} />
</div>
)}
</div>
</div>
)
})}
</div>
)}
{/* Divider between favorites and regular providers */}
{favoriteItems.length > 0 && (
<div className="border-b border-1 border-main-view-fg/8 mx-2"></div>
)}
{/* Regular provider sections */}
{Object.entries(groupedItems).map(([providerKey, models]) => {
const providerInfo = providers.find(
(p) => p.provider === providerKey
@ -340,7 +415,7 @@ const DropdownModelProvider = ({
return (
<div
key={providerKey}
className="bg-main-view-fg/4 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0"
className="bg-main-view-fg/2 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0"
>
{/* Provider header */}
<div className="flex items-center justify-between px-2 py-1">
@ -384,15 +459,20 @@ const DropdownModelProvider = ({
return (
<div
key={searchableModel.value}
title={searchableModel.model.id}
onClick={() => handleSelect(searchableModel)}
className={cn(
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
'hover:bg-main-view-fg/10',
isSelected && 'bg-main-view-fg/15'
'hover:bg-main-view-fg/4',
isSelected &&
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
)}
>
<div className="flex items-center gap-2 flex-1 min-w-0">
<span className="truncate text-main-view-fg/80 text-sm">
<span
className="truncate text-main-view-fg/80 text-sm"
title={searchableModel.model.id}
>
{searchableModel.model.id}
</span>

View File

@ -1,12 +1,18 @@
import { useEffect, useState } from 'react'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
DropDrawer,
DropDrawerContent,
DropDrawerItem,
DropDrawerSub,
DropDrawerLabel,
DropDrawerSubContent,
DropDrawerSeparator,
DropDrawerSubTrigger,
DropDrawerTrigger,
DropDrawerGroup,
} from '@/components/ui/dropdrawer'
import { Switch } from '@/components/ui/switch'
import { useThreads } from '@/hooks/useThreads'
@ -15,6 +21,7 @@ import { useToolAvailable } from '@/hooks/useToolAvailable'
import React from 'react'
import { useAppState } from '@/hooks/useAppState'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { cn } from '@/lib/utils'
interface DropdownToolsAvailableProps {
children: (isOpen: boolean, toolsCount: number) => React.ReactNode
@ -82,6 +89,23 @@ export default function DropdownToolsAvailable({
return false
}
const handleDisableAllServerTools = (
serverName: string,
disable: boolean
) => {
const allToolsByServer = getToolsByServer()
const serverTools = allToolsByServer[serverName] || []
serverTools.forEach((tool) => {
handleToolToggle(tool.name, !disable)
})
}
const areAllServerToolsDisabled = (serverName: string): boolean => {
const allToolsByServer = getToolsByServer()
const serverTools = allToolsByServer[serverName] || []
return serverTools.every((tool) => !isToolChecked(tool.name))
}
const getEnabledToolsCount = (): number => {
const disabledTools = initialMessage
? getDefaultDisabledTools()
@ -91,69 +115,153 @@ export default function DropdownToolsAvailable({
return tools.filter((tool) => !disabledTools.includes(tool.name)).length
}
const getToolsByServer = () => {
const toolsByServer = tools.reduce(
(acc, tool) => {
if (!acc[tool.server]) {
acc[tool.server] = []
}
acc[tool.server].push(tool)
return acc
},
{} as Record<string, typeof tools>
)
return toolsByServer
}
const renderTrigger = () => children(isOpen, getEnabledToolsCount())
if (tools.length === 0) {
return (
<DropdownMenu onOpenChange={handleOpenChange}>
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="max-w-64">
<DropdownMenuItem disabled>{t('common:noToolsAvailable')}</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
<DropDrawer onOpenChange={handleOpenChange}>
<DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
<DropDrawerContent align="start" className="max-w-64">
<DropDrawerItem disabled>
{t('common:noToolsAvailable')}
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
}
return (
<DropdownMenu onOpenChange={handleOpenChange}>
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
const toolsByServer = getToolsByServer()
<DropdownMenuContent
return (
<DropDrawer onOpenChange={handleOpenChange}>
<DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
<DropDrawerContent
side="top"
align="start"
className="max-w-64 backdrop-blur-xl bg-main-view"
className="bg-main-view !overflow-hidden"
onClick={(e) => e.stopPropagation()}
>
<DropdownMenuLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-2 ">
<DropDrawerLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-1">
Available Tools
</DropdownMenuLabel>
<DropdownMenuSeparator />
</DropDrawerLabel>
<DropDrawerSeparator />
<div className="max-h-64 overflow-y-auto">
{tools.map((tool) => {
const isChecked = isToolChecked(tool.name)
return (
<div
key={tool.name}
className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full"
<DropDrawerGroup>
{Object.entries(toolsByServer).map(([serverName, serverTools]) => (
<DropDrawerSub
id={`server-${serverName}`}
key={serverName}
title={serverName}
>
<div className="flex items-start justify-center gap-3">
<div className="flex items-start justify-between gap-4 w-full">
<div className="overflow-hidden w-full flex flex-col ">
<div className="truncate">
<span className="text-sm font-medium" title={tool.name}>
{tool.name}
</span>
</div>
{tool.description && (
<p className="text-xs text-main-view-fg/70 mt-1 line-clamp-2">
{tool.description}
</p>
)}
</div>
<div className="shrink-0 mx-auto">
<Switch
checked={isChecked}
onCheckedChange={(checked) =>
handleToolToggle(tool.name, checked)
}
/>
</div>
<DropDrawerSubTrigger className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full">
<div className="flex items-center justify-between w-full">
<span className="text-sm text-main-view-fg/80">
{serverName}
</span>
<span className="text-xs text-main-view-fg/50 inline-flex items-center mr-1 border border-main-view-fg/20 px-1 rounded-sm">
{
serverTools.filter((tool) => isToolChecked(tool.name))
.length
}
</span>
</div>
</div>
</div>
)
})}
</DropDrawerSubTrigger>
<DropDrawerSubContent className="max-w-64 max-h-70 w-full overflow-hidden">
<DropDrawerGroup>
{serverTools.length > 1 && (
<div className="sticky top-0 z-10 bg-main-view border-b border-main-view-fg/10 px-4 md:px-2 pr-2 py-1.5 flex items-center justify-between">
<span className="text-xs font-medium text-main-view-fg/70">
Disable All Tools
</span>
<div
className={cn(
'flex items-center gap-2',
serverTools.length > 5
? 'mr-3 md:mr-1.5'
: 'mr-2 md:mr-0'
)}
>
<Switch
checked={!areAllServerToolsDisabled(serverName)}
onCheckedChange={(checked) =>
handleDisableAllServerTools(serverName, !checked)
}
/>
</div>
</div>
)}
<div className="max-h-56 overflow-y-auto">
{serverTools.map((tool) => {
const isChecked = isToolChecked(tool.name)
return (
<DropDrawerItem
onClick={(e) => {
handleToolToggle(tool.name, !isChecked)
e.preventDefault()
}}
onSelect={(e) => {
handleToolToggle(tool.name, !isChecked)
e.preventDefault()
}}
key={tool.name}
className="mt-1 first:mt-0 py-1.5"
icon={
<Switch
checked={isChecked}
onCheckedChange={(checked) => {
console.log('checked', checked)
handleToolToggle(tool.name, checked)
}}
onClick={(e) => {
e.stopPropagation()
}}
/>
}
>
<div className="overflow-hidden flex flex-col items-start ">
<div className="truncate">
<span
className="text-sm font-medium text-main-view-fg"
title={tool.name}
>
{tool.name}
</span>
</div>
{tool.description && (
<p
className="text-xs text-main-view-fg/70 mt-1 line-clamp-1"
title={tool.description}
>
{tool.description}
</p>
)}
</div>
</DropDrawerItem>
)
})}
</div>
</DropDrawerGroup>
</DropDrawerSubContent>
</DropDrawerSub>
))}
</DropDrawerGroup>
</div>
</DropdownMenuContent>
</DropdownMenu>
</DropDrawerContent>
</DropDrawer>
)
}

View File

@ -0,0 +1,25 @@
import { IconStar, IconStarFilled } from '@tabler/icons-react'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
interface FavoriteModelActionProps {
model: Model
}
export function FavoriteModelAction({ model }: FavoriteModelActionProps) {
const { isFavorite, toggleFavorite } = useFavoriteModel()
const isModelFavorite = isFavorite(model.id)
return (
<div
aria-label="Toggle favorite"
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => toggleFavorite(model)}
>
{isModelFavorite ? (
<IconStarFilled size={18} className="text-main-view-fg" />
) : (
<IconStar size={18} className="text-main-view-fg/50" />
)}
</div>
)
}

View File

@ -21,15 +21,15 @@ export default function GlobalError({ error }: GlobalErrorProps) {
>
<path
d="M6 8H6.01M6 16H6.01M6 12H18C20.2091 12 22 10.2091 22 8C22 5.79086 20.2091 4 18 4H6C3.79086 4 2 5.79086 2 8C2 10.2091 3.79086 12 6 12ZM6 12C3.79086 12 2 13.7909 2 16C2 18.2091 3.79086 20 6 20H14"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
></path>
<path
d="M17 16L22 21M22 16L17 21"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
></path>
</svg>
</div>

View File

@ -8,7 +8,13 @@ import {
} from '@/components/ui/dialog'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { IconPlus, IconTrash, IconGripVertical } from '@tabler/icons-react'
import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'
import {
IconPlus,
IconTrash,
IconGripVertical,
IconCodeDots,
} from '@tabler/icons-react'
import { MCPServerConfig } from '@/hooks/useMCPServers'
import { useTranslation } from '@/i18n/react-i18next-compat'
import {
@ -27,6 +33,8 @@ import {
} from '@dnd-kit/sortable'
import { CSS } from '@dnd-kit/utilities'
import { cn } from '@/lib/utils'
import CodeEditor from '@uiw/react-textarea-code-editor'
import '@uiw/react-textarea-code-editor/dist.css'
// Sortable argument item component
function SortableArgItem({
@ -114,13 +122,34 @@ export default function AddEditMCPServer({
const [args, setArgs] = useState<string[]>([''])
const [envKeys, setEnvKeys] = useState<string[]>([''])
const [envValues, setEnvValues] = useState<string[]>([''])
const [transportType, setTransportType] = useState<'stdio' | 'http' | 'sse'>(
'stdio'
)
const [url, setUrl] = useState('')
const [headerKeys, setHeaderKeys] = useState<string[]>([''])
const [headerValues, setHeaderValues] = useState<string[]>([''])
const [timeout, setTimeout] = useState('')
const [isToggled, setIsToggled] = useState(false)
const [jsonContent, setJsonContent] = useState('')
const [error, setError] = useState<string | null>(null)
// Reset form when modal opens/closes or editing key changes
useEffect(() => {
if (open && editingKey && initialData) {
setServerName(editingKey)
setCommand(initialData.command)
setCommand(initialData.command || '')
setUrl(initialData.url || '')
setTimeout(initialData.timeout ? initialData.timeout.toString() : '')
setArgs(initialData.args?.length > 0 ? initialData.args : [''])
setTransportType(initialData?.type || 'stdio')
// Initialize JSON content for toggle mode
try {
const jsonData = { [editingKey]: initialData }
setJsonContent(JSON.stringify(jsonData, null, 2))
} catch {
setJsonContent('')
}
if (initialData.env) {
// Convert env object to arrays of keys and values
@ -130,6 +159,17 @@ export default function AddEditMCPServer({
setEnvKeys(keys.length > 0 ? keys : [''])
setEnvValues(values.length > 0 ? values : [''])
}
if (initialData.headers) {
// Convert headers object to arrays of keys and values
const headerKeysList = Object.keys(initialData.headers)
const headerValuesList = headerKeysList.map(
(key) => initialData.headers![key]
)
setHeaderKeys(headerKeysList.length > 0 ? headerKeysList : [''])
setHeaderValues(headerValuesList.length > 0 ? headerValuesList : [''])
}
} else if (open) {
// Add mode - reset form
resetForm()
@ -139,9 +179,17 @@ export default function AddEditMCPServer({
const resetForm = () => {
setServerName('')
setCommand('')
setUrl('')
setTimeout('')
setArgs([''])
setEnvKeys([''])
setEnvValues([''])
setHeaderKeys([''])
setHeaderValues([''])
setTransportType('stdio')
setIsToggled(false)
setJsonContent('')
setError(null)
}
const handleAddArg = () => {
@ -201,7 +249,57 @@ export default function AddEditMCPServer({
setEnvValues(newValues)
}
const handleAddHeader = () => {
setHeaderKeys([...headerKeys, ''])
setHeaderValues([...headerValues, ''])
}
const handleRemoveHeader = (index: number) => {
const newKeys = [...headerKeys]
const newValues = [...headerValues]
newKeys.splice(index, 1)
newValues.splice(index, 1)
setHeaderKeys(newKeys.length > 0 ? newKeys : [''])
setHeaderValues(newValues.length > 0 ? newValues : [''])
}
const handleHeaderKeyChange = (index: number, value: string) => {
const newKeys = [...headerKeys]
newKeys[index] = value
setHeaderKeys(newKeys)
}
const handleHeaderValueChange = (index: number, value: string) => {
const newValues = [...headerValues]
newValues[index] = value
setHeaderValues(newValues)
}
const handleSave = () => {
// Handle JSON mode
if (isToggled) {
try {
const parsedData = JSON.parse(jsonContent)
// Validate that it's an object with server configurations
if (typeof parsedData !== 'object' || parsedData === null) {
setError(t('mcp-servers:editJson.errorFormat'))
return
}
// For each server in the JSON, call onSave
Object.entries(parsedData).forEach(([serverName, config]) => {
onSave(serverName.trim(), config as MCPServerConfig)
})
onOpenChange(false)
resetForm()
setError(null)
return
} catch {
setError(t('mcp-servers:editJson.errorFormat'))
return
}
}
// Handle form mode
// Convert env arrays to object
const envObj: Record<string, string> = {}
envKeys.forEach((key, index) => {
@ -211,13 +309,28 @@ export default function AddEditMCPServer({
}
})
// Convert headers arrays to object
const headersObj: Record<string, string> = {}
headerKeys.forEach((key, index) => {
const keyName = key.trim()
if (keyName !== '') {
headersObj[keyName] = headerValues[index]?.trim() || ''
}
})
// Filter out empty args
const filteredArgs = args.map((arg) => arg.trim()).filter((arg) => arg)
const config: MCPServerConfig = {
command: command.trim(),
args: filteredArgs,
env: envObj,
command: transportType === 'stdio' ? command.trim() : '',
args: transportType === 'stdio' ? filteredArgs : [],
env: transportType === 'stdio' ? envObj : {},
type: transportType,
...(transportType !== 'stdio' && {
url: url.trim(),
headers: Object.keys(headersObj).length > 0 ? headersObj : undefined,
timeout: timeout.trim() !== '' ? parseInt(timeout) : undefined,
}),
}
if (serverName.trim() !== '') {
@ -229,122 +342,303 @@ export default function AddEditMCPServer({
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent>
<DialogContent showCloseButton={false}>
<DialogHeader>
<DialogTitle>
{editingKey
? t('mcp-servers:editServer')
: t('mcp-servers:addServer')}
<DialogTitle className="flex items-center justify-between">
<span>
{editingKey
? t('mcp-servers:editServer')
: t('mcp-servers:addServer')}
</span>
<div
className={cn(
'size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out',
isToggled && 'bg-main-view-fg/10 text-accent'
)}
title="Add server by JSON"
onClick={() => setIsToggled(!isToggled)}
>
<IconCodeDots className="h-5 w-5 cursor-pointer transition-colors duration-200" />
</div>
</DialogTitle>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:serverName')}
</label>
<Input
value={serverName}
onChange={(e) => setServerName(e.target.value)}
placeholder={t('mcp-servers:enterServerName')}
autoFocus
/>
</div>
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:command')}
</label>
<Input
value={command}
onChange={(e) => setCommand(e.target.value)}
placeholder={t('mcp-servers:enterCommand')}
/>
</div>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">{t('mcp-servers:arguments')}</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddArg}
>
<IconPlus size={18} className="text-main-view-fg/60" />
{isToggled ? (
<div className="space-y-4">
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:editJson.placeholder')}
</label>
<div className="border border-main-view-fg/10 rounded-md overflow-hidden">
<CodeEditor
value={jsonContent}
language="json"
placeholder={`{
"serverName": {
"command": "command",
"args": ["arg1", "arg2"],
"env": {
"KEY": "value"
}
}
}`}
onChange={(e) => {
setJsonContent(e.target.value)
setError(null)
}}
onPaste={() => setError(null)}
style={{
fontFamily: 'ui-monospace',
backgroundColor: 'transparent',
wordBreak: 'break-all',
overflowWrap: 'anywhere',
whiteSpace: 'pre-wrap',
}}
className="w-full !text-sm min-h-[300px]"
/>
</div>
{error && <div className="text-destructive text-sm">{error}</div>}
</div>
</div>
) : (
<div className="space-y-4">
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:serverName')}
</label>
<Input
value={serverName}
onChange={(e) => setServerName(e.target.value)}
placeholder={t('mcp-servers:enterServerName')}
autoFocus
/>
</div>
<DndContext
sensors={sensors}
collisionDetection={closestCenter}
onDragEnd={(event) => {
const { active, over } = event
if (active.id !== over?.id) {
const oldIndex = parseInt(active.id.toString())
const newIndex = parseInt(over?.id.toString() || '0')
handleReorderArgs(oldIndex, newIndex)
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
Transport Type
</label>
<RadioGroup
value={transportType}
onValueChange={(value) =>
setTransportType(value as 'http' | 'sse')
}
}}
>
<SortableContext
items={args.map((_, index) => index)}
strategy={verticalListSortingStrategy}
className="flex gap-6"
>
{args.map((arg, index) => (
<SortableArgItem
key={index}
id={index}
value={arg}
onChange={(value) => handleArgChange(index, value)}
onRemove={() => handleRemoveArg(index)}
canRemove={args.length > 1}
placeholder={t('mcp-servers:argument', {
index: index + 1,
})}
/>
))}
</SortableContext>
</DndContext>
</div>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">{t('mcp-servers:envVars')}</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddEnv}
>
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="stdio" id="stdio" />
<label
htmlFor="stdio"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
STDIO
</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label
htmlFor="http"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
HTTP
</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label
htmlFor="sse"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
SSE
</label>
</div>
</RadioGroup>
</div>
{envKeys.map((key, index) => (
<div key={`env-${index}`} className="flex items-center gap-2">
{transportType === 'stdio' ? (
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:command')}
</label>
<Input
value={key}
onChange={(e) => handleEnvKeyChange(index, e.target.value)}
placeholder={t('mcp-servers:key')}
className="flex-1"
value={command}
onChange={(e) => setCommand(e.target.value)}
placeholder={t('mcp-servers:enterCommand')}
/>
</div>
) : (
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">URL</label>
<Input
value={envValues[index] || ''}
onChange={(e) => handleEnvValueChange(index, e.target.value)}
placeholder={t('mcp-servers:value')}
className="flex-1"
value={url}
onChange={(e) => setUrl(e.target.value)}
placeholder="Enter URL"
/>
{envKeys.length > 1 && (
</div>
)}
{transportType === 'stdio' && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">
{t('mcp-servers:arguments')}
</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveEnv(index)}
onClick={handleAddArg}
>
<IconTrash size={18} className="text-destructive" />
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
)}
</div>
<DndContext
sensors={sensors}
collisionDetection={closestCenter}
onDragEnd={(event) => {
const { active, over } = event
if (active.id !== over?.id) {
const oldIndex = parseInt(active.id.toString())
const newIndex = parseInt(over?.id.toString() || '0')
handleReorderArgs(oldIndex, newIndex)
}
}}
>
<SortableContext
items={args.map((_, index) => index)}
strategy={verticalListSortingStrategy}
>
{args.map((arg, index) => (
<SortableArgItem
key={index}
id={index}
value={arg}
onChange={(value) => handleArgChange(index, value)}
onRemove={() => handleRemoveArg(index)}
canRemove={args.length > 1}
placeholder={t('mcp-servers:argument', {
index: index + 1,
})}
/>
))}
</SortableContext>
</DndContext>
</div>
))}
)}
{transportType === 'stdio' && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">{t('mcp-servers:envVars')}</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddEnv}
>
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
</div>
{envKeys.map((key, index) => (
<div key={`env-${index}`} className="flex items-center gap-2">
<Input
value={key}
onChange={(e) =>
handleEnvKeyChange(index, e.target.value)
}
placeholder={t('mcp-servers:key')}
className="flex-1"
/>
<Input
value={envValues[index] || ''}
onChange={(e) =>
handleEnvValueChange(index, e.target.value)
}
placeholder={t('mcp-servers:value')}
className="flex-1"
/>
{envKeys.length > 1 && (
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveEnv(index)}
>
<IconTrash size={18} className="text-destructive" />
</div>
)}
</div>
))}
</div>
)}
{(transportType === 'http' || transportType === 'sse') && (
<>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">Headers</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddHeader}
>
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
</div>
{headerKeys.map((key, index) => (
<div
key={`header-${index}`}
className="flex items-center gap-2"
>
<Input
value={key}
onChange={(e) =>
handleHeaderKeyChange(index, e.target.value)
}
placeholder="Header name"
className="flex-1"
/>
<Input
value={headerValues[index] || ''}
onChange={(e) =>
handleHeaderValueChange(index, e.target.value)
}
placeholder="Header value"
className="flex-1"
/>
{headerKeys.length > 1 && (
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveHeader(index)}
>
<IconTrash size={18} className="text-destructive" />
</div>
)}
</div>
))}
</div>
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
Timeout (seconds)
</label>
<Input
value={timeout}
onChange={(e) => setTimeout(e.target.value)}
placeholder="Enter timeout in seconds"
type="number"
/>
</div>
</>
)}
</div>
</div>
)}
<DialogFooter>
<Button onClick={handleSave}>{t('mcp-servers:save')}</Button>
<Button variant="link" onClick={() => onOpenChange(false)}>
{t('common:cancel')}
</Button>
<Button
onClick={handleSave}
disabled={!isToggled && serverName.trim() === ''}
>
{t('mcp-servers:save')}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>

View File

@ -18,6 +18,7 @@ import { IconTrash } from '@tabler/icons-react'
import { useState, useEffect } from 'react'
import { toast } from 'sonner'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
type DialogDeleteModelProps = {
provider: ModelProvider
@ -31,8 +32,12 @@ export const DialogDeleteModel = ({
const { t } = useTranslation()
const [selectedModelId, setSelectedModelId] = useState<string>('')
const { setProviders, deleteModel: deleteModelCache } = useModelProvider()
const { removeFavorite } = useFavoriteModel()
const removeModel = async () => {
// Remove model from favorites if it exists
removeFavorite(selectedModelId)
deleteModelCache(selectedModelId)
deleteModel(selectedModelId).then(() => {
getProviders().then((providers) => {

View File

@ -18,6 +18,7 @@ import { useRouter } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { predefinedProviders } from '@/consts/providers'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
type Props = {
provider?: ProviderObject
@ -25,6 +26,7 @@ type Props = {
const DeleteProvider = ({ provider }: Props) => {
const { t } = useTranslation()
const { deleteProvider, providers } = useModelProvider()
const { favoriteModels, removeFavorite } = useFavoriteModel()
const router = useRouter()
if (
!provider ||
@ -34,6 +36,14 @@ const DeleteProvider = ({ provider }: Props) => {
return null
const removeProvider = async () => {
// Remove favorite models that belong to this provider
const providerModelIds = provider.models.map((model) => model.id)
favoriteModels.forEach((favoriteModel) => {
if (providerModelIds.includes(favoriteModel.id)) {
removeFavorite(favoriteModel.id)
}
})
deleteProvider(provider.provider)
toast.success(t('providers:deleteProvider.title'), {
id: `delete-provider-${provider.provider}`,

View File

@ -43,19 +43,9 @@ export default function EditJsonMCPserver({
}
}, [open, initialData, t])
const handlePaste = (e: React.ClipboardEvent) => {
const pastedText = e.clipboardData.getData('text')
try {
const parsedJson = JSON.parse(pastedText)
const prettifiedJson = JSON.stringify(parsedJson, null, 2)
e.preventDefault()
setJsonContent(prettifiedJson)
setError(null)
} catch (error) {
e.preventDefault()
setError(t('mcp-servers:editJson.errorPaste'))
console.error('Paste error:', error)
}
const handlePaste = () => {
// Clear any existing errors when pasting
setError(null)
}
const handleSave = () => {
@ -80,7 +70,18 @@ export default function EditJsonMCPserver({
</DialogTitle>
</DialogHeader>
<div className="space-y-2">
<div className="border border-main-view-fg/10 rounded-md overflow-hidden">
<div className="border border-main-view-fg/10 rounded-md !overflow-hidden">
<style>{`
.w-tc-editor textarea {
word-break: break-all !important;
overflow-wrap: anywhere !important;
white-space: pre-wrap !important;
}
.w-tc-editor .token.string {
word-break: break-all !important;
overflow-wrap: anywhere !important;
}
`}</style>
<CodeEditor
value={jsonContent}
language="json"
@ -90,8 +91,11 @@ export default function EditJsonMCPserver({
style={{
fontFamily: 'ui-monospace',
backgroundColor: 'transparent',
wordBreak: 'break-all',
overflowWrap: 'anywhere',
whiteSpace: 'pre-wrap',
}}
className="w-full !text-sm "
className="w-full !text-sm overflow-hidden break-all"
/>
</div>
{error && <div className="text-destructive text-sm">{error}</div>}

View File

@ -1,253 +0,0 @@
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { Switch } from '@/components/ui/switch'
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '@/components/ui/tooltip'
import { useModelProvider } from '@/hooks/useModelProvider'
import {
IconPencil,
IconEye,
IconTool,
// IconWorld,
// IconAtom,
IconCodeCircle2,
} from '@tabler/icons-react'
import { useState, useEffect } from 'react'
import { useTranslation } from '@/i18n/react-i18next-compat'
// No need to define our own interface, we'll use the existing Model type
type DialogEditModelProps = {
provider: ModelProvider
modelId?: string // Optional model ID to edit
}
export const DialogEditModel = ({
provider,
modelId,
}: DialogEditModelProps) => {
const { t } = useTranslation()
const { updateProvider } = useModelProvider()
const [selectedModelId, setSelectedModelId] = useState<string>('')
const [capabilities, setCapabilities] = useState<Record<string, boolean>>({
completion: false,
vision: false,
tools: false,
reasoning: false,
embeddings: false,
web_search: false,
})
// Initialize with the provided model ID or the first model if available
useEffect(() => {
if (modelId) {
setSelectedModelId(modelId)
} else if (provider.models && provider.models.length > 0) {
setSelectedModelId(provider.models[0].id)
}
}, [provider, modelId])
// Get the currently selected model
const selectedModel = provider.models.find(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(m: any) => m.id === selectedModelId
)
// Initialize capabilities from selected model
useEffect(() => {
if (selectedModel) {
const modelCapabilities = selectedModel.capabilities || []
setCapabilities({
completion: modelCapabilities.includes('completion'),
vision: modelCapabilities.includes('vision'),
tools: modelCapabilities.includes('tools'),
embeddings: modelCapabilities.includes('embeddings'),
web_search: modelCapabilities.includes('web_search'),
reasoning: modelCapabilities.includes('reasoning'),
})
}
}, [selectedModel])
// Track if capabilities were updated by user action
const [capabilitiesUpdated, setCapabilitiesUpdated] = useState(false)
// Update model capabilities - only update local state
const handleCapabilityChange = (capability: string, enabled: boolean) => {
setCapabilities((prev) => ({
...prev,
[capability]: enabled,
}))
// Mark that capabilities were updated by user action
setCapabilitiesUpdated(true)
}
// Use effect to update the provider when capabilities are explicitly changed by user
useEffect(() => {
// Only run if capabilities were updated by user action and we have a selected model
if (!capabilitiesUpdated || !selectedModel) return
// Reset the flag
setCapabilitiesUpdated(false)
// Create updated capabilities array from the state
const updatedCapabilities = Object.entries(capabilities)
.filter(([, isEnabled]) => isEnabled)
.map(([capName]) => capName)
// Find and update the model in the provider
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const updatedModels = provider.models.map((m: any) => {
if (m.id === selectedModelId) {
return {
...m,
capabilities: updatedCapabilities,
}
}
return m
})
// Update the provider with the updated models
updateProvider(provider.provider, {
...provider,
models: updatedModels,
})
}, [
capabilitiesUpdated,
capabilities,
provider,
selectedModel,
selectedModelId,
updateProvider,
])
if (!selectedModel) {
return null
}
return (
<Dialog>
<DialogTrigger asChild>
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
<IconPencil size={18} className="text-main-view-fg/50" />
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle className="line-clamp-1" title={selectedModel.id}>
{t('providers:editModel.title', { modelId: selectedModel.id })}
</DialogTitle>
<DialogDescription>
{t('providers:editModel.description')}
</DialogDescription>
</DialogHeader>
<div className="py-1">
<h3 className="text-sm font-medium mb-3">
{t('providers:editModel.capabilities')}
</h3>
<div className="space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconTool className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.tools')}
</span>
</div>
<Switch
id="tools-capability"
checked={capabilities.tools}
onCheckedChange={(checked) =>
handleCapabilityChange('tools', checked)
}
/>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconEye className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.vision')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="vision-capability"
checked={capabilities.vision}
disabled={true}
onCheckedChange={(checked) =>
handleCapabilityChange('vision', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconCodeCircle2 className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.embeddings')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="embedding-capability"
disabled={true}
checked={capabilities.embeddings}
onCheckedChange={(checked) =>
handleCapabilityChange('embeddings', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconWorld className="size-4 text-main-view-fg/70" />
<span className="text-sm">Web Search</span>
</div>
<Switch
id="web_search-capability"
checked={capabilities.web_search}
onCheckedChange={(checked) =>
handleCapabilityChange('web_search', checked)
}
/>
</div> */}
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconAtom className="size-4 text-main-view-fg/70" />
<span className="text-sm">{t('reasoning')}</span>
</div>
<Switch
id="reasoning-capability"
checked={capabilities.reasoning}
onCheckedChange={(checked) =>
handleCapabilityChange('reasoning', checked)
}
/>
</div> */}
</div>
</div>
</DialogContent>
</Dialog>
)
}

View File

@ -0,0 +1,180 @@
import { renderHook, act } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { SystemEvent } from '@/types/events'
// Mock functions
const mockGetTools = vi.fn()
const mockUpdateTools = vi.fn()
const mockListen = vi.fn()
const mockUnsubscribe = vi.fn()
// Mock the dependencies
vi.mock('@/services/mcp', () => ({
getTools: mockGetTools,
}))
vi.mock('../useAppState', () => ({
useAppState: () => ({
updateTools: mockUpdateTools,
}),
}))
vi.mock('@tauri-apps/api/event', () => ({
listen: mockListen,
}))
describe('useTools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockListen.mockResolvedValue(mockUnsubscribe)
mockGetTools.mockResolvedValue([])
})
afterEach(() => {
vi.restoreAllMocks()
})
it('should call getTools and updateTools on mount', async () => {
const { useTools } = await import('../useTools')
const mockTools = [
{ name: 'test-tool', description: 'A test tool' },
{ name: 'another-tool', description: 'Another test tool' },
]
mockGetTools.mockResolvedValue(mockTools)
renderHook(() => useTools())
// Wait for async operations to complete
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
})
it('should set up event listener for MCP_UPDATE', async () => {
const { useTools } = await import('../useTools')
renderHook(() => useTools())
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockListen).toHaveBeenCalledWith(
SystemEvent.MCP_UPDATE,
expect.any(Function)
)
})
it('should call setTools when MCP_UPDATE event is triggered', async () => {
const { useTools } = await import('../useTools')
const mockTools = [{ name: 'updated-tool', description: 'Updated tool' }]
mockGetTools.mockResolvedValue(mockTools)
let eventCallback: () => void
mockListen.mockImplementation((_event, callback) => {
eventCallback = callback
return Promise.resolve(mockUnsubscribe)
})
renderHook(() => useTools())
// Wait for initial setup
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
// Clear the initial calls
vi.clearAllMocks()
mockGetTools.mockResolvedValue(mockTools)
// Trigger the event
await act(async () => {
eventCallback()
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
})
it('should return unsubscribe function for cleanup', async () => {
const { useTools } = await import('../useTools')
const { unmount } = renderHook(() => useTools())
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockListen).toHaveBeenCalled()
// Unmount should call the unsubscribe function
unmount()
expect(mockListen).toHaveBeenCalledWith(
SystemEvent.MCP_UPDATE,
expect.any(Function)
)
})
it('should handle getTools errors gracefully', async () => {
const { useTools } = await import('../useTools')
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
mockGetTools.mockRejectedValue(new Error('Failed to get tools'))
renderHook(() => useTools())
await act(async () => {
// Give enough time for the promise to be handled
await new Promise(resolve => setTimeout(resolve, 100))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
// updateTools should not be called if getTools fails
expect(mockUpdateTools).not.toHaveBeenCalled()
consoleErrorSpy.mockRestore()
})
it('should handle event listener setup errors gracefully', async () => {
const { useTools } = await import('../useTools')
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
mockListen.mockRejectedValue(new Error('Failed to set up listener'))
renderHook(() => useTools())
await act(async () => {
// Give enough time for the promise to be handled
await new Promise(resolve => setTimeout(resolve, 100))
})
// Initial getTools should still work
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalled()
consoleErrorSpy.mockRestore()
})
it('should only set up effect once with empty dependency array', async () => {
const { useTools } = await import('../useTools')
const { rerender } = renderHook(() => useTools())
// Initial render
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalledTimes(1)
// Rerender should not trigger additional calls
rerender()
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalledTimes(1)
})
})

View File

@ -1,4 +1,4 @@
import { useCallback, useEffect, useMemo } from 'react'
import { useCallback, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
@ -20,10 +20,7 @@ import { CompletionMessagesBuilder } from '@/lib/messages'
import { renderInstructions } from '@/lib/instructionTemplate'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models'
import { useToolApproval } from '@/hooks/useToolApproval'
@ -33,6 +30,10 @@ import { updateSettings } from '@/services/providers'
import { useContextSizeApproval } from './useModelContextApproval'
import { useModelLoad } from './useModelLoad'
import { useGeneralSetting } from './useGeneralSetting'
import {
ReasoningProcessor,
extractReasoningFromMessage,
} from '@/utils/reasoning'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
@ -41,7 +42,6 @@ export const useChat = () => {
tools,
updateTokenSpeed,
resetTokenSpeed,
updateTools,
updateStreamingContent,
updateLoadingModel,
setAbortController,
@ -78,22 +78,6 @@ export const useChat = () => {
const selectedAssistant =
assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
@ -306,18 +290,27 @@ export const useChat = () => {
const toolCalls: ChatCompletionMessageToolCall[] = []
try {
if (isCompletionResponse(completion)) {
accumulatedText =
(completion.choices[0]?.message?.content as string) || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || ''
// Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
}
if (message?.tool_calls) {
toolCalls.push(...message.tool_calls)
}
} else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const scheduleFlush = () => {
if (rafScheduled) return
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
@ -329,6 +322,12 @@ export const useChat = () => {
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
@ -349,7 +348,10 @@ export const useChat = () => {
}
const flushIfPending = () => {
if (!rafScheduled) return
if (typeof cancelAnimationFrame !== 'undefined' && rafHandle !== undefined) {
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
@ -372,31 +374,63 @@ export const useChat = () => {
pendingDeltaCount = 0
rafScheduled = false
}
for await (const part of completion) {
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
const delta = part.choices[0]?.delta?.content || ''
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
if (delta) {
accumulatedText += delta
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
} catch (error) {
const errorMessage =

View File

@ -0,0 +1,53 @@
import { create } from 'zustand'
import { persist, createJSONStorage } from 'zustand/middleware'
import { localStorageKey } from '@/constants/localStorage'
interface FavoriteModelState {
favoriteModels: Model[]
addFavorite: (model: Model) => void
removeFavorite: (modelId: string) => void
isFavorite: (modelId: string) => boolean
toggleFavorite: (model: Model) => void
}
export const useFavoriteModel = create<FavoriteModelState>()(
persist(
(set, get) => ({
favoriteModels: [],
addFavorite: (model: Model) => {
set((state) => {
if (!state.favoriteModels.some((fav) => fav.id === model.id)) {
return {
favoriteModels: [...state.favoriteModels, model],
}
}
return state
})
},
removeFavorite: (modelId: string) => {
set((state) => ({
favoriteModels: state.favoriteModels.filter((model) => model.id !== modelId),
}))
},
isFavorite: (modelId: string) => {
return get().favoriteModels.some((model) => model.id === modelId)
},
toggleFavorite: (model: Model) => {
const { isFavorite, addFavorite, removeFavorite } = get()
if (isFavorite(model.id)) {
removeFavorite(model.id)
} else {
addFavorite(model)
}
},
}),
{
name: localStorageKey.favoriteModels,
storage: createJSONStorage(() => localStorage),
}
)
)

View File

@ -7,6 +7,10 @@ export type MCPServerConfig = {
args: string[]
env: Record<string, string>
active?: boolean
type?: 'stdio' | 'http' | 'sse'
url?: string
headers?: Record<string, string>
timeout?: number
}
// Define the structure of all MCP servers

View File

@ -0,0 +1,31 @@
import { useEffect } from 'react'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { useAppState } from './useAppState'
export const useTools = () => {
const { updateTools } = useAppState()
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
}).catch((error) => {
console.error('Failed to fetch MCP tools:', error)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
}).catch((error) => {
console.error('Failed to set up MCP update listener:', error)
})
return unsubscribe
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
}

View File

@ -17,6 +17,7 @@ import {
import { useNavigate } from '@tanstack/react-router'
import { route } from '@/constants/routes'
import { useThreads } from '@/hooks/useThreads'
import { AppEvent, events } from '@janhq/core'
export function DataProvider() {
const { setProviders } = useModelProvider()
@ -70,6 +71,13 @@ export function DataProvider() {
}
}, [checkForUpdate])
useEffect(() => {
events.on(AppEvent.onModelImported, () => {
getProviders().then(setProviders)
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const handleDeepLink = (urls: string[] | null) => {
if (!urls) return
console.log('Received deeplink:', urls)

View File

@ -17,7 +17,12 @@ import { useModelProvider } from '@/hooks/useModelProvider'
import { Card, CardItem } from '@/containers/Card'
import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { extractModelName, extractDescription } from '@/lib/models'
import { IconDownload, IconFileCode, IconSearch } from '@tabler/icons-react'
import {
IconDownload,
IconFileCode,
IconSearch,
IconTool,
} from '@tabler/icons-react'
import { Switch } from '@/components/ui/switch'
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
@ -133,7 +138,10 @@ function Hub() {
if (debouncedSearchValue.length) {
const fuse = new Fuse(filtered, searchOptions)
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
const cleanedSearchValue = debouncedSearchValue.replace(/^https?:\/\/[^/]+\//, '')
const cleanedSearchValue = debouncedSearchValue.replace(
/^https?:\/\/[^/]+\//,
''
)
filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
}
// Apply downloaded filter
@ -647,6 +655,15 @@ function Hub() {
?.length || 0}
</span>
</div>
{filteredModels[virtualItem.index].tools && (
<div className="flex items-center gap-1">
<IconTool
size={17}
className="text-main-view-fg/50"
title={t('hub:tools')}
/>
</div>
)}
{filteredModels[virtualItem.index].quants.length >
1 && (
<div className="flex items-center gap-2 hub-show-variants-step">

View File

@ -3,6 +3,7 @@ import { createFileRoute, useSearch } from '@tanstack/react-router'
import ChatInput from '@/containers/ChatInput'
import HeaderPage from '@/containers/HeaderPage'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { useTools } from '@/hooks/useTools'
import { useModelProvider } from '@/hooks/useModelProvider'
import SetupScreen from '@/containers/SetupScreen'
@ -31,6 +32,7 @@ function Index() {
const search = useSearch({ from: route.home as any })
const selectedModel = search.model
const { setCurrentThreadId } = useThreads()
useTools()
// Conditional to check if there are any valid providers
// required min 1 api_key or 1 model in llama.cpp

View File

@ -29,6 +29,58 @@ const maskSensitiveValue = (value: string) => {
return value.slice(0, 4) + '*'.repeat(value.length - 8) + value.slice(-4)
}
// Function to mask sensitive URL parameters
const maskSensitiveUrl = (url: string) => {
if (!url) return url
try {
const urlObj = new URL(url)
const params = urlObj.searchParams
// List of sensitive parameter names (case-insensitive)
const sensitiveParams = [
'api_key',
'apikey',
'key',
'token',
'secret',
'password',
'pwd',
'auth',
'authorization',
'bearer',
'access_token',
'refresh_token',
'client_secret',
'private_key',
'signature',
'hash',
]
// Mask sensitive parameters
sensitiveParams.forEach((paramName) => {
// Check both exact match and case-insensitive match
for (const [key, value] of params.entries()) {
if (key.toLowerCase() === paramName.toLowerCase()) {
params.set(key, maskSensitiveValue(value))
}
}
})
// Reconstruct URL with masked parameters
urlObj.search = params.toString()
return urlObj.toString()
} catch {
// If URL parsing fails, just mask the entire query string after '?'
const queryIndex = url.indexOf('?')
if (queryIndex === -1) return url
const baseUrl = url.substring(0, queryIndex + 1)
const queryString = url.substring(queryIndex + 1)
return baseUrl + maskSensitiveValue(queryString)
}
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.mcp_servers as any)({
component: MCPServers,
@ -195,6 +247,7 @@ function MCPServers() {
getConnectedServers().then(setConnectedServers)
})
.catch((error) => {
console.log(error, 'error.mcp')
editServer(serverKey, {
...(config ?? (mcpServers[serverKey] as MCPServerConfig)),
active: false,
@ -326,22 +379,56 @@ function MCPServers() {
}
descriptionOutside={
<div className="text-sm text-main-view-fg/70">
<div>
{t('mcp-servers:command')}: {config.command}
<div className="mb-1">
Transport:{' '}
<span className="uppercase">
{config.type || 'stdio'}
</span>
</div>
<div className="my-1 break-all">
{t('mcp-servers:args')}: {config?.args?.join(', ')}
</div>
{config.env && Object.keys(config.env).length > 0 && (
<div className="break-all">
{t('mcp-servers:env')}:{' '}
{Object.entries(config.env)
.map(
([key, value]) =>
`${key}=${maskSensitiveValue(value)}`
)
.join(', ')}
</div>
{config.type === 'stdio' || !config.type ? (
<>
<div>
{t('mcp-servers:command')}: {config.command}
</div>
<div className="my-1 break-all">
{t('mcp-servers:args')}:{' '}
{config?.args?.join(', ')}
</div>
{config.env &&
Object.keys(config.env).length > 0 && (
<div className="break-all">
{t('mcp-servers:env')}:{' '}
{Object.entries(config.env)
.map(
([key, value]) =>
`${key}=${maskSensitiveValue(value)}`
)
.join(', ')}
</div>
)}
</>
) : (
<>
<div className="break-all">
URL: {maskSensitiveUrl(config.url || '')}
</div>
{config.headers &&
Object.keys(config.headers).length > 0 && (
<div className="my-1 break-all">
Headers:{' '}
{Object.entries(config.headers)
.map(
([key, value]) =>
`${key}=${maskSensitiveValue(value)}`
)
.join(', ')}
</div>
)}
{config.timeout && (
<div>Timeout: {config.timeout}s</div>
)}
</>
)}
</div>
}

View File

@ -22,10 +22,10 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
import Capabilities from '@/containers/Capabilities'
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { DialogEditModel } from '@/containers/dialogs/EditModel'
import { DialogAddModel } from '@/containers/dialogs/AddModel'
import { ModelSetting } from '@/containers/ModelSetting'
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
import { FavoriteModelAction } from '@/containers/FavoriteModelAction'
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import { route } from '@/constants/routes'
@ -554,17 +554,24 @@ function ProviderDetail() {
</div>
}
actions={
<div className="flex items-center gap-1">
<DialogEditModel
provider={provider}
modelId={model.id}
/>
<div className="flex items-center gap-0.5">
{model.settings && (
<ModelSetting
provider={provider}
model={model}
/>
)}
{((provider &&
!predefinedProviders.some(
(p) => p.provider === provider.provider
)) ||
(provider &&
predefinedProviders.some(
(p) => p.provider === provider.provider
) &&
Boolean(provider.api_key?.length))) && (
<FavoriteModelAction model={model} />
)}
<DialogDeleteModel
provider={provider}
modelId={model.id}

View File

@ -25,6 +25,7 @@ import { useCallback, useState } from 'react'
import { openAIProviderSettings } from '@/consts/providers'
import cloneDeep from 'lodash/cloneDeep'
import { toast } from 'sonner'
import { stopAllModels } from '@/services/models'
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.model_providers as any)({
@ -169,7 +170,10 @@ function ModelProviders() {
)}
<Switch
checked={provider.active}
onCheckedChange={(e) => {
onCheckedChange={async (e) => {
if (!e && provider.provider.toLowerCase() === 'llamacpp') {
await stopAllModels()
}
updateProvider(provider.provider, {
...provider,
active: e,

View File

@ -23,6 +23,7 @@ import { ContentType, ThreadMessage } from '@janhq/core'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { useChat } from '@/hooks/useChat'
import { useSmallScreen } from '@/hooks/useMediaQuery'
import { useTools } from '@/hooks/useTools'
// as route.threadsDetail
export const Route = createFileRoute('/threads/$threadId')({
@ -43,6 +44,7 @@ function ThreadDetail() {
const { appMainViewBgColor, chatWidth } = useAppearance()
const { sendMessage } = useChat()
const isSmallScreen = useSmallScreen()
useTools()
const { messages } = useMessages(
useShallow((state) => ({

View File

@ -5,7 +5,7 @@ import {
updateSettings,
} from '../providers'
import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/mock/data'
import { predefinedProviders } from '@/consts/providers'
import { EngineManager } from '@janhq/core'
import { fetchModels } from '../models'
import { ExtensionManager } from '@/lib/extension'
@ -21,7 +21,7 @@ vi.mock('token.js', () => ({
},
}))
vi.mock('@/mock/data', () => ({
vi.mock('@/consts/providers', () => ({
predefinedProviders: [
{
active: true,
@ -69,6 +69,7 @@ vi.mock('../models', () => ({
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
])
),
isToolSupported: vi.fn(() => Promise.resolve(false)),
}))
vi.mock('@/lib/extension', () => ({
@ -116,7 +117,7 @@ describe('providers service', () => {
it('should return builtin and runtime providers', async () => {
const providers = await getProviders()
expect(providers).toHaveLength(9) // 8 runtime + 1 builtin
expect(providers).toHaveLength(2) // 1 runtime + 1 builtin (mocked)
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
expect(providers.some((p) => p.provider === 'openai')).toBe(true)
})
@ -156,7 +157,7 @@ describe('providers service', () => {
provider: 'openai',
base_url: 'https://api.openai.com/v1',
api_key: 'test-key',
} as ModelProvider
}
const models = await fetchModelsFromProvider(provider)
@ -185,7 +186,7 @@ describe('providers service', () => {
provider: 'custom',
base_url: 'https://api.custom.com',
api_key: '',
} as ModelProvider
}
const models = await fetchModelsFromProvider(provider)
@ -204,7 +205,7 @@ describe('providers service', () => {
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
}
const models = await fetchModelsFromProvider(provider)
@ -214,7 +215,7 @@ describe('providers service', () => {
it('should throw error when provider has no base_url', async () => {
const provider = {
provider: 'custom',
} as ModelProvider
}
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Provider must have base_url configured'
@ -232,10 +233,10 @@ describe('providers service', () => {
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
}
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com'
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
)
})
@ -245,10 +246,10 @@ describe('providers service', () => {
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
}
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com'
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
)
})
@ -264,7 +265,7 @@ describe('providers service', () => {
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
}
const models = await fetchModelsFromProvider(provider)
@ -298,7 +299,7 @@ describe('providers service', () => {
controller_type: 'input',
controller_props: { value: 'test-key' },
},
] as ProviderSetting[]
]
await updateSettings('openai', settings)
@ -324,7 +325,7 @@ describe('providers service', () => {
mockExtensionManager
)
const settings = [] as ProviderSetting[]
const settings = []
const result = await updateSettings('nonexistent', settings)
@ -350,7 +351,7 @@ describe('providers service', () => {
controller_type: 'input',
controller_props: { value: undefined },
},
] as ProviderSetting[]
]
await updateSettings('openai', settings)

View File

@ -29,6 +29,7 @@ export interface CatalogModel {
mmproj_models?: MMProjModel[]
created_at?: string
readme?: string
tools?: boolean
}
export type ModelCatalog = CatalogModel[]
@ -313,3 +314,16 @@ export const startModel = async (
throw error
})
}
/**
* Check if model support tool use capability
* Returned by backend engine
* @param modelId
* @returns
*/
export const isToolSupported = async (modelId: string): Promise<boolean> => {
const engine = getEngine()
if (!engine) return false
return engine.isToolSupported(modelId)
}

View File

@ -1,12 +1,9 @@
import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/consts/providers'
import { EngineManager, SettingComponentProps } from '@janhq/core'
import {
DefaultToolUseSupportedModels,
ModelCapabilities,
} from '@/types/models'
import { ModelCapabilities } from '@/types/models'
import { modelSettings } from '@/lib/predefined'
import { fetchModels } from './models'
import { fetchModels, isToolSupported } from './models'
import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
@ -65,52 +62,41 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
controller_props: setting.controllerProps as unknown,
}
}) as ProviderSetting[],
models: models.map((model) => ({
id: model.id,
model: model.id,
name: model.name,
description: model.description,
capabilities:
'capabilities' in model
? (model.capabilities as string[])
: [
ModelCapabilities.COMPLETION,
...(Object.values(DefaultToolUseSupportedModels).some((v) =>
model.id.toLowerCase().includes(v.toLowerCase())
)
? [ModelCapabilities.TOOLS]
: []),
],
provider: providerName,
settings: Object.values(modelSettings).reduce(
(acc, setting) => {
let value = setting.controller_props.value
if (setting.key === 'ctx_len') {
value = 8192 // Default context length for Llama.cpp models
}
// Set temperature to 0.6 for DefaultToolUseSupportedModels
if (
Object.values(DefaultToolUseSupportedModels).some((v) =>
model.id.toLowerCase().includes(v.toLowerCase())
)
) {
if (setting.key === 'temperature') value = 0.7 // Default temperature for tool-supported models
if (setting.key === 'top_k') value = 20 // Default top_k for tool-supported models
if (setting.key === 'top_p') value = 0.8 // Default top_p for tool-supported models
if (setting.key === 'min_p') value = 0 // Default min_p for tool-supported models
}
acc[setting.key] = {
...setting,
controller_props: {
...setting.controller_props,
value: value,
},
}
return acc
},
{} as Record<string, ProviderSetting>
),
})),
models: await Promise.all(
models.map(
async (model) =>
({
id: model.id,
model: model.id,
name: model.name,
description: model.description,
capabilities:
'capabilities' in model
? (model.capabilities as string[])
: (await isToolSupported(model.id))
? [ModelCapabilities.TOOLS]
: [],
provider: providerName,
settings: Object.values(modelSettings).reduce(
(acc, setting) => {
let value = setting.controller_props.value
if (setting.key === 'ctx_len') {
value = 8192 // Default context length for Llama.cpp models
}
acc[setting.key] = {
...setting,
controller_props: {
...setting.controller_props,
value: value,
},
}
return acc
},
{} as Record<string, ProviderSetting>
),
}) as Model
)
),
}
runtimeProviders.push(provider)
}

View File

@ -6,4 +6,5 @@ export type MCPTool = {
name: string
description: string
inputSchema: Record<string, unknown>
server: string
}

View File

@ -13,11 +13,4 @@ export enum ModelCapabilities {
IMAGE_TO_IMAGE = 'image_to_image',
TEXT_TO_AUDIO = 'text_to_audio',
AUDIO_TO_TEXT = 'audio_to_text',
}
// TODO: Remove this enum when we integrate llama.cpp extension
export enum DefaultToolUseSupportedModels {
JanNano = 'jan-',
Qwen3 = 'qwen3',
Lucy = 'lucy',
}
}

View File

@ -0,0 +1,74 @@
import { CompletionResponseChunk } from 'token.js'
import {
chatCompletionChunk,
ChatCompletionMessage,
chatCompletionRequestMessage,
} from '@janhq/core'
// Extract reasoning from a message (for completed responses)
export function extractReasoningFromMessage(
message: chatCompletionRequestMessage | ChatCompletionMessage
): string | null {
if (!message) return null
const extendedMessage = message as chatCompletionRequestMessage
return extendedMessage.reasoning || null
}
// Extract reasoning from a chunk (for streaming responses)
function extractReasoningFromChunk(
chunk: CompletionResponseChunk | chatCompletionChunk
): string | null {
if (!chunk.choices?.[0]?.delta) return null
const delta = chunk.choices[0].delta as chatCompletionRequestMessage
const reasoning = delta.reasoning
// Return null for falsy values, non-strings, or whitespace-only strings
if (!reasoning || typeof reasoning !== 'string' || !reasoning.trim())
return null
return reasoning
}
// Tracks reasoning state and appends reasoning tokens with proper think tags
export class ReasoningProcessor {
private isReasoningActive = false
processReasoningChunk(
chunk: CompletionResponseChunk | chatCompletionChunk
): string {
const reasoning = extractReasoningFromChunk(chunk)
const chunkContent = chunk.choices?.[0]?.delta?.content || ''
// Handle reasoning tokens
if (reasoning) {
if (!this.isReasoningActive) {
this.isReasoningActive = true
return '<think>' + reasoning
}
return reasoning
}
// Handle reasoning end when content starts
if (this.isReasoningActive && chunkContent) {
this.isReasoningActive = false
return '</think>'
}
// No reasoning to process
return ''
}
finalize(): string {
if (this.isReasoningActive) {
this.isReasoningActive = false
return '</think>'
}
return ''
}
isReasoningInProgress(): boolean {
return this.isReasoningActive
}
}