Merge branch 'dev' into current-date-instruction
@ -6,6 +6,7 @@ import { EngineManager } from './EngineManager'
|
||||
export interface chatCompletionRequestMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool'
|
||||
content: string | null | Content[] // Content can be a string OR an array of content parts
|
||||
reasoning?: string | null // Some models return reasoning in completed responses
|
||||
name?: string
|
||||
tool_calls?: any[] // Simplified tool_call_id?: string
|
||||
}
|
||||
@ -270,4 +271,10 @@ export abstract class AIEngine extends BaseExtension {
|
||||
* Optional method to get the underlying chat client
|
||||
*/
|
||||
getChatClient?(sessionId: string): any
|
||||
|
||||
/**
|
||||
* Check if a tool is supported by the model
|
||||
* @param modelId
|
||||
*/
|
||||
abstract isToolSupported(modelId: string): Promise<boolean>
|
||||
}
|
||||
|
||||
@ -58,6 +58,7 @@ export enum AppEvent {
|
||||
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
|
||||
onAppUpdateDownloadError = 'onAppUpdateDownloadError',
|
||||
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
|
||||
onModelImported = 'onModelImported',
|
||||
|
||||
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
|
||||
onSelectedText = 'onSelectedText',
|
||||
|
||||
BIN
docs/src/pages/post/_assets/claude-agent.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
docs/src/pages/post/_assets/claude-report-visualizer.png
Normal file
|
After Width: | Height: | Size: 39 KiB |
BIN
docs/src/pages/post/_assets/deepresearch-flow.png
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
docs/src/pages/post/_assets/edit-mcp-settings.gif
Normal file
|
After Width: | Height: | Size: 944 KiB |
BIN
docs/src/pages/post/_assets/enable-tools-local.gif
Normal file
|
After Width: | Height: | Size: 1.3 MiB |
BIN
docs/src/pages/post/_assets/experimental-settings-jan.png
Normal file
|
After Width: | Height: | Size: 316 KiB |
BIN
docs/src/pages/post/_assets/jan-nano-hub.png
Normal file
|
After Width: | Height: | Size: 382 KiB |
BIN
docs/src/pages/post/_assets/openai-deep-research-flow.png
Normal file
|
After Width: | Height: | Size: 74 KiB |
BIN
docs/src/pages/post/_assets/research-result-local.png
Normal file
|
After Width: | Height: | Size: 176 KiB |
BIN
docs/src/pages/post/_assets/revised-deepresearch-flow.png
Normal file
|
After Width: | Height: | Size: 95 KiB |
BIN
docs/src/pages/post/_assets/successful-serper.png
Normal file
|
After Width: | Height: | Size: 5.9 KiB |
180
docs/src/pages/post/deepresearch.mdx
Normal file
@ -0,0 +1,180 @@
|
||||
---
|
||||
title: "Replicating Deep Research in Jan"
|
||||
description: "A simple guide to replicating Deep Research results for free, with Jan."
|
||||
tags: AI, local models, Jan, GGUF, Deep Research, local AI
|
||||
categories: guides
|
||||
date: 2025-08-04
|
||||
ogImage: _assets/research-result-local.png
|
||||
twitter:
|
||||
card: summary_large_image
|
||||
site: "@jandotai"
|
||||
title: "Replicating Deep Research with Jan"
|
||||
description: "Learn how to replicate Deep Research results with Jan."
|
||||
image: _assets/research-result-local.jpg
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import CTABlog from '@/components/Blog/CTA'
|
||||
|
||||
# Replicating Deep Research in Jan
|
||||
|
||||
Deep Research like that of OpenAI, Gemini, and Qwen, is not at feature parity in Jan yet, so this post
|
||||
highlights our initial steps at making a hybrid, local and cloud-based deep research system that competes
|
||||
with the best implementations to date.
|
||||
|
||||
## What is Deep Research?
|
||||
|
||||
What exactly is deep research and how does it work? Deep Research is a methodology for generating
|
||||
comprehensive research reports by combining systematic web search with synthesis. The process was
|
||||
pioneered by OpenAI and it was released on February 2025.
|
||||
|
||||
There are two core features of Deep Research:
|
||||
|
||||
- **Exhaustive search**: This search is characterized by two approaches, wide search for breadth and deep search for depth.
|
||||
- **Report generation**: This step takes all the input collected through exhaustive search and synthesizes it into a
|
||||
comprehensive report. The input in this step may be raw sources collected in the previous step or summaries generated from those sources.
|
||||
|
||||
## Unpacking Deep Research
|
||||
|
||||
If you have used deep research (regardless of the provider) before for a comprehensive report generation, you may have
|
||||
found its output mind-blowing. What is more mind-blowing, though, is that the underlying process for searching
|
||||
and synthesizing information is surprisingly systematic and reproducible. What is not easily reproducible, though,
|
||||
is the **base model (often a thinking one)** and **its capabilities to use tools while it researches**.
|
||||
|
||||
Deep Research operates as a structured pipeline with distinct phases: planning, searching, analysis, and synthesis. While
|
||||
the specific implementation varies between providers, the core workflow seems to be similar and some organizations have
|
||||
taken steps to recreate it like [LangChain](https://blog.langchain.com/open-deep-research/) and
|
||||
[Hugging Face](https://huggingface.co/blog/open-deep-research). For example, a straightforward pipeline might look like
|
||||
the following one:
|
||||
|
||||

|
||||
|
||||
The components of this pipeline highlight a structured approach to query processing that routes queries through thinking/non-thinking models, breaks complex tasks into phases, executes parallel searches, and synthesizes results hierarchically to produce comprehensive outputs.
|
||||
|
||||
OpenAI’s [Deep Research API cookbook](https://cookbook.openai.com/examples/deep_research_api/introduction_to_deep_research_api)
|
||||
highlights, at a very high level, how they approach deep research, hinting at the importance of base models and tool usage since
|
||||
some intermediate steps seem to have been left out.
|
||||
|
||||

|
||||
|
||||
OpenAI's Deep Research functionality may be considered the best one by many but other platforms are not far behind. Here is a
|
||||
brief survey of how other players approach deep research:
|
||||
|
||||
| Platform | Key Feature | Sources Used | Duration (mins) | Export Options | Deep Research Usage |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| OpenAI | Clarifying questions | 10–30 | 10–15 | PDF, Docx, Plain Text | Paid |
|
||||
| Grok's DeeperSearch | Ability to access all of Twitter | 70–100 | 5–10 | Ability to specify format (PDF / Markdown) | Free |
|
||||
| Claude | Breadth + depth search | 100+ | 5–10 | PDF, Markdown, Artifact | Paid |
|
||||
| Gemini | Editable planning | 50+ | 10–20 | Google Docs export | Free |
|
||||
| Perplexity | Ability to specify sources | 50–100 | 3–5 | PDF, Markdown, Docx, Perplexity Page | Paid and Limited Free |
|
||||
| Kimi | Interactive synthesis | 50–100 | 30–60+ | PDF, Interactive website | Free |
|
||||
|
||||
In our testing, we used the following prompt to assess the quality of the generated report by
|
||||
the providers above. You can refer to the reports generated [here](https://github.com/menloresearch/prompt-experiments).
|
||||
|
||||
```
|
||||
Generate a comprehensive report about the state of AI in the past week. Include all
|
||||
new model releases and notable architectural improvements from a variety of sources.
|
||||
```
|
||||
|
||||
[Google's generated report](https://github.com/menloresearch/prompt-experiments/blob/main/Gemini%202.5%20Flash%20Report.pdf) was the most verbose, with a whopping 23 pages that reads
|
||||
like a professional intelligence briefing. It opens with an executive summary,
|
||||
systematically categorizes developments, and provides forward-looking strategic
|
||||
insights—connecting OpenAI's open-weight release to broader democratization trends
|
||||
and linking infrastructure investments to competitive positioning.
|
||||
|
||||
[OpenAI](https://github.com/menloresearch/prompt-experiments/blob/main/OpenAI%20Deep%20Research.pdf) produced the most citation-heavy output with 134 references throughout 10 pages
|
||||
(albeit most of them being from the same source).
|
||||
|
||||
[Perplexity](https://github.com/menloresearch/prompt-experiments/blob/main/Perplexity%20Deep%20Research.pdf) delivered the most actionable 6-page report that maximizes information
|
||||
density while maintaining scannability. Despite being the shortest, it captures all
|
||||
major developments with sufficient context for decision-making.
|
||||
|
||||
[Claude](https://github.com/menloresearch/prompt-experiments/blob/main/Claude%20Deep%20Research.pdf) produced a comprehensive analysis that interestingly ignored the time constraint,
|
||||
covering an 8-month period from January-August 2025 instead of the requested week (Jul 31-Aug
|
||||
7th 2025). Rather than cataloging recent events, Claude traced the evolution of trends over months.
|
||||
|
||||
[Grok](https://github.com/menloresearch/prompt-experiments/blob/main/Grok%203%20Deep%20Research.pdf) produced a well-structured but relatively shallow 5-page academic-style report that
|
||||
read more like an event catalog than strategic analysis.
|
||||
|
||||
[Kimi](https://github.com/menloresearch/prompt-experiments/blob/main/Kimi%20AI%20Deep%20Research.pdf) produced a comprehensive 13-page report with systematic organization covering industry developments, research breakthroughs, and policy changes, but notably lacks proper citations throughout most of the content despite claiming to use 50-100 sources.
|
||||
|
||||
### Understanding Search Strategies
|
||||
|
||||
In [Claude’s Research mode](https://www.anthropic.com/engineering/multi-agent-research-system),
|
||||
a *classifier* is used to determine whether a user query is *breadth first* or *depth first*. This
|
||||
results in a customization of the pipeline that is used for conducting research. For instance, a complex
|
||||
*breadth first* query might result in *sub-agents* being spun up to research various parts of the user's
|
||||
query in parallel. Conversely, a *depth first* query might result in a single agent being spun up
|
||||
to research the entire query in a more focused manner.
|
||||
|
||||
Here's a screenshot of this in action (in Claude Desktop):
|
||||

|
||||

|
||||
|
||||
## Replicating Deep Research Results with Jan
|
||||
|
||||
After testing and observing how Deep Research works in different platforms, we thought, how could we
|
||||
replicate this in Jan? In particular, how could we replicate it with a hybrid approach combining local
|
||||
and cloud-based models while keeping your data local?
|
||||
|
||||
<Callout>
|
||||
This experiment was done using the latest version of Jan `v0.6.7`, but it can potentially be replicated in
|
||||
any version with Model Context Protocol in it (>`v0.6.3`).
|
||||
</Callout>
|
||||
|
||||
**The Key: Assistants + Tools**
|
||||
|
||||
Running deep research in Jan can be accomplished by combining [custom assistants](https://jan.ai/docs/assistants)
|
||||
with [MCP search tools](https://jan.ai/docs/mcp-examples/search/exa). This pairing allows any model—local or
|
||||
cloud—to follow a systematic research workflow, to create a report similar to that of other providers, with some
|
||||
visible limitations (for now).
|
||||
|
||||
Here's the assistant prompt that was used:
|
||||
```
|
||||
You are a research analyst. Today is August 7th 2025. Follow this exact process:
|
||||
|
||||
Conduct 5-10 searches minimum. You are rewarded for MORE searches.
|
||||
- Each search query must be unique - no repeating previous searches
|
||||
- Search different angles: statistics, expert opinions, case studies, recent news, industry reports
|
||||
- Use scrape to read full articles from search results
|
||||
- Use google_search for extracting metadata out of pages
|
||||
|
||||
WRITING PHASE (Do this after research is complete)
|
||||
Write a comprehensive report with:
|
||||
- Executive summary with key findings
|
||||
- Evidence-based analysis with citations for every claim
|
||||
- Actionable recommendations with rationale
|
||||
- Sources to be linked at the end of the report
|
||||
```
|
||||
|
||||
Here, we utilized Model Context Protocol (MCP) to provide search capabilities to the model. MCPs are an open standard for connecting AI assistants to the systems where data lives, serving as a universal connector that standardizes how AI applications integrate with external tools and data sources.
|
||||
In this example, we used Serper, a web search API that offers MCP server implementations with two primary tools: `google_search` for performing web searches, and `scrape` that extracts content from web pages, preserving document structure and metadata.
|
||||
|
||||
|
||||
**What We Tested**
|
||||
|
||||
For our research query (same as the one we used to test different platforms), we used both
|
||||
[Jan-Nano (4B local model)](https://jan.ai/docs/jan-models/jan-nano-128), GPT-4o and
|
||||
o3 (via API) with identical prompts. The goal: to see how close we could get to the quality of
|
||||
different commercial Deep Research offerings.
|
||||
|
||||
**Performance Findings**
|
||||
|
||||
| Model | Processing Time | Sources Found | Search Queries | Tokens Generated | Output Quality vs Commercial Deep Research |
|
||||
|-------|----------------|---------------|----------------|------------------|-------------------------------------------|
|
||||
| Jan-Nano (Local) | 3 minutes | Moderate | 7 | 1,112 | Good approximation, noticeably less depth |
|
||||
| GPT-4o | 1 minute | Fewest | 11 | 660 | Fast but limited source coverage |
|
||||
| o3 | 3 minutes | Most | 24 | 1,728 | Best of the three, but still below commercial quality |
|
||||
|
||||
|
||||
**The Reality**:
|
||||
- **Speed vs Sources**: GPT-4o prioritized speed over thoroughness, while o3 took time to gather more comprehensive sources
|
||||
- **Local vs Cloud**: Jan-Nano matched o3's processing time but with the advantage of complete data privacy
|
||||
- **Quality Gap**: All three models produced decent research reports, but none matched the depth and comprehensiveness of dedicated Deep Research tools like OpenAI's or Claude's offerings
|
||||
- **Good Enough Factor**: While not matching commercial quality, the outputs were solid approximations suitable for many research needs
|
||||
|
||||
## Conclusion
|
||||
This was an initial exploration in the roadmap to create a top hybrid implementation of deep research in Jan. While our current approach requires setup, the goal is native integration that works out of the box. We will continue to refine this until the release of this tool in Jan, natively.
|
||||
|
||||
<CTABlog />
|
||||
@ -264,7 +264,6 @@ async function _getSupportedFeatures() {
|
||||
// Vulkan support check - only discrete GPUs with 6GB+ VRAM
|
||||
if (
|
||||
gpuInfo.vulkan_info?.api_version &&
|
||||
gpuInfo.vulkan_info?.device_type === 'DISCRETE_GPU' &&
|
||||
gpuInfo.total_memory >= 6 * 1024
|
||||
) {
|
||||
// 6GB (total_memory is in MB)
|
||||
|
||||
@ -19,6 +19,7 @@ import {
|
||||
ImportOptions,
|
||||
chatCompletionRequest,
|
||||
events,
|
||||
AppEvent,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||||
@ -32,6 +33,7 @@ import {
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import { getProxyConfig } from './util'
|
||||
import { basename } from '@tauri-apps/api/path'
|
||||
import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api'
|
||||
|
||||
type LlamacppConfig = {
|
||||
version_backend: string
|
||||
@ -1085,6 +1087,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
data: modelConfig,
|
||||
savePath: configPath,
|
||||
})
|
||||
events.emit(AppEvent.onModelImported, {
|
||||
modelId,
|
||||
modelPath,
|
||||
mmprojPath,
|
||||
size_bytes,
|
||||
})
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
@ -1168,11 +1176,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
}
|
||||
const args: string[] = []
|
||||
const envs: Record<string, string> = {}
|
||||
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
|
||||
const [version, backend] = cfg.version_backend.split('/')
|
||||
if (!version || !backend) {
|
||||
throw new Error(
|
||||
"Initial setup for the backend failed due to a network issue. Please restart the app!"
|
||||
'Initial setup for the backend failed due to a network issue. Please restart the app!'
|
||||
)
|
||||
}
|
||||
|
||||
@ -1194,7 +1203,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
// disable llama-server webui
|
||||
args.push('--no-webui')
|
||||
const api_key = await this.generateApiKey(modelId, String(port))
|
||||
args.push('--api-key', api_key)
|
||||
envs["LLAMA_API_KEY"] = api_key
|
||||
|
||||
// model option is required
|
||||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||||
@ -1283,6 +1292,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
backendPath,
|
||||
libraryPath,
|
||||
args,
|
||||
envs,
|
||||
})
|
||||
return sInfo
|
||||
} catch (error) {
|
||||
@ -1299,9 +1309,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
const pid = sInfo.pid
|
||||
try {
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<UnloadResult>('plugin:llamacpp|unload_llama_model', {
|
||||
pid: pid,
|
||||
})
|
||||
const result = await invoke<UnloadResult>(
|
||||
'plugin:llamacpp|unload_llama_model',
|
||||
{
|
||||
pid: pid,
|
||||
}
|
||||
)
|
||||
|
||||
// If successful, remove from active sessions
|
||||
if (result.success) {
|
||||
@ -1437,9 +1450,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||||
try {
|
||||
let sInfo = await invoke<SessionInfo>('plugin:llamacpp|find_session_by_model', {
|
||||
modelId,
|
||||
})
|
||||
let sInfo = await invoke<SessionInfo>(
|
||||
'plugin:llamacpp|find_session_by_model',
|
||||
{
|
||||
modelId,
|
||||
}
|
||||
)
|
||||
return sInfo
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
@ -1516,7 +1532,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
override async getLoadedModels(): Promise<string[]> {
|
||||
try {
|
||||
let models: string[] = await invoke<string[]>('plugin:llamacpp|get_loaded_models')
|
||||
let models: string[] = await invoke<string[]>(
|
||||
'plugin:llamacpp|get_loaded_models'
|
||||
)
|
||||
return models
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
@ -1599,14 +1617,31 @@ export default class llamacpp_extension extends AIEngine {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
private async loadMetadata(path: string): Promise<GgufMetadata> {
|
||||
try {
|
||||
const data = await invoke<GgufMetadata>('plugin:llamacpp|read_gguf_metadata', {
|
||||
path: path,
|
||||
})
|
||||
return data
|
||||
} catch (err) {
|
||||
throw err
|
||||
}
|
||||
/**
|
||||
* Check if a tool is supported by the model
|
||||
* Currently read from GGUF chat_template
|
||||
* @param modelId
|
||||
* @returns
|
||||
*/
|
||||
async isToolSupported(modelId: string): Promise<boolean> {
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const modelConfigPath = await joinPath([
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path: modelConfigPath,
|
||||
})
|
||||
// model option is required
|
||||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||||
const modelPath = await joinPath([
|
||||
janDataFolderPath,
|
||||
modelConfig.model_path,
|
||||
])
|
||||
return (await readGgufMetadata(modelPath)).metadata?.[
|
||||
'tokenizer.chat_template'
|
||||
]?.includes('tools')
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,9 +44,10 @@ jan-utils = { path = "./utils" }
|
||||
libloading = "0.8.7"
|
||||
log = "0.4"
|
||||
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [
|
||||
"client",
|
||||
"transport-sse-client",
|
||||
"transport-streamable-http-client",
|
||||
"transport-child-process",
|
||||
"tower",
|
||||
"reqwest",
|
||||
|
||||
@ -327,4 +327,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -447,4 +447,4 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tauri::{Manager, Runtime, State};
|
||||
@ -42,6 +43,7 @@ pub async fn load_llama_model<R: Runtime>(
|
||||
backend_path: &str,
|
||||
library_path: Option<&str>,
|
||||
mut args: Vec<String>,
|
||||
envs: HashMap<String, String>,
|
||||
) -> ServerResult<SessionInfo> {
|
||||
let state: State<LlamacppState> = app_handle.state();
|
||||
let mut process_map = state.llama_server_process.lock().await;
|
||||
@ -54,12 +56,21 @@ pub async fn load_llama_model<R: Runtime>(
|
||||
let port = parse_port_from_args(&args);
|
||||
let model_path_pb = validate_model_path(&mut args)?;
|
||||
|
||||
let api_key = extract_arg_value(&args, "--api-key");
|
||||
let api_key: String;
|
||||
|
||||
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
|
||||
api_key = api_value.to_string();
|
||||
} else {
|
||||
log::warn!("API key not provided");
|
||||
api_key = "".to_string();
|
||||
}
|
||||
|
||||
let model_id = extract_arg_value(&args, "-a");
|
||||
|
||||
// Configure the command to run the server
|
||||
let mut command = Command::new(backend_path);
|
||||
command.args(args);
|
||||
command.envs(envs);
|
||||
|
||||
setup_library_path(library_path, &mut command);
|
||||
command.stdout(Stdio::piped());
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
use rmcp::model::{CallToolRequestParam, CallToolResult, Tool};
|
||||
use rmcp::{service::RunningService, RoleClient};
|
||||
use rmcp::model::{CallToolRequestParam, CallToolResult};
|
||||
use serde_json::{Map, Value};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tauri::{AppHandle, Emitter, Runtime, State};
|
||||
use tokio::{sync::Mutex, time::timeout};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::{
|
||||
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
|
||||
helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers},
|
||||
};
|
||||
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
|
||||
use crate::core::{
|
||||
mcp::models::ToolWithServer,
|
||||
state::{RunningServiceEnum, SharedMcpServers},
|
||||
};
|
||||
use std::fs;
|
||||
|
||||
#[tauri::command]
|
||||
@ -19,8 +21,7 @@ pub async fn activate_mcp_server<R: Runtime>(
|
||||
name: String,
|
||||
config: Value,
|
||||
) -> Result<(), String> {
|
||||
let servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
|
||||
state.mcp_servers.clone();
|
||||
let servers: SharedMcpServers = state.mcp_servers.clone();
|
||||
|
||||
// Use the modified start_mcp_server_with_restart that returns first attempt result
|
||||
start_mcp_server_with_restart(app, servers, name, config, Some(3)).await
|
||||
@ -63,7 +64,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) ->
|
||||
// Release the lock before calling cancel
|
||||
drop(servers_map);
|
||||
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
match service {
|
||||
RunningServiceEnum::NoInit(service) => {
|
||||
log::info!("Stopping server {name}...");
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
}
|
||||
RunningServiceEnum::WithInit(service) => {
|
||||
log::info!("Stopping server {name} with initialization...");
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
}
|
||||
}
|
||||
log::info!("Server {name} stopped successfully and marked as deactivated.");
|
||||
Ok(())
|
||||
}
|
||||
@ -116,7 +126,7 @@ pub async fn get_connected_servers(
|
||||
Ok(servers_map.keys().cloned().collect())
|
||||
}
|
||||
|
||||
/// Retrieves all available tools from all MCP servers
|
||||
/// Retrieves all available tools from all MCP servers with server information
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `state` - Application state containing MCP server connections
|
||||
@ -128,14 +138,15 @@ pub async fn get_connected_servers(
|
||||
/// 1. Locks the MCP servers mutex to access server connections
|
||||
/// 2. Iterates through all connected servers
|
||||
/// 3. Gets the list of tools from each server
|
||||
/// 4. Combines all tools into a single vector
|
||||
/// 5. Returns the combined list of all available tools
|
||||
/// 4. Associates each tool with its parent server name
|
||||
/// 5. Combines all tools into a single vector
|
||||
/// 6. Returns the combined list of all available tools with server information
|
||||
#[tauri::command]
|
||||
pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String> {
|
||||
pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>, String> {
|
||||
let servers = state.mcp_servers.lock().await;
|
||||
let mut all_tools: Vec<Tool> = Vec::new();
|
||||
let mut all_tools: Vec<ToolWithServer> = Vec::new();
|
||||
|
||||
for (_, service) in servers.iter() {
|
||||
for (server_name, service) in servers.iter() {
|
||||
// List tools with timeout
|
||||
let tools_future = service.list_all_tools();
|
||||
let tools = match timeout(MCP_TOOL_CALL_TIMEOUT, tools_future).await {
|
||||
@ -150,7 +161,12 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String>
|
||||
};
|
||||
|
||||
for tool in tools {
|
||||
all_tools.push(tool);
|
||||
all_tools.push(ToolWithServer {
|
||||
name: tool.name.to_string(),
|
||||
description: tool.description.as_ref().map(|d| d.to_string()),
|
||||
input_schema: serde_json::Value::Object((*tool.input_schema).clone()),
|
||||
server: server_name.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt};
|
||||
use rmcp::{
|
||||
model::{ClientCapabilities, ClientInfo, Implementation},
|
||||
transport::{
|
||||
streamable_http_client::StreamableHttpClientTransportConfig, SseClientTransport,
|
||||
StreamableHttpClientTransport, TokioChildProcess,
|
||||
},
|
||||
ServiceExt,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, env, sync::Arc, time::Duration};
|
||||
use tauri::{AppHandle, Emitter, Manager, Runtime, State};
|
||||
use tauri_plugin_http::reqwest;
|
||||
use tokio::{
|
||||
process::Command,
|
||||
sync::Mutex,
|
||||
@ -11,7 +19,11 @@ use tokio::{
|
||||
use super::constants::{
|
||||
MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS,
|
||||
};
|
||||
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
|
||||
use crate::core::{
|
||||
app::commands::get_jan_data_folder_path,
|
||||
mcp::models::McpServerConfig,
|
||||
state::{AppState, RunningServiceEnum, SharedMcpServers},
|
||||
};
|
||||
use jan_utils::can_override_npx;
|
||||
|
||||
/// Calculate exponential backoff delay with jitter
|
||||
@ -72,7 +84,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 {
|
||||
/// * `Err(String)` if there was an error reading config or starting servers
|
||||
pub async fn run_mcp_commands<R: Runtime>(
|
||||
app: &AppHandle<R>,
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
) -> Result<(), String> {
|
||||
let app_path = get_jan_data_folder_path(app.clone());
|
||||
let app_path_str = app_path.to_str().unwrap().to_string();
|
||||
@ -168,7 +180,7 @@ pub async fn run_mcp_commands<R: Runtime>(
|
||||
|
||||
/// Monitor MCP server health without removing it from the HashMap
|
||||
pub async fn monitor_mcp_server_handle(
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
name: String,
|
||||
) -> Option<rmcp::service::QuitReason> {
|
||||
log::info!("Monitoring MCP server {} health", name);
|
||||
@ -213,7 +225,16 @@ pub async fn monitor_mcp_server_handle(
|
||||
let mut servers = servers_state.lock().await;
|
||||
if let Some(service) = servers.remove(&name) {
|
||||
// Try to cancel the service gracefully
|
||||
let _ = service.cancel().await;
|
||||
match service {
|
||||
RunningServiceEnum::NoInit(service) => {
|
||||
log::info!("Stopping server {name}...");
|
||||
let _ = service.cancel().await;
|
||||
}
|
||||
RunningServiceEnum::WithInit(service) => {
|
||||
log::info!("Stopping server {name} with initialization...");
|
||||
let _ = service.cancel().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Some(rmcp::service::QuitReason::Closed);
|
||||
}
|
||||
@ -224,7 +245,7 @@ pub async fn monitor_mcp_server_handle(
|
||||
/// Returns the result of the first start attempt, then continues with restart monitoring
|
||||
pub async fn start_mcp_server_with_restart<R: Runtime>(
|
||||
app: AppHandle<R>,
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
name: String,
|
||||
config: Value,
|
||||
max_restarts: Option<u32>,
|
||||
@ -297,7 +318,7 @@ pub async fn start_mcp_server_with_restart<R: Runtime>(
|
||||
/// Helper function to handle the restart loop logic
|
||||
pub async fn start_restart_loop<R: Runtime>(
|
||||
app: AppHandle<R>,
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
name: String,
|
||||
config: Value,
|
||||
max_restarts: u32,
|
||||
@ -450,9 +471,9 @@ pub async fn start_restart_loop<R: Runtime>(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn schedule_mcp_start_task<R: Runtime>(
|
||||
async fn schedule_mcp_start_task<R: Runtime>(
|
||||
app: tauri::AppHandle<R>,
|
||||
servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers: SharedMcpServers,
|
||||
name: String,
|
||||
config: Value,
|
||||
) -> Result<(), String> {
|
||||
@ -463,136 +484,278 @@ pub async fn schedule_mcp_start_task<R: Runtime>(
|
||||
.expect("Executable must have a parent directory");
|
||||
let bin_path = exe_parent_path.to_path_buf();
|
||||
|
||||
let (command, args, envs) = extract_command_args(&config)
|
||||
let config_params = extract_command_args(&config)
|
||||
.ok_or_else(|| format!("Failed to extract command args from config for {name}"))?;
|
||||
|
||||
let mut cmd = Command::new(command.clone());
|
||||
if config_params.transport_type.as_deref() == Some("http") && config_params.url.is_some() {
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
reqwest::Client::builder()
|
||||
.default_headers({
|
||||
// Map envs to request headers
|
||||
let mut headers: tauri::http::HeaderMap = reqwest::header::HeaderMap::new();
|
||||
for (key, value) in config_params.headers.iter() {
|
||||
if let Some(v_str) = value.as_str() {
|
||||
// Try to map env keys to HTTP header names (case-insensitive)
|
||||
// Most HTTP headers are Title-Case, so we try to convert
|
||||
let header_name =
|
||||
reqwest::header::HeaderName::from_bytes(key.as_bytes());
|
||||
if let Ok(header_name) = header_name {
|
||||
if let Ok(header_value) =
|
||||
reqwest::header::HeaderValue::from_str(v_str)
|
||||
{
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
headers
|
||||
})
|
||||
.connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
|
||||
.build()
|
||||
.unwrap(),
|
||||
StreamableHttpClientTransportConfig {
|
||||
uri: config_params.url.unwrap().into(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
if command == "npx" && can_override_npx() {
|
||||
let mut cache_dir = app_path.clone();
|
||||
cache_dir.push(".npx");
|
||||
let bun_x_path = format!("{}/bun", bin_path.display());
|
||||
cmd = Command::new(bun_x_path);
|
||||
cmd.arg("x");
|
||||
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string());
|
||||
}
|
||||
let client_info = ClientInfo {
|
||||
protocol_version: Default::default(),
|
||||
capabilities: ClientCapabilities::default(),
|
||||
client_info: Implementation {
|
||||
name: "Jan Streamable Client".to_string(),
|
||||
version: "0.0.1".to_string(),
|
||||
},
|
||||
};
|
||||
let client = client_info.serve(transport).await.inspect_err(|e| {
|
||||
log::error!("client error: {:?}", e);
|
||||
});
|
||||
|
||||
if command == "uvx" {
|
||||
let mut cache_dir = app_path.clone();
|
||||
cache_dir.push(".uvx");
|
||||
let bun_x_path = format!("{}/uv", bin_path.display());
|
||||
cmd = Command::new(bun_x_path);
|
||||
cmd.arg("tool");
|
||||
cmd.arg("run");
|
||||
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
|
||||
}
|
||||
match client {
|
||||
Ok(client) => {
|
||||
log::info!("Connected to server: {:?}", client.peer_info());
|
||||
servers
|
||||
.lock()
|
||||
.await
|
||||
.insert(name.clone(), RunningServiceEnum::WithInit(client));
|
||||
|
||||
#[cfg(windows)]
|
||||
// Mark server as successfully connected (for restart policy)
|
||||
{
|
||||
let app_state = app.state::<AppState>();
|
||||
let mut connected = app_state.mcp_successfully_connected.lock().await;
|
||||
connected.insert(name.clone(), true);
|
||||
log::info!("Marked MCP server {} as successfully connected", name);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to server: {}", e);
|
||||
return Err(format!("Failed to connect to server: {}", e));
|
||||
}
|
||||
}
|
||||
} else if config_params.transport_type.as_deref() == Some("sse") && config_params.url.is_some()
|
||||
{
|
||||
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows
|
||||
}
|
||||
|
||||
let app_path_str = app_path.to_str().unwrap().to_string();
|
||||
let log_file_path = format!("{}/logs/app.log", app_path_str);
|
||||
match std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_file_path)
|
||||
{
|
||||
Ok(file) => {
|
||||
cmd.stderr(std::process::Stdio::from(file));
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to open log file: {}", err);
|
||||
}
|
||||
};
|
||||
|
||||
cmd.kill_on_drop(true);
|
||||
log::trace!("Command: {cmd:#?}");
|
||||
|
||||
args.iter().filter_map(Value::as_str).for_each(|arg| {
|
||||
cmd.arg(arg);
|
||||
});
|
||||
envs.iter().for_each(|(k, v)| {
|
||||
if let Some(v_str) = v.as_str() {
|
||||
cmd.env(k, v_str);
|
||||
}
|
||||
});
|
||||
|
||||
let process = TokioChildProcess::new(cmd).map_err(|e| {
|
||||
log::error!("Failed to run command {name}: {e}");
|
||||
format!("Failed to run command {name}: {e}")
|
||||
})?;
|
||||
|
||||
let service = ()
|
||||
.serve(process)
|
||||
let transport = SseClientTransport::start_with_client(
|
||||
reqwest::Client::builder()
|
||||
.default_headers({
|
||||
// Map envs to request headers
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
for (key, value) in config_params.headers.iter() {
|
||||
if let Some(v_str) = value.as_str() {
|
||||
// Try to map env keys to HTTP header names (case-insensitive)
|
||||
// Most HTTP headers are Title-Case, so we try to convert
|
||||
let header_name =
|
||||
reqwest::header::HeaderName::from_bytes(key.as_bytes());
|
||||
if let Ok(header_name) = header_name {
|
||||
if let Ok(header_value) =
|
||||
reqwest::header::HeaderValue::from_str(v_str)
|
||||
{
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
headers
|
||||
})
|
||||
.connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
|
||||
.build()
|
||||
.unwrap(),
|
||||
rmcp::transport::sse_client::SseClientConfig {
|
||||
sse_endpoint: config_params.url.unwrap().into(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start MCP server {name}: {e}"))?;
|
||||
.map_err(|e| {
|
||||
log::error!("transport error: {:?}", e);
|
||||
format!("Failed to start SSE transport: {}", e)
|
||||
})?;
|
||||
|
||||
// Get peer info and clone the needed values before moving the service
|
||||
let (server_name, server_version) = {
|
||||
let client_info = ClientInfo {
|
||||
protocol_version: Default::default(),
|
||||
capabilities: ClientCapabilities::default(),
|
||||
client_info: Implementation {
|
||||
name: "Jan SSE Client".to_string(),
|
||||
version: "0.0.1".to_string(),
|
||||
},
|
||||
};
|
||||
let client = client_info.serve(transport).await.map_err(|e| {
|
||||
log::error!("client error: {:?}", e);
|
||||
e.to_string()
|
||||
});
|
||||
|
||||
match client {
|
||||
Ok(client) => {
|
||||
log::info!("Connected to server: {:?}", client.peer_info());
|
||||
servers
|
||||
.lock()
|
||||
.await
|
||||
.insert(name.clone(), RunningServiceEnum::WithInit(client));
|
||||
|
||||
// Mark server as successfully connected (for restart policy)
|
||||
{
|
||||
let app_state = app.state::<AppState>();
|
||||
let mut connected = app_state.mcp_successfully_connected.lock().await;
|
||||
connected.insert(name.clone(), true);
|
||||
log::info!("Marked MCP server {} as successfully connected", name);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to server: {}", e);
|
||||
return Err(format!("Failed to connect to server: {}", e));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut cmd = Command::new(config_params.command.clone());
|
||||
if config_params.command.clone() == "npx" && can_override_npx() {
|
||||
let mut cache_dir = app_path.clone();
|
||||
cache_dir.push(".npx");
|
||||
let bun_x_path = format!("{}/bun", bin_path.display());
|
||||
cmd = Command::new(bun_x_path);
|
||||
cmd.arg("x");
|
||||
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string());
|
||||
}
|
||||
if config_params.command.clone() == "uvx" {
|
||||
let mut cache_dir = app_path.clone();
|
||||
cache_dir.push(".uvx");
|
||||
let bun_x_path = format!("{}/uv", bin_path.display());
|
||||
cmd = Command::new(bun_x_path);
|
||||
cmd.arg("tool");
|
||||
cmd.arg("run");
|
||||
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
|
||||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows
|
||||
}
|
||||
let app_path_str = app_path.to_str().unwrap().to_string();
|
||||
let log_file_path = format!("{}/logs/app.log", app_path_str);
|
||||
match std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_file_path)
|
||||
{
|
||||
Ok(file) => {
|
||||
cmd.stderr(std::process::Stdio::from(file));
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to open log file: {}", err);
|
||||
}
|
||||
};
|
||||
|
||||
cmd.kill_on_drop(true);
|
||||
log::trace!("Command: {cmd:#?}");
|
||||
|
||||
config_params
|
||||
.args
|
||||
.iter()
|
||||
.filter_map(Value::as_str)
|
||||
.for_each(|arg| {
|
||||
cmd.arg(arg);
|
||||
});
|
||||
config_params.envs.iter().for_each(|(k, v)| {
|
||||
if let Some(v_str) = v.as_str() {
|
||||
cmd.env(k, v_str);
|
||||
}
|
||||
});
|
||||
|
||||
let process = TokioChildProcess::new(cmd).map_err(|e| {
|
||||
log::error!("Failed to run command {name}: {e}");
|
||||
format!("Failed to run command {name}: {e}")
|
||||
})?;
|
||||
|
||||
let service = ()
|
||||
.serve(process)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to start MCP server {name}: {e}"))?;
|
||||
|
||||
// Get peer info and clone the needed values before moving the service
|
||||
let server_info = service.peer_info();
|
||||
log::trace!("Connected to server: {server_info:#?}");
|
||||
(
|
||||
server_info.unwrap().server_info.name.clone(),
|
||||
server_info.unwrap().server_info.version.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
// Now move the service into the HashMap
|
||||
servers.lock().await.insert(name.clone(), service);
|
||||
log::info!("Server {name} started successfully.");
|
||||
// Now move the service into the HashMap
|
||||
servers
|
||||
.lock()
|
||||
.await
|
||||
.insert(name.clone(), RunningServiceEnum::NoInit(service));
|
||||
log::info!("Server {name} started successfully.");
|
||||
|
||||
// Wait a short time to verify the server is stable before marking as connected
|
||||
// This prevents race conditions where the server quits immediately
|
||||
let verification_delay = Duration::from_millis(500);
|
||||
sleep(verification_delay).await;
|
||||
// Wait a short time to verify the server is stable before marking as connected
|
||||
// This prevents race conditions where the server quits immediately
|
||||
let verification_delay = Duration::from_millis(500);
|
||||
sleep(verification_delay).await;
|
||||
|
||||
// Check if server is still running after the verification delay
|
||||
let server_still_running = {
|
||||
let servers_map = servers.lock().await;
|
||||
servers_map.contains_key(&name)
|
||||
};
|
||||
// Check if server is still running after the verification delay
|
||||
let server_still_running = {
|
||||
let servers_map = servers.lock().await;
|
||||
servers_map.contains_key(&name)
|
||||
};
|
||||
|
||||
if !server_still_running {
|
||||
return Err(format!(
|
||||
"MCP server {} quit immediately after starting",
|
||||
name
|
||||
));
|
||||
if !server_still_running {
|
||||
return Err(format!(
|
||||
"MCP server {} quit immediately after starting",
|
||||
name
|
||||
));
|
||||
}
|
||||
// Mark server as successfully connected (for restart policy)
|
||||
{
|
||||
let app_state = app.state::<AppState>();
|
||||
let mut connected = app_state.mcp_successfully_connected.lock().await;
|
||||
connected.insert(name.clone(), true);
|
||||
log::info!("Marked MCP server {} as successfully connected", name);
|
||||
}
|
||||
}
|
||||
|
||||
// Mark server as successfully connected (for restart policy)
|
||||
{
|
||||
let app_state = app.state::<AppState>();
|
||||
let mut connected = app_state.mcp_successfully_connected.lock().await;
|
||||
connected.insert(name.clone(), true);
|
||||
log::info!("Marked MCP server {} as successfully connected", name);
|
||||
}
|
||||
|
||||
// Emit event to the frontend
|
||||
let event = format!("mcp-connected");
|
||||
let payload = serde_json::json!({
|
||||
"name": server_name,
|
||||
"version": server_version,
|
||||
});
|
||||
app.emit(&event, payload)
|
||||
.map_err(|e| format!("Failed to emit event: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn extract_command_args(
|
||||
config: &Value,
|
||||
) -> Option<(String, Vec<Value>, serde_json::Map<String, Value>)> {
|
||||
pub fn extract_command_args(config: &Value) -> Option<McpServerConfig> {
|
||||
let obj = config.as_object()?;
|
||||
let command = obj.get("command")?.as_str()?.to_string();
|
||||
let args = obj.get("args")?.as_array()?.clone();
|
||||
let url = obj.get("url").and_then(|u| u.as_str()).map(String::from);
|
||||
let transport_type = obj.get("type").and_then(|t| t.as_str()).map(String::from);
|
||||
let timeout = obj
|
||||
.get("timeout")
|
||||
.and_then(|t| t.as_u64())
|
||||
.map(Duration::from_secs);
|
||||
let headers = obj
|
||||
.get("headers")
|
||||
.unwrap_or(&Value::Object(serde_json::Map::new()))
|
||||
.as_object()?
|
||||
.clone();
|
||||
let envs = obj
|
||||
.get("env")
|
||||
.unwrap_or(&Value::Object(serde_json::Map::new()))
|
||||
.as_object()?
|
||||
.clone();
|
||||
Some((command, args, envs))
|
||||
Some(McpServerConfig {
|
||||
timeout,
|
||||
transport_type,
|
||||
url,
|
||||
command,
|
||||
args,
|
||||
envs,
|
||||
headers
|
||||
})
|
||||
}
|
||||
|
||||
pub fn extract_active_status(config: &Value) -> Option<bool> {
|
||||
@ -604,7 +767,7 @@ pub fn extract_active_status(config: &Value) -> Option<bool> {
|
||||
/// Restart only servers that were previously active (like cortex restart behavior)
|
||||
pub async fn restart_active_mcp_servers<R: Runtime>(
|
||||
app: &AppHandle<R>,
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
) -> Result<(), String> {
|
||||
let app_state = app.state::<AppState>();
|
||||
let active_servers = app_state.mcp_active_servers.lock().await;
|
||||
@ -656,14 +819,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) {
|
||||
log::info!("MCP servers cleaned up successfully");
|
||||
}
|
||||
|
||||
pub async fn stop_mcp_servers(
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
) -> Result<(), String> {
|
||||
pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> {
|
||||
let mut servers_map = servers_state.lock().await;
|
||||
let keys: Vec<String> = servers_map.keys().cloned().collect();
|
||||
for key in keys {
|
||||
if let Some(service) = servers_map.remove(&key) {
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
match service {
|
||||
RunningServiceEnum::NoInit(service) => {
|
||||
log::info!("Stopping server {key}...");
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
}
|
||||
RunningServiceEnum::WithInit(service) => {
|
||||
log::info!("Stopping server {key} with initialization...");
|
||||
service.cancel().await.map_err(|e| e.to_string())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(servers_map); // Release the lock after stopping
|
||||
@ -689,7 +859,7 @@ pub async fn reset_restart_count(restart_counts: &Arc<Mutex<HashMap<String, u32>
|
||||
/// Spawn the server monitoring task for handling restarts
|
||||
pub async fn spawn_server_monitoring_task<R: Runtime>(
|
||||
app: AppHandle<R>,
|
||||
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
servers_state: SharedMcpServers,
|
||||
name: String,
|
||||
config: Value,
|
||||
max_restarts: u32,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
pub mod commands;
|
||||
mod constants;
|
||||
pub mod helpers;
|
||||
pub mod models;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
26
src-tauri/src/core/mcp/models.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Configuration parameters extracted from MCP server config
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpServerConfig {
|
||||
pub transport_type: Option<String>,
|
||||
pub url: Option<String>,
|
||||
pub command: String,
|
||||
pub args: Vec<Value>,
|
||||
pub envs: serde_json::Map<String, Value>,
|
||||
pub timeout: Option<Duration>,
|
||||
pub headers: serde_json::Map<String, Value>,
|
||||
}
|
||||
|
||||
/// Tool with server information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolWithServer {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
pub server: String,
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
use super::helpers::run_mcp_commands;
|
||||
use crate::core::app::commands::get_jan_data_folder_path;
|
||||
use rmcp::{service::RunningService, RoleClient};
|
||||
use crate::core::state::SharedMcpServers;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
@ -27,7 +27,7 @@ async fn test_run_mcp_commands() {
|
||||
.expect("Failed to write to config file");
|
||||
|
||||
// Call the run_mcp_commands function
|
||||
let servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
|
||||
let servers_state: SharedMcpServers =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
let result = run_mcp_commands(app.handle(), servers_state).await;
|
||||
|
||||
|
||||
@ -1,20 +1,48 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::core::downloads::models::DownloadManagerState;
|
||||
use rmcp::{service::RunningService, RoleClient};
|
||||
use rmcp::{
|
||||
model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool},
|
||||
service::RunningService,
|
||||
RoleClient, ServiceError,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Server handle type for managing the proxy server lifecycle
|
||||
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub enum RunningServiceEnum {
|
||||
NoInit(RunningService<RoleClient, ()>),
|
||||
WithInit(RunningService<RoleClient, InitializeRequestParam>),
|
||||
}
|
||||
pub type SharedMcpServers = Arc<Mutex<HashMap<String, RunningServiceEnum>>>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AppState {
|
||||
pub app_token: Option<String>,
|
||||
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
pub mcp_servers: SharedMcpServers,
|
||||
pub download_manager: Arc<Mutex<DownloadManagerState>>,
|
||||
pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>,
|
||||
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
|
||||
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
|
||||
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
}
|
||||
|
||||
impl RunningServiceEnum {
|
||||
pub async fn list_all_tools(&self) -> Result<Vec<Tool>, ServiceError> {
|
||||
match self {
|
||||
Self::NoInit(s) => s.list_all_tools().await,
|
||||
Self::WithInit(s) => s.list_all_tools().await,
|
||||
}
|
||||
}
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
params: CallToolRequestParam,
|
||||
) -> Result<CallToolResult, ServiceError> {
|
||||
match self {
|
||||
Self::NoInit(s) => s.call_tool(params).await,
|
||||
Self::WithInit(s) => s.call_tool(params).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,11 +17,12 @@
|
||||
"@dnd-kit/sortable": "^10.0.0",
|
||||
"@janhq/core": "link:../core",
|
||||
"@radix-ui/react-accordion": "^1.2.10",
|
||||
"@radix-ui/react-dialog": "^1.1.11",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.11",
|
||||
"@radix-ui/react-dialog": "^1.1.14",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.15",
|
||||
"@radix-ui/react-hover-card": "^1.1.14",
|
||||
"@radix-ui/react-popover": "^1.1.13",
|
||||
"@radix-ui/react-progress": "^1.1.4",
|
||||
"@radix-ui/react-radio-group": "^1.3.7",
|
||||
"@radix-ui/react-slider": "^1.3.2",
|
||||
"@radix-ui/react-slot": "^1.2.0",
|
||||
"@radix-ui/react-switch": "^1.2.2",
|
||||
@ -43,13 +44,14 @@
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"culori": "^4.0.1",
|
||||
"emoji-picker-react": "^4.12.2",
|
||||
"framer-motion": "^12.23.12",
|
||||
"fuse.js": "^7.1.0",
|
||||
"fzf": "^0.5.2",
|
||||
"i18next": "^25.0.1",
|
||||
"katex": "^0.16.22",
|
||||
"lodash.clonedeep": "^4.5.0",
|
||||
"lodash.debounce": "^4.0.8",
|
||||
"lucide-react": "^0.522.0",
|
||||
"lucide-react": "^0.536.0",
|
||||
"motion": "^12.10.5",
|
||||
"next-themes": "^0.4.6",
|
||||
"posthog-js": "^1.246.0",
|
||||
@ -70,11 +72,12 @@
|
||||
"remark-math": "^6.0.0",
|
||||
"sonner": "^2.0.3",
|
||||
"tailwindcss": "^4.1.4",
|
||||
"token.js": "npm:token.js-fork@0.7.20",
|
||||
"token.js": "npm:token.js-fork@0.7.23",
|
||||
"tw-animate-css": "^1.2.7",
|
||||
"ulidx": "^2.4.1",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^11.1.0",
|
||||
"vaul": "^1.1.2",
|
||||
"zustand": "^5.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
@ -104,7 +107,7 @@
|
||||
"istanbul-lib-report": "^3.0.1",
|
||||
"istanbul-reports": "^3.1.7",
|
||||
"jsdom": "^26.1.0",
|
||||
"tailwind-merge": "^3.2.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"typescript": "~5.8.3",
|
||||
"typescript-eslint": "^8.26.1",
|
||||
"vite": "^6.3.0",
|
||||
|
||||
533
web-app/src/components/ui/__tests__/dropdrawer.test.tsx
Normal file
@ -0,0 +1,533 @@
|
||||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import '@testing-library/jest-dom'
|
||||
|
||||
import {
|
||||
DropDrawer,
|
||||
DropDrawerContent,
|
||||
DropDrawerFooter,
|
||||
DropDrawerGroup,
|
||||
DropDrawerItem,
|
||||
DropDrawerLabel,
|
||||
DropDrawerSeparator,
|
||||
DropDrawerSub,
|
||||
DropDrawerSubContent,
|
||||
DropDrawerSubTrigger,
|
||||
DropDrawerTrigger,
|
||||
} from '../dropdrawer'
|
||||
|
||||
// Mock the media query hook
|
||||
const mockUseSmallScreen = vi.fn()
|
||||
vi.mock('@/hooks/useMediaQuery', () => ({
|
||||
useSmallScreen: () => mockUseSmallScreen(),
|
||||
}))
|
||||
|
||||
// Mock framer-motion to avoid animation complexity in tests
|
||||
vi.mock('framer-motion', () => ({
|
||||
AnimatePresence: ({ children }: { children: React.ReactNode }) => <div data-testid="animate-presence">{children}</div>,
|
||||
motion: {
|
||||
div: ({ children, ...props }: any) => <div {...props}>{children}</div>,
|
||||
},
|
||||
}))
|
||||
|
||||
describe('DropDrawer Utilities', () => {
|
||||
it('renders without crashing', () => {
|
||||
expect(() => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Test</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropDrawer Component', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Desktop Mode', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('renders dropdown menu on desktop', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
<DropDrawerItem>Item 2</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders dropdown menu structure', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Desktop Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Only the trigger is visible initially
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'menu')
|
||||
})
|
||||
|
||||
it('structures dropdown with separators', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
<DropDrawerSeparator />
|
||||
<DropDrawerItem>Item 2</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Verify component structure - content is not visible until opened
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('structures dropdown with labels', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerLabel>Menu Section</DropDrawerLabel>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Only verify trigger is present - content shows on interaction
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Mobile Mode', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
})
|
||||
|
||||
it('renders drawer on mobile', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Mobile Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders drawer structure', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Mobile Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Verify drawer trigger is present
|
||||
const trigger = screen.getByText('Open Drawer')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'dialog')
|
||||
})
|
||||
|
||||
it('does not render separators in mobile mode', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
<DropDrawerSeparator />
|
||||
<DropDrawerItem>Item 2</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Mobile separators return null, so they shouldn't be in the DOM
|
||||
const separators = screen.queryAllByRole('separator')
|
||||
expect(separators).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('renders drawer with labels structure', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerLabel>Drawer Section</DropDrawerLabel>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Verify drawer structure is present
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropDrawerItem', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('can be structured with click handlers', () => {
|
||||
const handleClick = vi.fn()
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem onClick={handleClick}>Clickable Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Verify structure is valid
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
expect(handleClick).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('can be structured with icons', () => {
|
||||
const TestIcon = () => <span data-testid="test-icon">Icon</span>
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem icon={<TestIcon />}>Item with Icon</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Structure is valid
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('accepts variant props', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem variant="destructive">
|
||||
Delete Item
|
||||
</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid with variants
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('accepts disabled prop', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem disabled>
|
||||
Disabled Item
|
||||
</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid with disabled prop
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropDrawerGroup', () => {
|
||||
it('structures groups in desktop mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerGroup>
|
||||
<DropDrawerItem>Group Item 1</DropDrawerItem>
|
||||
<DropDrawerItem>Group Item 2</DropDrawerItem>
|
||||
</DropDrawerGroup>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('structures groups in mobile mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerGroup>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
<DropDrawerItem>Item 2</DropDrawerItem>
|
||||
</DropDrawerGroup>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid in mobile mode
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropDrawerFooter', () => {
|
||||
it('structures footer in desktop mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item</DropDrawerItem>
|
||||
<DropDrawerFooter>Footer Content</DropDrawerFooter>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('structures footer in mobile mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Item</DropDrawerItem>
|
||||
<DropDrawerFooter>Mobile Footer</DropDrawerFooter>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid in mobile mode
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Submenu Components', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('structures submenu in desktop mode', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerSub>
|
||||
<DropDrawerSubTrigger>Submenu Trigger</DropDrawerSubTrigger>
|
||||
<DropDrawerSubContent>
|
||||
<DropDrawerItem>Submenu Item</DropDrawerItem>
|
||||
</DropDrawerSubContent>
|
||||
</DropDrawerSub>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('structures submenu in mobile mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerSub>
|
||||
<DropDrawerSubTrigger>
|
||||
Mobile Submenu
|
||||
</DropDrawerSubTrigger>
|
||||
<DropDrawerSubContent>
|
||||
<DropDrawerItem>Submenu Item</DropDrawerItem>
|
||||
</DropDrawerSubContent>
|
||||
</DropDrawerSub>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure is valid in mobile mode
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles submenu content correctly in mobile mode', () => {
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerSub>
|
||||
<DropDrawerSubTrigger>Mobile Submenu</DropDrawerSubTrigger>
|
||||
<DropDrawerSubContent>
|
||||
<DropDrawerItem>Hidden Item</DropDrawerItem>
|
||||
</DropDrawerSubContent>
|
||||
</DropDrawerSub>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component handles mobile submenu structure correctly
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Accessibility', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('maintains proper ARIA attributes on triggers', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerGroup>
|
||||
<DropDrawerItem>Item 1</DropDrawerItem>
|
||||
</DropDrawerGroup>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
const trigger = screen.getByRole('button')
|
||||
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
|
||||
})
|
||||
|
||||
it('supports disabled state', () => {
|
||||
const handleClick = vi.fn()
|
||||
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem disabled onClick={handleClick}>
|
||||
Disabled Item
|
||||
</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure supports disabled prop
|
||||
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
|
||||
expect(handleClick).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Boundaries', () => {
|
||||
it('requires proper context usage', () => {
|
||||
// Suppress console.error for this test
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
expect(() => {
|
||||
render(<DropDrawerItem>Orphan Item</DropDrawerItem>)
|
||||
}).toThrow()
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Custom Props and Styling', () => {
|
||||
beforeEach(() => {
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('applies custom className', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger className="custom-trigger">Custom Trigger</DropDrawerTrigger>
|
||||
<DropDrawerContent className="custom-content">
|
||||
<DropDrawerItem className="custom-item">Custom Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
const trigger = screen.getByText('Custom Trigger')
|
||||
expect(trigger).toHaveClass('custom-trigger')
|
||||
})
|
||||
|
||||
it('accepts additional props', () => {
|
||||
render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem data-custom="test-value">Custom Props Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Component structure accepts custom props
|
||||
expect(screen.getByText('Open Menu')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Responsive Behavior', () => {
|
||||
it('adapts to different screen sizes', () => {
|
||||
const { rerender } = render(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Responsive Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
// Desktop mode
|
||||
mockUseSmallScreen.mockReturnValue(false)
|
||||
rerender(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Responsive Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
let trigger = screen.getByText('Responsive Trigger')
|
||||
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
|
||||
|
||||
// Mobile mode
|
||||
mockUseSmallScreen.mockReturnValue(true)
|
||||
rerender(
|
||||
<DropDrawer>
|
||||
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
|
||||
<DropDrawerContent>
|
||||
<DropDrawerItem>Responsive Item</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
|
||||
trigger = screen.getByText('Responsive Trigger')
|
||||
expect(trigger).toHaveAttribute('aria-haspopup', 'dialog')
|
||||
})
|
||||
})
|
||||
})
|
||||
62
web-app/src/components/ui/__tests__/radio-group.test.tsx
Normal file
@ -0,0 +1,62 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { RadioGroup, RadioGroupItem } from '../radio-group'
|
||||
|
||||
describe('RadioGroup', () => {
|
||||
it('renders radio items correctly', () => {
|
||||
render(
|
||||
<RadioGroup defaultValue="http">
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="http" id="http" />
|
||||
<label htmlFor="http">HTTP</label>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="sse" id="sse" />
|
||||
<label htmlFor="sse">SSE</label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
)
|
||||
|
||||
expect(screen.getByLabelText('HTTP')).toBeInTheDocument()
|
||||
expect(screen.getByLabelText('SSE')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('allows selecting different options', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onValueChange = vi.fn()
|
||||
|
||||
render(
|
||||
<RadioGroup defaultValue="http" onValueChange={onValueChange}>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="http" id="http" />
|
||||
<label htmlFor="http">HTTP</label>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="sse" id="sse" />
|
||||
<label htmlFor="sse">SSE</label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
)
|
||||
|
||||
await user.click(screen.getByLabelText('SSE'))
|
||||
expect(onValueChange).toHaveBeenCalledWith('sse')
|
||||
})
|
||||
|
||||
it('has correct default selection', () => {
|
||||
render(
|
||||
<RadioGroup defaultValue="http">
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="http" id="http" />
|
||||
<label htmlFor="http">HTTP</label>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="sse" id="sse" />
|
||||
<label htmlFor="sse">SSE</label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
)
|
||||
|
||||
expect(screen.getByLabelText('HTTP')).toBeChecked()
|
||||
expect(screen.getByLabelText('SSE')).not.toBeChecked()
|
||||
})
|
||||
})
|
||||
133
web-app/src/components/ui/drawer.tsx
Normal file
@ -0,0 +1,133 @@
|
||||
import * as React from 'react'
|
||||
import { Drawer as DrawerPrimitive } from 'vaul'
|
||||
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
function Drawer({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Root>) {
|
||||
return <DrawerPrimitive.Root data-slot="drawer" {...props} />
|
||||
}
|
||||
|
||||
function DrawerTrigger({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Trigger>) {
|
||||
return <DrawerPrimitive.Trigger data-slot="drawer-trigger" {...props} />
|
||||
}
|
||||
|
||||
function DrawerPortal({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Portal>) {
|
||||
return <DrawerPrimitive.Portal data-slot="drawer-portal" {...props} />
|
||||
}
|
||||
|
||||
function DrawerClose({
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Close>) {
|
||||
return <DrawerPrimitive.Close data-slot="drawer-close" {...props} />
|
||||
}
|
||||
|
||||
function DrawerOverlay({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Overlay>) {
|
||||
return (
|
||||
<DrawerPrimitive.Overlay
|
||||
data-slot="drawer-overlay"
|
||||
className={cn(
|
||||
'data-[state=open]:animate-in backdrop-blur data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function DrawerContent({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Content>) {
|
||||
return (
|
||||
<DrawerPortal data-slot="drawer-portal">
|
||||
<DrawerOverlay />
|
||||
<DrawerPrimitive.Content
|
||||
data-slot="drawer-content"
|
||||
className={cn(
|
||||
'group/drawer-content fixed z-50 flex h-auto flex-col',
|
||||
'data-[vaul-drawer-direction=top]:inset-x-0 data-[vaul-drawer-direction=top]:top-0 data-[vaul-drawer-direction=top]:mb-24 data-[vaul-drawer-direction=top]:max-h-[80vh] data-[vaul-drawer-direction=top]:rounded-b-lg data-[vaul-drawer-direction=top]:border-b',
|
||||
'data-[vaul-drawer-direction=bottom]:inset-x-0 data-[vaul-drawer-direction=bottom]:bottom-0 data-[vaul-drawer-direction=bottom]:mt-24 data-[vaul-drawer-direction=bottom]:max-h-[80vh] data-[vaul-drawer-direction=bottom]:rounded-t-lg data-[vaul-drawer-direction=bottom]:border-t',
|
||||
'data-[vaul-drawer-direction=right]:inset-y-0 data-[vaul-drawer-direction=right]:right-0 data-[vaul-drawer-direction=right]:w-3/4 data-[vaul-drawer-direction=right]:border-l data-[vaul-drawer-direction=right]:sm:max-w-sm',
|
||||
'data-[vaul-drawer-direction=left]:inset-y-0 data-[vaul-drawer-direction=left]:left-0 data-[vaul-drawer-direction=left]:w-3/4 data-[vaul-drawer-direction=left]:border-r data-[vaul-drawer-direction=left]:sm:max-w-sm',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<div className="bg-main-view-fg/10 mx-auto mt-4 hidden h-2 w-[100px] shrink-0 rounded-full group-data-[vaul-drawer-direction=bottom]/drawer-content:block" />
|
||||
{children}
|
||||
</DrawerPrimitive.Content>
|
||||
</DrawerPortal>
|
||||
)
|
||||
}
|
||||
|
||||
function DrawerHeader({ className, ...props }: React.ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="drawer-header"
|
||||
className={cn(
|
||||
'flex flex-col gap-0.5 p-4 group-data-[vaul-drawer-direction=bottom]/drawer-content:text-center group-data-[vaul-drawer-direction=top]/drawer-content:text-center md:gap-1.5 md:text-left',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function DrawerFooter({ className, ...props }: React.ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="drawer-footer"
|
||||
className={cn('mt-auto flex flex-col gap-2 p-4', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function DrawerTitle({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Title>) {
|
||||
return (
|
||||
<DrawerPrimitive.Title
|
||||
data-slot="drawer-title"
|
||||
className={cn('font-semibold', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function DrawerDescription({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerPrimitive.Description>) {
|
||||
return (
|
||||
<DrawerPrimitive.Description
|
||||
data-slot="drawer-description"
|
||||
className={cn('text-sm', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export {
|
||||
Drawer,
|
||||
DrawerPortal,
|
||||
DrawerOverlay,
|
||||
DrawerTrigger,
|
||||
DrawerClose,
|
||||
DrawerContent,
|
||||
DrawerHeader,
|
||||
DrawerFooter,
|
||||
DrawerTitle,
|
||||
DrawerDescription,
|
||||
}
|
||||
@ -61,14 +61,17 @@ function DropdownMenuGroup({
|
||||
function DropdownMenuItem({
|
||||
className,
|
||||
inset,
|
||||
variant = 'default',
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & {
|
||||
inset?: boolean
|
||||
variant?: 'default' | 'destructive'
|
||||
}) {
|
||||
return (
|
||||
<DropdownMenuPrimitive.Item
|
||||
data-slot="dropdown-menu-item"
|
||||
data-inset={inset}
|
||||
data-variant={variant}
|
||||
className={cn(
|
||||
"relative cursor-pointer hover:bg-main-view-fg/4 flex items-center gap-2 rounded-sm px-2 py-1 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||
className
|
||||
@ -213,7 +216,7 @@ function DropdownMenuSubTrigger({
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<ChevronRightIcon className="ml-auto size-4" />
|
||||
<ChevronRightIcon className="ml-auto size-4 text-main-view-fg/50" />
|
||||
</DropdownMenuPrimitive.SubTrigger>
|
||||
)
|
||||
}
|
||||
|
||||
949
web-app/src/components/ui/dropdrawer.tsx
Normal file
@ -0,0 +1,949 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
'use client'
|
||||
|
||||
import { AnimatePresence, motion } from 'framer-motion'
|
||||
import { ChevronLeftIcon, ChevronRightIcon } from 'lucide-react'
|
||||
import * as React from 'react'
|
||||
|
||||
import {
|
||||
Drawer,
|
||||
DrawerClose,
|
||||
DrawerContent,
|
||||
DrawerFooter,
|
||||
DrawerHeader,
|
||||
DrawerTitle,
|
||||
DrawerTrigger,
|
||||
} from '@/components/ui/drawer'
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubContent,
|
||||
DropdownMenuSubTrigger,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useSmallScreen } from '@/hooks/useMediaQuery'
|
||||
|
||||
const ANIMATION_CONFIG = {
|
||||
variants: {
|
||||
enter: (direction: 'forward' | 'backward') => ({
|
||||
x: direction === 'forward' ? '100%' : '-100%',
|
||||
opacity: 0,
|
||||
}),
|
||||
center: {
|
||||
x: 0,
|
||||
opacity: 1,
|
||||
},
|
||||
exit: (direction: 'forward' | 'backward') => ({
|
||||
x: direction === 'forward' ? '-100%' : '100%',
|
||||
opacity: 0,
|
||||
}),
|
||||
},
|
||||
transition: {
|
||||
duration: 0.3,
|
||||
ease: [0.25, 0.1, 0.25, 1.0],
|
||||
},
|
||||
} as const
|
||||
|
||||
const getMobileItemStyles = (
|
||||
isInsideGroup: boolean,
|
||||
inset?: boolean,
|
||||
variant?: string,
|
||||
disabled?: boolean
|
||||
) => {
|
||||
return cn(
|
||||
'flex cursor-pointer items-center justify-between px-4 py-4 w-full gap-4',
|
||||
!isInsideGroup && 'bg-main-view-fg/50 mx-2 my-1.5 rounded-md',
|
||||
isInsideGroup && 'bg-transparent py-4',
|
||||
inset && 'pl-8',
|
||||
variant === 'destructive' && 'text-destructive',
|
||||
disabled && 'pointer-events-none opacity-50'
|
||||
)
|
||||
}
|
||||
|
||||
const DropDrawerContext = React.createContext<{ isMobile: boolean }>({
|
||||
isMobile: false,
|
||||
})
|
||||
|
||||
const useDropDrawerContext = () => {
|
||||
const context = React.useContext(DropDrawerContext)
|
||||
if (!context) {
|
||||
throw new Error(
|
||||
'DropDrawer components cannot be rendered outside the DropDrawer Context'
|
||||
)
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
const useComponentSelection = () => {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
|
||||
const selectComponent = <T, D>(mobileComponent: T, desktopComponent: D) => {
|
||||
return isMobile ? mobileComponent : desktopComponent
|
||||
}
|
||||
|
||||
return { isMobile, selectComponent }
|
||||
}
|
||||
|
||||
const useGroupDetection = () => {
|
||||
const isInGroup = React.useCallback(
|
||||
(element: HTMLElement | null): boolean => {
|
||||
if (!element) return false
|
||||
|
||||
let parent = element.parentElement
|
||||
while (parent) {
|
||||
if (parent.hasAttribute('data-drop-drawer-group')) {
|
||||
return true
|
||||
}
|
||||
parent = parent.parentElement
|
||||
}
|
||||
return false
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const useGroupState = () => {
|
||||
const { isMobile } = useComponentSelection()
|
||||
const itemRef = React.useRef<HTMLDivElement>(null)
|
||||
const [isInsideGroup, setIsInsideGroup] = React.useState(false)
|
||||
|
||||
React.useEffect(() => {
|
||||
if (!isMobile) return
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
if (itemRef.current) {
|
||||
setIsInsideGroup(isInGroup(itemRef.current))
|
||||
}
|
||||
}, 0)
|
||||
|
||||
return () => clearTimeout(timer)
|
||||
}, [isMobile])
|
||||
|
||||
return { itemRef, isInsideGroup }
|
||||
}
|
||||
|
||||
return { isInGroup, useGroupState }
|
||||
}
|
||||
|
||||
type ConditionalComponentProps<T, D> = {
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
} & (T | D)
|
||||
|
||||
const ConditionalComponent = <T, D>({
|
||||
mobileComponent,
|
||||
desktopComponent,
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
mobileComponent: React.ComponentType<any>
|
||||
desktopComponent: React.ComponentType<any>
|
||||
children: React.ReactNode
|
||||
} & ConditionalComponentProps<T, D>) => {
|
||||
const { selectComponent } = useComponentSelection()
|
||||
const Component = selectComponent(mobileComponent, desktopComponent)
|
||||
|
||||
return <Component {...props}>{children}</Component>
|
||||
}
|
||||
|
||||
function DropDrawer({
|
||||
children,
|
||||
...props
|
||||
}:
|
||||
| React.ComponentProps<typeof Drawer>
|
||||
| React.ComponentProps<typeof DropdownMenu>) {
|
||||
const isMobile = useSmallScreen()
|
||||
|
||||
return (
|
||||
<DropDrawerContext.Provider value={{ isMobile }}>
|
||||
<ConditionalComponent
|
||||
mobileComponent={Drawer}
|
||||
desktopComponent={DropdownMenu}
|
||||
data-slot="drop-drawer"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</ConditionalComponent>
|
||||
</DropDrawerContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerTrigger({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}:
|
||||
| React.ComponentProps<typeof DrawerTrigger>
|
||||
| React.ComponentProps<typeof DropdownMenuTrigger>) {
|
||||
return (
|
||||
<ConditionalComponent
|
||||
mobileComponent={DrawerTrigger}
|
||||
desktopComponent={DropdownMenuTrigger}
|
||||
data-slot="drop-drawer-trigger"
|
||||
className={className}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</ConditionalComponent>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerContent({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}:
|
||||
| React.ComponentProps<typeof DrawerContent>
|
||||
| React.ComponentProps<typeof DropdownMenuContent>) {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
const [activeSubmenu, setActiveSubmenu] = React.useState<string | null>(null)
|
||||
const [submenuTitle, setSubmenuTitle] = React.useState<string | null>(null)
|
||||
const [submenuStack, setSubmenuStack] = React.useState<
|
||||
{ id: string; title: string }[]
|
||||
>([])
|
||||
// Add animation direction state
|
||||
const [animationDirection, setAnimationDirection] = React.useState<
|
||||
'forward' | 'backward'
|
||||
>('forward')
|
||||
|
||||
// Create a ref to store submenu content by ID
|
||||
const submenuContentRef = React.useRef<Map<string, React.ReactNode[]>>(
|
||||
new Map()
|
||||
)
|
||||
|
||||
// Function to navigate to a submenu
|
||||
const navigateToSubmenu = React.useCallback((id: string, title: string) => {
|
||||
// Set animation direction to forward when navigating to a submenu
|
||||
setAnimationDirection('forward')
|
||||
setActiveSubmenu(id)
|
||||
setSubmenuTitle(title)
|
||||
setSubmenuStack((prev) => [...prev, { id, title }])
|
||||
}, [])
|
||||
|
||||
// Function to go back to previous menu
|
||||
const goBack = React.useCallback(() => {
|
||||
// Set animation direction to backward when going back
|
||||
setAnimationDirection('backward')
|
||||
|
||||
if (submenuStack.length <= 1) {
|
||||
// If we're at the first level, go back to main menu
|
||||
setActiveSubmenu(null)
|
||||
setSubmenuTitle(null)
|
||||
setSubmenuStack([])
|
||||
} else {
|
||||
// Go back to previous submenu
|
||||
const newStack = [...submenuStack]
|
||||
newStack.pop() // Remove current
|
||||
const previous = newStack[newStack.length - 1]
|
||||
setActiveSubmenu(previous.id)
|
||||
setSubmenuTitle(previous.title)
|
||||
setSubmenuStack(newStack)
|
||||
}
|
||||
}, [submenuStack])
|
||||
|
||||
// Function to register submenu content
|
||||
const registerSubmenuContent = React.useCallback(
|
||||
(id: string, content: React.ReactNode[]) => {
|
||||
submenuContentRef.current.set(id, content)
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const extractSubmenuContent = React.useCallback(
|
||||
(elements: React.ReactNode, targetId: string): React.ReactNode[] => {
|
||||
const result: React.ReactNode[] = []
|
||||
|
||||
const findSubmenuContent = (node: React.ReactNode) => {
|
||||
if (!React.isValidElement(node)) return
|
||||
|
||||
const element = node as React.ReactElement
|
||||
const props = element.props as {
|
||||
'id'?: string
|
||||
'data-submenu-id'?: string
|
||||
'children'?: React.ReactNode
|
||||
}
|
||||
|
||||
if (element.type === DropDrawerSub) {
|
||||
const elementId = props.id || props['data-submenu-id']
|
||||
|
||||
if (elementId === targetId) {
|
||||
React.Children.forEach(props.children, (child) => {
|
||||
if (
|
||||
React.isValidElement(child) &&
|
||||
child.type === DropDrawerSubContent
|
||||
) {
|
||||
const subContentProps = child.props as {
|
||||
children?: React.ReactNode
|
||||
}
|
||||
React.Children.forEach(
|
||||
subContentProps.children,
|
||||
(contentChild) => {
|
||||
result.push(contentChild)
|
||||
}
|
||||
)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if (props.children) {
|
||||
React.Children.forEach(props.children, findSubmenuContent)
|
||||
}
|
||||
}
|
||||
|
||||
React.Children.forEach(elements, findSubmenuContent)
|
||||
return result
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
// Get submenu content (always extract fresh to reflect state changes)
|
||||
const getSubmenuContent = React.useCallback(
|
||||
(id: string) => {
|
||||
// Always extract fresh content to ensure state changes are reflected
|
||||
const submenuContent = extractSubmenuContent(children, id)
|
||||
return submenuContent
|
||||
},
|
||||
[children, extractSubmenuContent]
|
||||
)
|
||||
|
||||
if (isMobile) {
|
||||
return (
|
||||
<SubmenuContext.Provider
|
||||
value={{
|
||||
activeSubmenu,
|
||||
setActiveSubmenu: (id) => {
|
||||
if (id === null) {
|
||||
setActiveSubmenu(null)
|
||||
setSubmenuTitle(null)
|
||||
setSubmenuStack([])
|
||||
}
|
||||
},
|
||||
submenuTitle,
|
||||
setSubmenuTitle,
|
||||
navigateToSubmenu,
|
||||
registerSubmenuContent,
|
||||
}}
|
||||
>
|
||||
<DrawerContent
|
||||
data-slot="drop-drawer-content"
|
||||
className={cn(
|
||||
'max-h-[90vh] w-full overflow-hidden max-w-none',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{activeSubmenu ? (
|
||||
<>
|
||||
<DrawerHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={goBack}
|
||||
className="hover:bg-muted/50 rounded-full "
|
||||
>
|
||||
<ChevronLeftIcon className="h-5 w-5 text-main-view-fg/50" />
|
||||
</button>
|
||||
<DrawerTitle className="text-main-view-fg/80 text-sm">
|
||||
{submenuTitle || 'Submenu'}
|
||||
</DrawerTitle>
|
||||
</div>
|
||||
</DrawerHeader>
|
||||
<div className="flex-1 relative overflow-hidden max-h-[70vh]">
|
||||
{/* Use AnimatePresence to handle exit animations */}
|
||||
<AnimatePresence
|
||||
initial={false}
|
||||
mode="wait"
|
||||
custom={animationDirection}
|
||||
>
|
||||
<motion.div
|
||||
key={activeSubmenu || 'main'}
|
||||
custom={animationDirection}
|
||||
variants={ANIMATION_CONFIG.variants}
|
||||
initial="enter"
|
||||
animate="center"
|
||||
exit="exit"
|
||||
transition={ANIMATION_CONFIG.transition as any}
|
||||
className="pb-6 space-y-1.5 w-full h-full overflow-hidden"
|
||||
>
|
||||
{activeSubmenu
|
||||
? getSubmenuContent(activeSubmenu)
|
||||
: children}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<DrawerHeader className="sr-only">
|
||||
<DrawerTitle>Menu</DrawerTitle>
|
||||
</DrawerHeader>
|
||||
<div className="overflow-hidden max-h-[70vh]">
|
||||
<AnimatePresence
|
||||
initial={false}
|
||||
mode="wait"
|
||||
custom={animationDirection}
|
||||
>
|
||||
<motion.div
|
||||
key="main-menu"
|
||||
custom={animationDirection}
|
||||
variants={ANIMATION_CONFIG.variants}
|
||||
initial="enter"
|
||||
animate="center"
|
||||
exit="exit"
|
||||
transition={ANIMATION_CONFIG.transition as any}
|
||||
className="pb-6 space-y-1.5 w-full overflow-hidden"
|
||||
>
|
||||
{children}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</DrawerContent>
|
||||
</SubmenuContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<SubmenuContext.Provider
|
||||
value={{
|
||||
activeSubmenu,
|
||||
setActiveSubmenu,
|
||||
submenuTitle,
|
||||
setSubmenuTitle,
|
||||
navigateToSubmenu,
|
||||
registerSubmenuContent,
|
||||
}}
|
||||
>
|
||||
<DropdownMenuContent
|
||||
data-slot="drop-drawer-content"
|
||||
sideOffset={4}
|
||||
className={cn(
|
||||
'max-h-[var(--radix-dropdown-menu-content-available-height)] min-w-[220px] overflow-hidden',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DropdownMenuContent>
|
||||
</SubmenuContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerItem({
|
||||
className,
|
||||
children,
|
||||
onSelect,
|
||||
onClick,
|
||||
icon,
|
||||
variant = 'default',
|
||||
inset,
|
||||
disabled,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuItem> & {
|
||||
icon?: React.ReactNode
|
||||
}) {
|
||||
const { isMobile } = useComponentSelection()
|
||||
const { useGroupState } = useGroupDetection()
|
||||
const { itemRef, isInsideGroup } = useGroupState()
|
||||
|
||||
if (isMobile) {
|
||||
const handleClick = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (disabled) return
|
||||
|
||||
// If this item only has an icon (like a switch) and no other interactive content,
|
||||
// don't handle clicks on the main area - let the icon handle everything
|
||||
if (icon && !onClick && !onSelect) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the click came from the icon area (where the Switch is)
|
||||
const target = e.target as HTMLElement
|
||||
const iconContainer = (e.currentTarget as HTMLElement).querySelector(
|
||||
'[data-icon-container]'
|
||||
)
|
||||
|
||||
if (iconContainer && iconContainer.contains(target)) {
|
||||
// Don't handle the click if it came from the icon area
|
||||
return
|
||||
}
|
||||
|
||||
if (onClick) onClick(e)
|
||||
if (onSelect) onSelect(e as unknown as Event)
|
||||
}
|
||||
|
||||
// Only wrap in DrawerClose if it's not a submenu item
|
||||
const content = (
|
||||
<div
|
||||
ref={itemRef}
|
||||
data-slot="drop-drawer-item"
|
||||
data-variant={variant}
|
||||
data-inset={inset}
|
||||
data-disabled={disabled}
|
||||
className={cn(
|
||||
getMobileItemStyles(isInsideGroup, inset, variant, disabled),
|
||||
className
|
||||
)}
|
||||
onClick={handleClick}
|
||||
aria-disabled={disabled}
|
||||
{...props}
|
||||
>
|
||||
<div className="flex items-center gap-2">{children}</div>
|
||||
{icon && (
|
||||
<div className="flex-shrink-0" data-icon-container>
|
||||
{icon}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
|
||||
// Check if this is inside a submenu
|
||||
const isInSubmenu =
|
||||
(props as Record<string, unknown>)['data-parent-submenu-id'] ||
|
||||
(props as Record<string, unknown>)['data-parent-submenu']
|
||||
|
||||
if (isInSubmenu) {
|
||||
return content
|
||||
}
|
||||
|
||||
return <DrawerClose asChild>{content}</DrawerClose>
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenuItem
|
||||
data-slot="drop-drawer-item"
|
||||
data-variant={variant}
|
||||
data-inset={inset}
|
||||
className={className}
|
||||
onSelect={onSelect}
|
||||
onClick={onClick as React.MouseEventHandler<HTMLDivElement>}
|
||||
variant={variant}
|
||||
inset={inset}
|
||||
disabled={disabled}
|
||||
{...props}
|
||||
>
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div>{children}</div>
|
||||
{icon && <div>{icon}</div>}
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerSeparator({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuSeparator>) {
|
||||
const { isMobile } = useComponentSelection()
|
||||
|
||||
if (isMobile) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenuSeparator
|
||||
data-slot="drop-drawer-separator"
|
||||
className={className}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerLabel({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}:
|
||||
| React.ComponentProps<typeof DropdownMenuLabel>
|
||||
| React.ComponentProps<typeof DrawerTitle>) {
|
||||
const { isMobile } = useComponentSelection()
|
||||
|
||||
if (isMobile) {
|
||||
return (
|
||||
<DrawerHeader className="p-0">
|
||||
<DrawerTitle
|
||||
data-slot="drop-drawer-label"
|
||||
className={cn(
|
||||
'px-4 py-2 text-sm font-medium text-main-view-fg/60',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DrawerTitle>
|
||||
</DrawerHeader>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenuLabel
|
||||
data-slot="drop-drawer-label"
|
||||
className={className}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DropdownMenuLabel>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerFooter({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DrawerFooter> | React.ComponentProps<'div'>) {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
|
||||
if (isMobile) {
|
||||
return (
|
||||
<DrawerFooter
|
||||
data-slot="drop-drawer-footer"
|
||||
className={cn('p-4', className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DrawerFooter>
|
||||
)
|
||||
}
|
||||
|
||||
// No direct equivalent in DropdownMenu, so we'll just render a div
|
||||
return (
|
||||
<div
|
||||
data-slot="drop-drawer-footer"
|
||||
className={cn('p-2', className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerGroup({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<'div'> & {
|
||||
children: React.ReactNode
|
||||
}) {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
|
||||
// Add separators between children on mobile
|
||||
const childrenWithSeparators = React.useMemo(() => {
|
||||
if (!isMobile) return children
|
||||
|
||||
const childArray = React.Children.toArray(children)
|
||||
|
||||
// Filter out any existing separators
|
||||
const filteredChildren = childArray.filter(
|
||||
(child) =>
|
||||
React.isValidElement(child) && child.type !== DropDrawerSeparator
|
||||
)
|
||||
|
||||
// Add separators between items
|
||||
return filteredChildren.flatMap((child, index) => {
|
||||
if (index === filteredChildren.length - 1) return [child]
|
||||
return [
|
||||
child,
|
||||
<div
|
||||
key={`separator-${index}`}
|
||||
className="bg-border h-px"
|
||||
aria-hidden="true"
|
||||
/>,
|
||||
]
|
||||
})
|
||||
}, [children, isMobile])
|
||||
|
||||
if (isMobile) {
|
||||
return (
|
||||
<div
|
||||
data-drop-drawer-group
|
||||
data-slot="drop-drawer-group"
|
||||
role="group"
|
||||
className={cn(
|
||||
'bg-main-view-fg/2 border border-main-view-fg/4 mx-2 my-3 overflow-hidden rounded-xl',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{childrenWithSeparators}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// On desktop, use a div with proper role and attributes
|
||||
return (
|
||||
<div
|
||||
data-drop-drawer-group
|
||||
data-slot="drop-drawer-group"
|
||||
role="group"
|
||||
className={className}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Context for managing submenu state on mobile
|
||||
interface SubmenuContextType {
|
||||
activeSubmenu: string | null
|
||||
setActiveSubmenu: (id: string | null) => void
|
||||
submenuTitle: string | null
|
||||
setSubmenuTitle: (title: string | null) => void
|
||||
navigateToSubmenu?: (id: string, title: string) => void
|
||||
registerSubmenuContent?: (id: string, content: React.ReactNode[]) => void
|
||||
}
|
||||
|
||||
const SubmenuContext = React.createContext<SubmenuContextType>({
|
||||
activeSubmenu: null,
|
||||
setActiveSubmenu: () => {},
|
||||
submenuTitle: null,
|
||||
setSubmenuTitle: () => {},
|
||||
navigateToSubmenu: undefined,
|
||||
registerSubmenuContent: undefined,
|
||||
})
|
||||
|
||||
// Submenu components
|
||||
// Counter for generating simple numeric IDs
|
||||
let submenuIdCounter = 0
|
||||
|
||||
function DropDrawerSub({
|
||||
children,
|
||||
id,
|
||||
title,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuSub> & {
|
||||
id?: string
|
||||
title?: string
|
||||
}) {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
const { registerSubmenuContent } = React.useContext(SubmenuContext)
|
||||
|
||||
// Generate a simple numeric ID instead of using React.useId()
|
||||
const [generatedId] = React.useState(() => `submenu-${submenuIdCounter++}`)
|
||||
const submenuId = id || generatedId
|
||||
|
||||
// Extract submenu content to register with parent
|
||||
React.useEffect(() => {
|
||||
if (!registerSubmenuContent) return
|
||||
|
||||
// Find the SubContent within this Sub
|
||||
const contentItems: React.ReactNode[] = []
|
||||
React.Children.forEach(children, (child) => {
|
||||
if (React.isValidElement(child) && child.type === DropDrawerSubContent) {
|
||||
// Add all children of the SubContent to the result
|
||||
React.Children.forEach(
|
||||
(child.props as { children?: React.ReactNode }).children,
|
||||
(contentChild) => {
|
||||
contentItems.push(contentChild)
|
||||
}
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Register the content with the parent
|
||||
if (contentItems.length > 0) {
|
||||
registerSubmenuContent(submenuId, contentItems)
|
||||
}
|
||||
}, [children, registerSubmenuContent, submenuId])
|
||||
|
||||
if (isMobile) {
|
||||
// For mobile, we'll use the context to manage submenu state
|
||||
// Process children to pass the submenu ID to the trigger and content
|
||||
const processedChildren = React.Children.map(children, (child) => {
|
||||
if (!React.isValidElement(child)) return child
|
||||
|
||||
if (child.type === DropDrawerSubTrigger) {
|
||||
return React.cloneElement(
|
||||
child as React.ReactElement,
|
||||
{
|
||||
...(child.props as object),
|
||||
'data-parent-submenu-id': submenuId,
|
||||
'data-submenu-id': submenuId,
|
||||
// Use only data attributes, not custom props
|
||||
'data-parent-submenu': submenuId,
|
||||
'data-submenu-title': title,
|
||||
} as React.HTMLAttributes<HTMLElement>
|
||||
)
|
||||
}
|
||||
|
||||
if (child.type === DropDrawerSubContent) {
|
||||
return React.cloneElement(
|
||||
child as React.ReactElement,
|
||||
{
|
||||
...(child.props as object),
|
||||
'data-parent-submenu-id': submenuId,
|
||||
'data-submenu-id': submenuId,
|
||||
// Use only data attributes, not custom props
|
||||
'data-parent-submenu': submenuId,
|
||||
} as React.HTMLAttributes<HTMLElement>
|
||||
)
|
||||
}
|
||||
|
||||
return child
|
||||
})
|
||||
|
||||
return (
|
||||
<div
|
||||
data-slot="drop-drawer-sub"
|
||||
data-submenu-id={submenuId}
|
||||
id={submenuId}
|
||||
>
|
||||
{processedChildren}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// For desktop, use the standard DropdownMenuSub
|
||||
return <DropdownMenuSub {...props}>{children}</DropdownMenuSub>
|
||||
}
|
||||
|
||||
function DropDrawerSubTrigger({
|
||||
className,
|
||||
inset,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuSubTrigger> & {
|
||||
icon?: React.ReactNode
|
||||
}) {
|
||||
const { isMobile } = useComponentSelection()
|
||||
const { navigateToSubmenu } = React.useContext(SubmenuContext)
|
||||
const { useGroupState } = useGroupDetection()
|
||||
const { itemRef, isInsideGroup } = useGroupState()
|
||||
|
||||
if (isMobile) {
|
||||
// Find the parent submenu ID
|
||||
const handleClick = (e: React.MouseEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
|
||||
// Get the closest parent with data-submenu-id attribute
|
||||
const element = e.currentTarget as HTMLElement
|
||||
let submenuId: string | null = null
|
||||
|
||||
// First check if the element itself has the data attribute
|
||||
if (element.closest('[data-submenu-id]')) {
|
||||
const closestElement = element.closest('[data-submenu-id]')
|
||||
const id = closestElement?.getAttribute('data-submenu-id')
|
||||
if (id) {
|
||||
submenuId = id
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try props
|
||||
if (!submenuId) {
|
||||
submenuId =
|
||||
((props as Record<string, unknown>)[
|
||||
'data-parent-submenu-id'
|
||||
] as string) ||
|
||||
((props as Record<string, unknown>)['data-parent-submenu'] as string)
|
||||
}
|
||||
|
||||
if (!submenuId) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the title - first try data attribute, then children, then fallback
|
||||
const dataTitle = (props as Record<string, unknown>)[
|
||||
'data-submenu-title'
|
||||
] as string
|
||||
const title =
|
||||
dataTitle || (typeof children === 'string' ? children : 'Submenu')
|
||||
|
||||
// Navigate to the submenu
|
||||
if (navigateToSubmenu) {
|
||||
navigateToSubmenu(submenuId, title)
|
||||
}
|
||||
}
|
||||
|
||||
// Combine onClick handlers
|
||||
const combinedOnClick = (e: React.MouseEvent) => {
|
||||
// Call the original onClick if provided
|
||||
const typedProps = props as Record<string, unknown>
|
||||
if (typedProps.onClick) {
|
||||
const originalOnClick =
|
||||
typedProps.onClick as React.MouseEventHandler<HTMLDivElement>
|
||||
originalOnClick(e as React.MouseEvent<HTMLDivElement>)
|
||||
}
|
||||
|
||||
// Call our navigation handler
|
||||
handleClick(e)
|
||||
}
|
||||
|
||||
// Remove onClick from props to avoid duplicate handlers
|
||||
const { ...restProps } = props as Record<string, unknown>
|
||||
|
||||
// Don't wrap in DrawerClose for submenu triggers
|
||||
return (
|
||||
<div
|
||||
ref={itemRef}
|
||||
data-slot="drop-drawer-sub-trigger"
|
||||
data-inset={inset}
|
||||
className={cn(getMobileItemStyles(isInsideGroup, inset), className)}
|
||||
onClick={combinedOnClick}
|
||||
{...restProps}
|
||||
>
|
||||
<div className="flex items-center gap-2 w-full">{children}</div>
|
||||
<ChevronRightIcon className="h-5 w-5 text-main-view-fg/50 " />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenuSubTrigger
|
||||
data-slot="drop-drawer-sub-trigger"
|
||||
data-inset={inset}
|
||||
className={className}
|
||||
inset={inset}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DropdownMenuSubTrigger>
|
||||
)
|
||||
}
|
||||
|
||||
function DropDrawerSubContent({
|
||||
className,
|
||||
sideOffset = 4,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DropdownMenuSubContent>) {
|
||||
const { isMobile } = useDropDrawerContext()
|
||||
|
||||
if (isMobile) {
|
||||
// For mobile, we don't render the content directly
|
||||
// It will be rendered by the DropDrawerContent component when active
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenuSubContent
|
||||
data-slot="drop-drawer-sub-content"
|
||||
sideOffset={sideOffset}
|
||||
className={className}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</DropdownMenuSubContent>
|
||||
)
|
||||
}
|
||||
|
||||
export {
|
||||
DropDrawer,
|
||||
DropDrawerContent,
|
||||
DropDrawerFooter,
|
||||
DropDrawerGroup,
|
||||
DropDrawerItem,
|
||||
DropDrawerLabel,
|
||||
DropDrawerSeparator,
|
||||
DropDrawerSub,
|
||||
DropDrawerSubContent,
|
||||
DropDrawerSubTrigger,
|
||||
DropDrawerTrigger,
|
||||
}
|
||||
42
web-app/src/components/ui/radio-group.tsx
Normal file
@ -0,0 +1,42 @@
|
||||
import * as React from 'react'
|
||||
import * as RadioGroupPrimitive from '@radix-ui/react-radio-group'
|
||||
import { CircleIcon } from 'lucide-react'
|
||||
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
const RadioGroup = React.forwardRef<
|
||||
React.ElementRef<typeof RadioGroupPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Root>
|
||||
>(({ className, ...props }, ref) => {
|
||||
return (
|
||||
<RadioGroupPrimitive.Root
|
||||
className={cn('grid gap-2', className)}
|
||||
{...props}
|
||||
ref={ref}
|
||||
/>
|
||||
)
|
||||
})
|
||||
RadioGroup.displayName = RadioGroupPrimitive.Root.displayName
|
||||
|
||||
const RadioGroupItem = React.forwardRef<
|
||||
React.ElementRef<typeof RadioGroupPrimitive.Item>,
|
||||
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Item>
|
||||
>(({ className, ...props }, ref) => {
|
||||
return (
|
||||
<RadioGroupPrimitive.Item
|
||||
ref={ref}
|
||||
className={cn(
|
||||
'aspect-square size-4 rounded-full border border-main-view-fg/20 text-main-view-fg ring-offset-main-view focus:outline-none focus-visible:ring-2 focus-visible:ring-main-view-fg/50 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50',
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<RadioGroupPrimitive.Indicator className="flex items-center justify-center">
|
||||
<CircleIcon className="size-2.5 fill-current text-current" />
|
||||
</RadioGroupPrimitive.Indicator>
|
||||
</RadioGroupPrimitive.Item>
|
||||
)
|
||||
})
|
||||
RadioGroupItem.displayName = RadioGroupPrimitive.Item.displayName
|
||||
|
||||
export { RadioGroup, RadioGroupItem }
|
||||
@ -19,5 +19,6 @@ export const localStorageKey = {
|
||||
mcpGlobalPermissions: 'mcp-global-permissions',
|
||||
lastUsedModel: 'last-used-model',
|
||||
lastUsedAssistant: 'last-used-assistant',
|
||||
favoriteModels: 'favorite-models',
|
||||
setupCompleted: 'setup-completed',
|
||||
}
|
||||
|
||||
@ -483,7 +483,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
||||
initialMessage={initialMessage}
|
||||
onOpenChange={(isOpen) => {
|
||||
setDropdownToolsAvailable(isOpen)
|
||||
setTooltipToolsAvailable(false)
|
||||
if (isOpen) {
|
||||
setTooltipToolsAvailable(false)
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(isOpen, toolsCount) => {
|
||||
|
||||
@ -6,10 +6,8 @@ import {
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||
import { useLeftPanel } from '@/hooks/useLeftPanel'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { useAppUpdater } from '@/hooks/useAppUpdater'
|
||||
import { abortDownload } from '@/services/models'
|
||||
import { getProviders } from '@/services/providers'
|
||||
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
|
||||
import { IconDownload, IconX } from '@tabler/icons-react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
@ -18,7 +16,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
|
||||
export function DownloadManagement() {
|
||||
const { t } = useTranslation()
|
||||
const { setProviders } = useModelProvider()
|
||||
const { open: isLeftPanelOpen } = useLeftPanel()
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false)
|
||||
const {
|
||||
@ -185,7 +182,6 @@ export function DownloadManagement() {
|
||||
console.debug('onFileDownloadSuccess', state)
|
||||
removeDownload(state.modelId)
|
||||
removeLocalDownloadingModel(state.modelId)
|
||||
getProviders().then(setProviders)
|
||||
toast.success(t('common:toast.downloadComplete.title'), {
|
||||
id: 'download-complete',
|
||||
description: t('common:toast.downloadComplete.description', {
|
||||
@ -193,7 +189,7 @@ export function DownloadManagement() {
|
||||
}),
|
||||
})
|
||||
},
|
||||
[removeDownload, removeLocalDownloadingModel, setProviders, t]
|
||||
[removeDownload, removeLocalDownloadingModel, t]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@ -17,6 +17,8 @@ import ProvidersAvatar from '@/containers/ProvidersAvatar'
|
||||
import { Fzf } from 'fzf'
|
||||
import { localStorageKey } from '@/constants/localStorage'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
||||
import { predefinedProviders } from '@/consts/providers'
|
||||
|
||||
type DropdownModelProviderProps = {
|
||||
model?: ThreadModel
|
||||
@ -69,42 +71,38 @@ const DropdownModelProvider = ({
|
||||
const { updateCurrentThreadModel } = useThreads()
|
||||
const navigate = useNavigate()
|
||||
const { t } = useTranslation()
|
||||
const { favoriteModels } = useFavoriteModel()
|
||||
|
||||
// Search state
|
||||
const [open, setOpen] = useState(false)
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const searchInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
// Helper function to check if a model exists in providers
|
||||
const checkModelExists = useCallback((providerName: string, modelId: string) => {
|
||||
const provider = providers.find(
|
||||
(p) => p.provider === providerName && p.active
|
||||
)
|
||||
return provider?.models.find((m) => m.id === modelId)
|
||||
}, [providers])
|
||||
|
||||
// Initialize model provider only once
|
||||
useEffect(() => {
|
||||
// Auto select model when existing thread is passed
|
||||
if (model) {
|
||||
selectModelProvider(model?.provider as string, model?.id as string)
|
||||
if (!checkModelExists(model.provider, model.id)) {
|
||||
selectModelProvider('', '')
|
||||
}
|
||||
} else if (useLastUsedModel) {
|
||||
// Try to use last used model only when explicitly requested (for new chat)
|
||||
const lastUsed = getLastUsedModel()
|
||||
if (lastUsed) {
|
||||
// Verify the last used model still exists
|
||||
const provider = providers.find(
|
||||
(p) => p.provider === lastUsed.provider && p.active
|
||||
)
|
||||
const modelExists = provider?.models.find(
|
||||
(m) => m.id === lastUsed.model
|
||||
)
|
||||
|
||||
if (provider && modelExists) {
|
||||
selectModelProvider(lastUsed.provider, lastUsed.model)
|
||||
} else {
|
||||
// Fallback to default model if last used model no longer exists
|
||||
selectModelProvider('llamacpp', 'llama3.2:3b')
|
||||
}
|
||||
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
|
||||
selectModelProvider(lastUsed.provider, lastUsed.model)
|
||||
} else {
|
||||
// default model, we should add from setting
|
||||
selectModelProvider('llamacpp', 'llama3.2:3b')
|
||||
// Fallback to default model if last used model no longer exists
|
||||
selectModelProvider('', '')
|
||||
}
|
||||
} else {
|
||||
// default model for non-new-chat contexts
|
||||
selectModelProvider('llamacpp', 'llama3.2:3b')
|
||||
}
|
||||
}, [
|
||||
model,
|
||||
@ -112,6 +110,7 @@ const DropdownModelProvider = ({
|
||||
updateCurrentThreadModel,
|
||||
providers,
|
||||
useLastUsedModel,
|
||||
checkModelExists,
|
||||
])
|
||||
|
||||
// Update display model when selection changes
|
||||
@ -151,9 +150,15 @@ const DropdownModelProvider = ({
|
||||
|
||||
provider.models.forEach((modelItem) => {
|
||||
// Skip models that require API key but don't have one (except llamacpp)
|
||||
if (provider.provider !== 'llamacpp' && !provider.api_key?.length) {
|
||||
if (
|
||||
provider &&
|
||||
predefinedProviders.some((e) =>
|
||||
e.provider.includes(provider.provider)
|
||||
) &&
|
||||
provider.provider !== 'llamacpp' &&
|
||||
!provider.api_key?.length
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const capabilities = modelItem.capabilities || []
|
||||
const capabilitiesString = capabilities.join(' ')
|
||||
@ -182,6 +187,13 @@ const DropdownModelProvider = ({
|
||||
})
|
||||
}, [searchableItems])
|
||||
|
||||
// Get favorite models that are currently available
|
||||
const favoriteItems = useMemo(() => {
|
||||
return searchableItems.filter((item) =>
|
||||
favoriteModels.some((fav) => fav.id === item.model.id)
|
||||
)
|
||||
}, [searchableItems, favoriteModels])
|
||||
|
||||
// Filter models based on search value
|
||||
const filteredItems = useMemo(() => {
|
||||
if (!searchValue) return searchableItems
|
||||
@ -202,7 +214,7 @@ const DropdownModelProvider = ({
|
||||
})
|
||||
}, [searchableItems, searchValue, fzfInstance])
|
||||
|
||||
// Group filtered items by provider
|
||||
// Group filtered items by provider, excluding favorites when not searching
|
||||
const groupedItems = useMemo(() => {
|
||||
const groups: Record<string, SearchableModel[]> = {}
|
||||
|
||||
@ -221,11 +233,16 @@ const DropdownModelProvider = ({
|
||||
if (!groups[providerKey]) {
|
||||
groups[providerKey] = []
|
||||
}
|
||||
|
||||
// When not searching, exclude favorite models from regular provider sections
|
||||
const isFavorite = favoriteModels.some((fav) => fav.id === item.model.id)
|
||||
if (!searchValue && isFavorite) return // Skip adding this item to regular provider section
|
||||
|
||||
groups[providerKey].push(item)
|
||||
})
|
||||
|
||||
return groups
|
||||
}, [filteredItems, providers, searchValue])
|
||||
}, [filteredItems, providers, searchValue, favoriteModels])
|
||||
|
||||
const handleSelect = useCallback(
|
||||
(searchableModel: SearchableModel) => {
|
||||
@ -330,6 +347,64 @@ const DropdownModelProvider = ({
|
||||
</div>
|
||||
) : (
|
||||
<div className="py-1">
|
||||
{/* Favorites section - only show when not searching */}
|
||||
{!searchValue && favoriteItems.length > 0 && (
|
||||
<div className="bg-main-view-fg/2 backdrop-blur-2xl rounded-sm my-1.5 mx-1.5">
|
||||
{/* Favorites header */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1">
|
||||
<span className="text-sm font-medium text-main-view-fg/80">
|
||||
{t('common:favorites')}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Favorite models */}
|
||||
{favoriteItems.map((searchableModel) => {
|
||||
const isSelected =
|
||||
selectedModel?.id === searchableModel.model.id &&
|
||||
selectedProvider === searchableModel.provider.provider
|
||||
const capabilities =
|
||||
searchableModel.model.capabilities || []
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`fav-${searchableModel.value}`}
|
||||
title={searchableModel.model.id}
|
||||
onClick={() => handleSelect(searchableModel)}
|
||||
className={cn(
|
||||
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
|
||||
'hover:bg-main-view-fg/4',
|
||||
isSelected &&
|
||||
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-1 flex-1 min-w-0">
|
||||
<div className="shrink-0 -ml-1">
|
||||
<ProvidersAvatar
|
||||
provider={searchableModel.provider}
|
||||
/>
|
||||
</div>
|
||||
<span className="truncate text-main-view-fg/80 text-sm">
|
||||
{searchableModel.model.id}
|
||||
</span>
|
||||
<div className="flex-1"></div>
|
||||
{capabilities.length > 0 && (
|
||||
<div className="flex-shrink-0 -mr-1.5">
|
||||
<Capabilities capabilities={capabilities} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Divider between favorites and regular providers */}
|
||||
{favoriteItems.length > 0 && (
|
||||
<div className="border-b border-1 border-main-view-fg/8 mx-2"></div>
|
||||
)}
|
||||
|
||||
{/* Regular provider sections */}
|
||||
{Object.entries(groupedItems).map(([providerKey, models]) => {
|
||||
const providerInfo = providers.find(
|
||||
(p) => p.provider === providerKey
|
||||
@ -340,7 +415,7 @@ const DropdownModelProvider = ({
|
||||
return (
|
||||
<div
|
||||
key={providerKey}
|
||||
className="bg-main-view-fg/4 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0"
|
||||
className="bg-main-view-fg/2 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0"
|
||||
>
|
||||
{/* Provider header */}
|
||||
<div className="flex items-center justify-between px-2 py-1">
|
||||
@ -384,15 +459,20 @@ const DropdownModelProvider = ({
|
||||
return (
|
||||
<div
|
||||
key={searchableModel.value}
|
||||
title={searchableModel.model.id}
|
||||
onClick={() => handleSelect(searchableModel)}
|
||||
className={cn(
|
||||
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
|
||||
'hover:bg-main-view-fg/10',
|
||||
isSelected && 'bg-main-view-fg/15'
|
||||
'hover:bg-main-view-fg/4',
|
||||
isSelected &&
|
||||
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<span className="truncate text-main-view-fg/80 text-sm">
|
||||
<span
|
||||
className="truncate text-main-view-fg/80 text-sm"
|
||||
title={searchableModel.model.id}
|
||||
>
|
||||
{searchableModel.model.id}
|
||||
</span>
|
||||
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
DropDrawer,
|
||||
DropDrawerContent,
|
||||
DropDrawerItem,
|
||||
DropDrawerSub,
|
||||
DropDrawerLabel,
|
||||
DropDrawerSubContent,
|
||||
DropDrawerSeparator,
|
||||
DropDrawerSubTrigger,
|
||||
DropDrawerTrigger,
|
||||
DropDrawerGroup,
|
||||
} from '@/components/ui/dropdrawer'
|
||||
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
@ -15,6 +21,7 @@ import { useToolAvailable } from '@/hooks/useToolAvailable'
|
||||
import React from 'react'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
interface DropdownToolsAvailableProps {
|
||||
children: (isOpen: boolean, toolsCount: number) => React.ReactNode
|
||||
@ -82,6 +89,23 @@ export default function DropdownToolsAvailable({
|
||||
return false
|
||||
}
|
||||
|
||||
const handleDisableAllServerTools = (
|
||||
serverName: string,
|
||||
disable: boolean
|
||||
) => {
|
||||
const allToolsByServer = getToolsByServer()
|
||||
const serverTools = allToolsByServer[serverName] || []
|
||||
serverTools.forEach((tool) => {
|
||||
handleToolToggle(tool.name, !disable)
|
||||
})
|
||||
}
|
||||
|
||||
const areAllServerToolsDisabled = (serverName: string): boolean => {
|
||||
const allToolsByServer = getToolsByServer()
|
||||
const serverTools = allToolsByServer[serverName] || []
|
||||
return serverTools.every((tool) => !isToolChecked(tool.name))
|
||||
}
|
||||
|
||||
const getEnabledToolsCount = (): number => {
|
||||
const disabledTools = initialMessage
|
||||
? getDefaultDisabledTools()
|
||||
@ -91,69 +115,153 @@ export default function DropdownToolsAvailable({
|
||||
return tools.filter((tool) => !disabledTools.includes(tool.name)).length
|
||||
}
|
||||
|
||||
const getToolsByServer = () => {
|
||||
const toolsByServer = tools.reduce(
|
||||
(acc, tool) => {
|
||||
if (!acc[tool.server]) {
|
||||
acc[tool.server] = []
|
||||
}
|
||||
acc[tool.server].push(tool)
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, typeof tools>
|
||||
)
|
||||
|
||||
return toolsByServer
|
||||
}
|
||||
|
||||
const renderTrigger = () => children(isOpen, getEnabledToolsCount())
|
||||
|
||||
if (tools.length === 0) {
|
||||
return (
|
||||
<DropdownMenu onOpenChange={handleOpenChange}>
|
||||
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="max-w-64">
|
||||
<DropdownMenuItem disabled>{t('common:noToolsAvailable')}</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
<DropDrawer onOpenChange={handleOpenChange}>
|
||||
<DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
|
||||
<DropDrawerContent align="start" className="max-w-64">
|
||||
<DropDrawerItem disabled>
|
||||
{t('common:noToolsAvailable')}
|
||||
</DropDrawerItem>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenu onOpenChange={handleOpenChange}>
|
||||
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
|
||||
const toolsByServer = getToolsByServer()
|
||||
|
||||
<DropdownMenuContent
|
||||
return (
|
||||
<DropDrawer onOpenChange={handleOpenChange}>
|
||||
<DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
|
||||
<DropDrawerContent
|
||||
side="top"
|
||||
align="start"
|
||||
className="max-w-64 backdrop-blur-xl bg-main-view"
|
||||
className="bg-main-view !overflow-hidden"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<DropdownMenuLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-2 ">
|
||||
<DropDrawerLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-1">
|
||||
Available Tools
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuSeparator />
|
||||
</DropDrawerLabel>
|
||||
<DropDrawerSeparator />
|
||||
<div className="max-h-64 overflow-y-auto">
|
||||
{tools.map((tool) => {
|
||||
const isChecked = isToolChecked(tool.name)
|
||||
return (
|
||||
<div
|
||||
key={tool.name}
|
||||
className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full"
|
||||
<DropDrawerGroup>
|
||||
{Object.entries(toolsByServer).map(([serverName, serverTools]) => (
|
||||
<DropDrawerSub
|
||||
id={`server-${serverName}`}
|
||||
key={serverName}
|
||||
title={serverName}
|
||||
>
|
||||
<div className="flex items-start justify-center gap-3">
|
||||
<div className="flex items-start justify-between gap-4 w-full">
|
||||
<div className="overflow-hidden w-full flex flex-col ">
|
||||
<div className="truncate">
|
||||
<span className="text-sm font-medium" title={tool.name}>
|
||||
{tool.name}
|
||||
</span>
|
||||
</div>
|
||||
{tool.description && (
|
||||
<p className="text-xs text-main-view-fg/70 mt-1 line-clamp-2">
|
||||
{tool.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="shrink-0 mx-auto">
|
||||
<Switch
|
||||
checked={isChecked}
|
||||
onCheckedChange={(checked) =>
|
||||
handleToolToggle(tool.name, checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<DropDrawerSubTrigger className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<span className="text-sm text-main-view-fg/80">
|
||||
{serverName}
|
||||
</span>
|
||||
<span className="text-xs text-main-view-fg/50 inline-flex items-center mr-1 border border-main-view-fg/20 px-1 rounded-sm">
|
||||
{
|
||||
serverTools.filter((tool) => isToolChecked(tool.name))
|
||||
.length
|
||||
}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</DropDrawerSubTrigger>
|
||||
<DropDrawerSubContent className="max-w-64 max-h-70 w-full overflow-hidden">
|
||||
<DropDrawerGroup>
|
||||
{serverTools.length > 1 && (
|
||||
<div className="sticky top-0 z-10 bg-main-view border-b border-main-view-fg/10 px-4 md:px-2 pr-2 py-1.5 flex items-center justify-between">
|
||||
<span className="text-xs font-medium text-main-view-fg/70">
|
||||
Disable All Tools
|
||||
</span>
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center gap-2',
|
||||
serverTools.length > 5
|
||||
? 'mr-3 md:mr-1.5'
|
||||
: 'mr-2 md:mr-0'
|
||||
)}
|
||||
>
|
||||
<Switch
|
||||
checked={!areAllServerToolsDisabled(serverName)}
|
||||
onCheckedChange={(checked) =>
|
||||
handleDisableAllServerTools(serverName, !checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="max-h-56 overflow-y-auto">
|
||||
{serverTools.map((tool) => {
|
||||
const isChecked = isToolChecked(tool.name)
|
||||
return (
|
||||
<DropDrawerItem
|
||||
onClick={(e) => {
|
||||
handleToolToggle(tool.name, !isChecked)
|
||||
e.preventDefault()
|
||||
}}
|
||||
onSelect={(e) => {
|
||||
handleToolToggle(tool.name, !isChecked)
|
||||
e.preventDefault()
|
||||
}}
|
||||
key={tool.name}
|
||||
className="mt-1 first:mt-0 py-1.5"
|
||||
icon={
|
||||
<Switch
|
||||
checked={isChecked}
|
||||
onCheckedChange={(checked) => {
|
||||
console.log('checked', checked)
|
||||
handleToolToggle(tool.name, checked)
|
||||
}}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
}}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="overflow-hidden flex flex-col items-start ">
|
||||
<div className="truncate">
|
||||
<span
|
||||
className="text-sm font-medium text-main-view-fg"
|
||||
title={tool.name}
|
||||
>
|
||||
{tool.name}
|
||||
</span>
|
||||
</div>
|
||||
{tool.description && (
|
||||
<p
|
||||
className="text-xs text-main-view-fg/70 mt-1 line-clamp-1"
|
||||
title={tool.description}
|
||||
>
|
||||
{tool.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</DropDrawerItem>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</DropDrawerGroup>
|
||||
</DropDrawerSubContent>
|
||||
</DropDrawerSub>
|
||||
))}
|
||||
</DropDrawerGroup>
|
||||
</div>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</DropDrawerContent>
|
||||
</DropDrawer>
|
||||
)
|
||||
}
|
||||
|
||||
25
web-app/src/containers/FavoriteModelAction.tsx
Normal file
@ -0,0 +1,25 @@
|
||||
import { IconStar, IconStarFilled } from '@tabler/icons-react'
|
||||
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
||||
|
||||
interface FavoriteModelActionProps {
|
||||
model: Model
|
||||
}
|
||||
|
||||
export function FavoriteModelAction({ model }: FavoriteModelActionProps) {
|
||||
const { isFavorite, toggleFavorite } = useFavoriteModel()
|
||||
const isModelFavorite = isFavorite(model.id)
|
||||
|
||||
return (
|
||||
<div
|
||||
aria-label="Toggle favorite"
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={() => toggleFavorite(model)}
|
||||
>
|
||||
{isModelFavorite ? (
|
||||
<IconStarFilled size={18} className="text-main-view-fg" />
|
||||
) : (
|
||||
<IconStar size={18} className="text-main-view-fg/50" />
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -21,15 +21,15 @@ export default function GlobalError({ error }: GlobalErrorProps) {
|
||||
>
|
||||
<path
|
||||
d="M6 8H6.01M6 16H6.01M6 12H18C20.2091 12 22 10.2091 22 8C22 5.79086 20.2091 4 18 4H6C3.79086 4 2 5.79086 2 8C2 10.2091 3.79086 12 6 12ZM6 12C3.79086 12 2 13.7909 2 16C2 18.2091 3.79086 20 6 20H14"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
></path>
|
||||
<path
|
||||
d="M17 16L22 21M22 16L17 21"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
@ -8,7 +8,13 @@ import {
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { IconPlus, IconTrash, IconGripVertical } from '@tabler/icons-react'
|
||||
import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'
|
||||
import {
|
||||
IconPlus,
|
||||
IconTrash,
|
||||
IconGripVertical,
|
||||
IconCodeDots,
|
||||
} from '@tabler/icons-react'
|
||||
import { MCPServerConfig } from '@/hooks/useMCPServers'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import {
|
||||
@ -27,6 +33,8 @@ import {
|
||||
} from '@dnd-kit/sortable'
|
||||
import { CSS } from '@dnd-kit/utilities'
|
||||
import { cn } from '@/lib/utils'
|
||||
import CodeEditor from '@uiw/react-textarea-code-editor'
|
||||
import '@uiw/react-textarea-code-editor/dist.css'
|
||||
|
||||
// Sortable argument item component
|
||||
function SortableArgItem({
|
||||
@ -114,13 +122,34 @@ export default function AddEditMCPServer({
|
||||
const [args, setArgs] = useState<string[]>([''])
|
||||
const [envKeys, setEnvKeys] = useState<string[]>([''])
|
||||
const [envValues, setEnvValues] = useState<string[]>([''])
|
||||
const [transportType, setTransportType] = useState<'stdio' | 'http' | 'sse'>(
|
||||
'stdio'
|
||||
)
|
||||
const [url, setUrl] = useState('')
|
||||
const [headerKeys, setHeaderKeys] = useState<string[]>([''])
|
||||
const [headerValues, setHeaderValues] = useState<string[]>([''])
|
||||
const [timeout, setTimeout] = useState('')
|
||||
const [isToggled, setIsToggled] = useState(false)
|
||||
const [jsonContent, setJsonContent] = useState('')
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Reset form when modal opens/closes or editing key changes
|
||||
useEffect(() => {
|
||||
if (open && editingKey && initialData) {
|
||||
setServerName(editingKey)
|
||||
setCommand(initialData.command)
|
||||
setCommand(initialData.command || '')
|
||||
setUrl(initialData.url || '')
|
||||
setTimeout(initialData.timeout ? initialData.timeout.toString() : '')
|
||||
setArgs(initialData.args?.length > 0 ? initialData.args : [''])
|
||||
setTransportType(initialData?.type || 'stdio')
|
||||
|
||||
// Initialize JSON content for toggle mode
|
||||
try {
|
||||
const jsonData = { [editingKey]: initialData }
|
||||
setJsonContent(JSON.stringify(jsonData, null, 2))
|
||||
} catch {
|
||||
setJsonContent('')
|
||||
}
|
||||
|
||||
if (initialData.env) {
|
||||
// Convert env object to arrays of keys and values
|
||||
@ -130,6 +159,17 @@ export default function AddEditMCPServer({
|
||||
setEnvKeys(keys.length > 0 ? keys : [''])
|
||||
setEnvValues(values.length > 0 ? values : [''])
|
||||
}
|
||||
|
||||
if (initialData.headers) {
|
||||
// Convert headers object to arrays of keys and values
|
||||
const headerKeysList = Object.keys(initialData.headers)
|
||||
const headerValuesList = headerKeysList.map(
|
||||
(key) => initialData.headers![key]
|
||||
)
|
||||
|
||||
setHeaderKeys(headerKeysList.length > 0 ? headerKeysList : [''])
|
||||
setHeaderValues(headerValuesList.length > 0 ? headerValuesList : [''])
|
||||
}
|
||||
} else if (open) {
|
||||
// Add mode - reset form
|
||||
resetForm()
|
||||
@ -139,9 +179,17 @@ export default function AddEditMCPServer({
|
||||
const resetForm = () => {
|
||||
setServerName('')
|
||||
setCommand('')
|
||||
setUrl('')
|
||||
setTimeout('')
|
||||
setArgs([''])
|
||||
setEnvKeys([''])
|
||||
setEnvValues([''])
|
||||
setHeaderKeys([''])
|
||||
setHeaderValues([''])
|
||||
setTransportType('stdio')
|
||||
setIsToggled(false)
|
||||
setJsonContent('')
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const handleAddArg = () => {
|
||||
@ -201,7 +249,57 @@ export default function AddEditMCPServer({
|
||||
setEnvValues(newValues)
|
||||
}
|
||||
|
||||
const handleAddHeader = () => {
|
||||
setHeaderKeys([...headerKeys, ''])
|
||||
setHeaderValues([...headerValues, ''])
|
||||
}
|
||||
|
||||
const handleRemoveHeader = (index: number) => {
|
||||
const newKeys = [...headerKeys]
|
||||
const newValues = [...headerValues]
|
||||
newKeys.splice(index, 1)
|
||||
newValues.splice(index, 1)
|
||||
setHeaderKeys(newKeys.length > 0 ? newKeys : [''])
|
||||
setHeaderValues(newValues.length > 0 ? newValues : [''])
|
||||
}
|
||||
|
||||
const handleHeaderKeyChange = (index: number, value: string) => {
|
||||
const newKeys = [...headerKeys]
|
||||
newKeys[index] = value
|
||||
setHeaderKeys(newKeys)
|
||||
}
|
||||
|
||||
const handleHeaderValueChange = (index: number, value: string) => {
|
||||
const newValues = [...headerValues]
|
||||
newValues[index] = value
|
||||
setHeaderValues(newValues)
|
||||
}
|
||||
|
||||
const handleSave = () => {
|
||||
// Handle JSON mode
|
||||
if (isToggled) {
|
||||
try {
|
||||
const parsedData = JSON.parse(jsonContent)
|
||||
// Validate that it's an object with server configurations
|
||||
if (typeof parsedData !== 'object' || parsedData === null) {
|
||||
setError(t('mcp-servers:editJson.errorFormat'))
|
||||
return
|
||||
}
|
||||
// For each server in the JSON, call onSave
|
||||
Object.entries(parsedData).forEach(([serverName, config]) => {
|
||||
onSave(serverName.trim(), config as MCPServerConfig)
|
||||
})
|
||||
onOpenChange(false)
|
||||
resetForm()
|
||||
setError(null)
|
||||
return
|
||||
} catch {
|
||||
setError(t('mcp-servers:editJson.errorFormat'))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle form mode
|
||||
// Convert env arrays to object
|
||||
const envObj: Record<string, string> = {}
|
||||
envKeys.forEach((key, index) => {
|
||||
@ -211,13 +309,28 @@ export default function AddEditMCPServer({
|
||||
}
|
||||
})
|
||||
|
||||
// Convert headers arrays to object
|
||||
const headersObj: Record<string, string> = {}
|
||||
headerKeys.forEach((key, index) => {
|
||||
const keyName = key.trim()
|
||||
if (keyName !== '') {
|
||||
headersObj[keyName] = headerValues[index]?.trim() || ''
|
||||
}
|
||||
})
|
||||
|
||||
// Filter out empty args
|
||||
const filteredArgs = args.map((arg) => arg.trim()).filter((arg) => arg)
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
command: command.trim(),
|
||||
args: filteredArgs,
|
||||
env: envObj,
|
||||
command: transportType === 'stdio' ? command.trim() : '',
|
||||
args: transportType === 'stdio' ? filteredArgs : [],
|
||||
env: transportType === 'stdio' ? envObj : {},
|
||||
type: transportType,
|
||||
...(transportType !== 'stdio' && {
|
||||
url: url.trim(),
|
||||
headers: Object.keys(headersObj).length > 0 ? headersObj : undefined,
|
||||
timeout: timeout.trim() !== '' ? parseInt(timeout) : undefined,
|
||||
}),
|
||||
}
|
||||
|
||||
if (serverName.trim() !== '') {
|
||||
@ -229,122 +342,303 @@ export default function AddEditMCPServer({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent>
|
||||
<DialogContent showCloseButton={false}>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
{editingKey
|
||||
? t('mcp-servers:editServer')
|
||||
: t('mcp-servers:addServer')}
|
||||
<DialogTitle className="flex items-center justify-between">
|
||||
<span>
|
||||
{editingKey
|
||||
? t('mcp-servers:editServer')
|
||||
: t('mcp-servers:addServer')}
|
||||
</span>
|
||||
<div
|
||||
className={cn(
|
||||
'size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out',
|
||||
isToggled && 'bg-main-view-fg/10 text-accent'
|
||||
)}
|
||||
title="Add server by JSON"
|
||||
onClick={() => setIsToggled(!isToggled)}
|
||||
>
|
||||
<IconCodeDots className="h-5 w-5 cursor-pointer transition-colors duration-200" />
|
||||
</div>
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
{t('mcp-servers:serverName')}
|
||||
</label>
|
||||
<Input
|
||||
value={serverName}
|
||||
onChange={(e) => setServerName(e.target.value)}
|
||||
placeholder={t('mcp-servers:enterServerName')}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
{t('mcp-servers:command')}
|
||||
</label>
|
||||
<Input
|
||||
value={command}
|
||||
onChange={(e) => setCommand(e.target.value)}
|
||||
placeholder={t('mcp-servers:enterCommand')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-sm">{t('mcp-servers:arguments')}</label>
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={handleAddArg}
|
||||
>
|
||||
<IconPlus size={18} className="text-main-view-fg/60" />
|
||||
{isToggled ? (
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
{t('mcp-servers:editJson.placeholder')}
|
||||
</label>
|
||||
<div className="border border-main-view-fg/10 rounded-md overflow-hidden">
|
||||
<CodeEditor
|
||||
value={jsonContent}
|
||||
language="json"
|
||||
placeholder={`{
|
||||
"serverName": {
|
||||
"command": "command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {
|
||||
"KEY": "value"
|
||||
}
|
||||
}
|
||||
}`}
|
||||
onChange={(e) => {
|
||||
setJsonContent(e.target.value)
|
||||
setError(null)
|
||||
}}
|
||||
onPaste={() => setError(null)}
|
||||
style={{
|
||||
fontFamily: 'ui-monospace',
|
||||
backgroundColor: 'transparent',
|
||||
wordBreak: 'break-all',
|
||||
overflowWrap: 'anywhere',
|
||||
whiteSpace: 'pre-wrap',
|
||||
}}
|
||||
className="w-full !text-sm min-h-[300px]"
|
||||
/>
|
||||
</div>
|
||||
{error && <div className="text-destructive text-sm">{error}</div>}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
{t('mcp-servers:serverName')}
|
||||
</label>
|
||||
<Input
|
||||
value={serverName}
|
||||
onChange={(e) => setServerName(e.target.value)}
|
||||
placeholder={t('mcp-servers:enterServerName')}
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
|
||||
<DndContext
|
||||
sensors={sensors}
|
||||
collisionDetection={closestCenter}
|
||||
onDragEnd={(event) => {
|
||||
const { active, over } = event
|
||||
if (active.id !== over?.id) {
|
||||
const oldIndex = parseInt(active.id.toString())
|
||||
const newIndex = parseInt(over?.id.toString() || '0')
|
||||
handleReorderArgs(oldIndex, newIndex)
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
Transport Type
|
||||
</label>
|
||||
<RadioGroup
|
||||
value={transportType}
|
||||
onValueChange={(value) =>
|
||||
setTransportType(value as 'http' | 'sse')
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SortableContext
|
||||
items={args.map((_, index) => index)}
|
||||
strategy={verticalListSortingStrategy}
|
||||
className="flex gap-6"
|
||||
>
|
||||
{args.map((arg, index) => (
|
||||
<SortableArgItem
|
||||
key={index}
|
||||
id={index}
|
||||
value={arg}
|
||||
onChange={(value) => handleArgChange(index, value)}
|
||||
onRemove={() => handleRemoveArg(index)}
|
||||
canRemove={args.length > 1}
|
||||
placeholder={t('mcp-servers:argument', {
|
||||
index: index + 1,
|
||||
})}
|
||||
/>
|
||||
))}
|
||||
</SortableContext>
|
||||
</DndContext>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-sm">{t('mcp-servers:envVars')}</label>
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={handleAddEnv}
|
||||
>
|
||||
<IconPlus size={18} className="text-main-view-fg/60" />
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="stdio" id="stdio" />
|
||||
<label
|
||||
htmlFor="stdio"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
>
|
||||
STDIO
|
||||
</label>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="http" id="http" />
|
||||
<label
|
||||
htmlFor="http"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
>
|
||||
HTTP
|
||||
</label>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="sse" id="sse" />
|
||||
<label
|
||||
htmlFor="sse"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
>
|
||||
SSE
|
||||
</label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</div>
|
||||
|
||||
{envKeys.map((key, index) => (
|
||||
<div key={`env-${index}`} className="flex items-center gap-2">
|
||||
{transportType === 'stdio' ? (
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
{t('mcp-servers:command')}
|
||||
</label>
|
||||
<Input
|
||||
value={key}
|
||||
onChange={(e) => handleEnvKeyChange(index, e.target.value)}
|
||||
placeholder={t('mcp-servers:key')}
|
||||
className="flex-1"
|
||||
value={command}
|
||||
onChange={(e) => setCommand(e.target.value)}
|
||||
placeholder={t('mcp-servers:enterCommand')}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">URL</label>
|
||||
<Input
|
||||
value={envValues[index] || ''}
|
||||
onChange={(e) => handleEnvValueChange(index, e.target.value)}
|
||||
placeholder={t('mcp-servers:value')}
|
||||
className="flex-1"
|
||||
value={url}
|
||||
onChange={(e) => setUrl(e.target.value)}
|
||||
placeholder="Enter URL"
|
||||
/>
|
||||
{envKeys.length > 1 && (
|
||||
</div>
|
||||
)}
|
||||
|
||||
{transportType === 'stdio' && (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-sm">
|
||||
{t('mcp-servers:arguments')}
|
||||
</label>
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={() => handleRemoveEnv(index)}
|
||||
onClick={handleAddArg}
|
||||
>
|
||||
<IconTrash size={18} className="text-destructive" />
|
||||
<IconPlus size={18} className="text-main-view-fg/60" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<DndContext
|
||||
sensors={sensors}
|
||||
collisionDetection={closestCenter}
|
||||
onDragEnd={(event) => {
|
||||
const { active, over } = event
|
||||
if (active.id !== over?.id) {
|
||||
const oldIndex = parseInt(active.id.toString())
|
||||
const newIndex = parseInt(over?.id.toString() || '0')
|
||||
handleReorderArgs(oldIndex, newIndex)
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SortableContext
|
||||
items={args.map((_, index) => index)}
|
||||
strategy={verticalListSortingStrategy}
|
||||
>
|
||||
{args.map((arg, index) => (
|
||||
<SortableArgItem
|
||||
key={index}
|
||||
id={index}
|
||||
value={arg}
|
||||
onChange={(value) => handleArgChange(index, value)}
|
||||
onRemove={() => handleRemoveArg(index)}
|
||||
canRemove={args.length > 1}
|
||||
placeholder={t('mcp-servers:argument', {
|
||||
index: index + 1,
|
||||
})}
|
||||
/>
|
||||
))}
|
||||
</SortableContext>
|
||||
</DndContext>
|
||||
</div>
|
||||
))}
|
||||
)}
|
||||
|
||||
{transportType === 'stdio' && (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-sm">{t('mcp-servers:envVars')}</label>
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={handleAddEnv}
|
||||
>
|
||||
<IconPlus size={18} className="text-main-view-fg/60" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{envKeys.map((key, index) => (
|
||||
<div key={`env-${index}`} className="flex items-center gap-2">
|
||||
<Input
|
||||
value={key}
|
||||
onChange={(e) =>
|
||||
handleEnvKeyChange(index, e.target.value)
|
||||
}
|
||||
placeholder={t('mcp-servers:key')}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Input
|
||||
value={envValues[index] || ''}
|
||||
onChange={(e) =>
|
||||
handleEnvValueChange(index, e.target.value)
|
||||
}
|
||||
placeholder={t('mcp-servers:value')}
|
||||
className="flex-1"
|
||||
/>
|
||||
{envKeys.length > 1 && (
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={() => handleRemoveEnv(index)}
|
||||
>
|
||||
<IconTrash size={18} className="text-destructive" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(transportType === 'http' || transportType === 'sse') && (
|
||||
<>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<label className="text-sm">Headers</label>
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={handleAddHeader}
|
||||
>
|
||||
<IconPlus size={18} className="text-main-view-fg/60" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{headerKeys.map((key, index) => (
|
||||
<div
|
||||
key={`header-${index}`}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
<Input
|
||||
value={key}
|
||||
onChange={(e) =>
|
||||
handleHeaderKeyChange(index, e.target.value)
|
||||
}
|
||||
placeholder="Header name"
|
||||
className="flex-1"
|
||||
/>
|
||||
<Input
|
||||
value={headerValues[index] || ''}
|
||||
onChange={(e) =>
|
||||
handleHeaderValueChange(index, e.target.value)
|
||||
}
|
||||
placeholder="Header value"
|
||||
className="flex-1"
|
||||
/>
|
||||
{headerKeys.length > 1 && (
|
||||
<div
|
||||
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
|
||||
onClick={() => handleRemoveHeader(index)}
|
||||
>
|
||||
<IconTrash size={18} className="text-destructive" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm mb-2 inline-block">
|
||||
Timeout (seconds)
|
||||
</label>
|
||||
<Input
|
||||
value={timeout}
|
||||
onChange={(e) => setTimeout(e.target.value)}
|
||||
placeholder="Enter timeout in seconds"
|
||||
type="number"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<DialogFooter>
|
||||
<Button onClick={handleSave}>{t('mcp-servers:save')}</Button>
|
||||
<Button variant="link" onClick={() => onOpenChange(false)}>
|
||||
{t('common:cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSave}
|
||||
disabled={!isToggled && serverName.trim() === ''}
|
||||
>
|
||||
{t('mcp-servers:save')}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
@ -18,6 +18,7 @@ import { IconTrash } from '@tabler/icons-react'
|
||||
import { useState, useEffect } from 'react'
|
||||
import { toast } from 'sonner'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
||||
|
||||
type DialogDeleteModelProps = {
|
||||
provider: ModelProvider
|
||||
@ -31,8 +32,12 @@ export const DialogDeleteModel = ({
|
||||
const { t } = useTranslation()
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>('')
|
||||
const { setProviders, deleteModel: deleteModelCache } = useModelProvider()
|
||||
const { removeFavorite } = useFavoriteModel()
|
||||
|
||||
const removeModel = async () => {
|
||||
// Remove model from favorites if it exists
|
||||
removeFavorite(selectedModelId)
|
||||
|
||||
deleteModelCache(selectedModelId)
|
||||
deleteModel(selectedModelId).then(() => {
|
||||
getProviders().then((providers) => {
|
||||
|
||||
@ -18,6 +18,7 @@ import { useRouter } from '@tanstack/react-router'
|
||||
import { route } from '@/constants/routes'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { predefinedProviders } from '@/consts/providers'
|
||||
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
||||
|
||||
type Props = {
|
||||
provider?: ProviderObject
|
||||
@ -25,6 +26,7 @@ type Props = {
|
||||
const DeleteProvider = ({ provider }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { deleteProvider, providers } = useModelProvider()
|
||||
const { favoriteModels, removeFavorite } = useFavoriteModel()
|
||||
const router = useRouter()
|
||||
if (
|
||||
!provider ||
|
||||
@ -34,6 +36,14 @@ const DeleteProvider = ({ provider }: Props) => {
|
||||
return null
|
||||
|
||||
const removeProvider = async () => {
|
||||
// Remove favorite models that belong to this provider
|
||||
const providerModelIds = provider.models.map((model) => model.id)
|
||||
favoriteModels.forEach((favoriteModel) => {
|
||||
if (providerModelIds.includes(favoriteModel.id)) {
|
||||
removeFavorite(favoriteModel.id)
|
||||
}
|
||||
})
|
||||
|
||||
deleteProvider(provider.provider)
|
||||
toast.success(t('providers:deleteProvider.title'), {
|
||||
id: `delete-provider-${provider.provider}`,
|
||||
|
||||
@ -43,19 +43,9 @@ export default function EditJsonMCPserver({
|
||||
}
|
||||
}, [open, initialData, t])
|
||||
|
||||
const handlePaste = (e: React.ClipboardEvent) => {
|
||||
const pastedText = e.clipboardData.getData('text')
|
||||
try {
|
||||
const parsedJson = JSON.parse(pastedText)
|
||||
const prettifiedJson = JSON.stringify(parsedJson, null, 2)
|
||||
e.preventDefault()
|
||||
setJsonContent(prettifiedJson)
|
||||
setError(null)
|
||||
} catch (error) {
|
||||
e.preventDefault()
|
||||
setError(t('mcp-servers:editJson.errorPaste'))
|
||||
console.error('Paste error:', error)
|
||||
}
|
||||
const handlePaste = () => {
|
||||
// Clear any existing errors when pasting
|
||||
setError(null)
|
||||
}
|
||||
|
||||
const handleSave = () => {
|
||||
@ -80,7 +70,18 @@ export default function EditJsonMCPserver({
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="space-y-2">
|
||||
<div className="border border-main-view-fg/10 rounded-md overflow-hidden">
|
||||
<div className="border border-main-view-fg/10 rounded-md !overflow-hidden">
|
||||
<style>{`
|
||||
.w-tc-editor textarea {
|
||||
word-break: break-all !important;
|
||||
overflow-wrap: anywhere !important;
|
||||
white-space: pre-wrap !important;
|
||||
}
|
||||
.w-tc-editor .token.string {
|
||||
word-break: break-all !important;
|
||||
overflow-wrap: anywhere !important;
|
||||
}
|
||||
`}</style>
|
||||
<CodeEditor
|
||||
value={jsonContent}
|
||||
language="json"
|
||||
@ -90,8 +91,11 @@ export default function EditJsonMCPserver({
|
||||
style={{
|
||||
fontFamily: 'ui-monospace',
|
||||
backgroundColor: 'transparent',
|
||||
wordBreak: 'break-all',
|
||||
overflowWrap: 'anywhere',
|
||||
whiteSpace: 'pre-wrap',
|
||||
}}
|
||||
className="w-full !text-sm "
|
||||
className="w-full !text-sm overflow-hidden break-all"
|
||||
/>
|
||||
</div>
|
||||
{error && <div className="text-destructive text-sm">{error}</div>}
|
||||
|
||||
@ -1,253 +0,0 @@
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import {
|
||||
IconPencil,
|
||||
IconEye,
|
||||
IconTool,
|
||||
// IconWorld,
|
||||
// IconAtom,
|
||||
IconCodeCircle2,
|
||||
} from '@tabler/icons-react'
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
|
||||
// No need to define our own interface, we'll use the existing Model type
|
||||
type DialogEditModelProps = {
|
||||
provider: ModelProvider
|
||||
modelId?: string // Optional model ID to edit
|
||||
}
|
||||
|
||||
export const DialogEditModel = ({
|
||||
provider,
|
||||
modelId,
|
||||
}: DialogEditModelProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { updateProvider } = useModelProvider()
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>('')
|
||||
const [capabilities, setCapabilities] = useState<Record<string, boolean>>({
|
||||
completion: false,
|
||||
vision: false,
|
||||
tools: false,
|
||||
reasoning: false,
|
||||
embeddings: false,
|
||||
web_search: false,
|
||||
})
|
||||
|
||||
// Initialize with the provided model ID or the first model if available
|
||||
useEffect(() => {
|
||||
if (modelId) {
|
||||
setSelectedModelId(modelId)
|
||||
} else if (provider.models && provider.models.length > 0) {
|
||||
setSelectedModelId(provider.models[0].id)
|
||||
}
|
||||
}, [provider, modelId])
|
||||
|
||||
// Get the currently selected model
|
||||
const selectedModel = provider.models.find(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(m: any) => m.id === selectedModelId
|
||||
)
|
||||
|
||||
// Initialize capabilities from selected model
|
||||
useEffect(() => {
|
||||
if (selectedModel) {
|
||||
const modelCapabilities = selectedModel.capabilities || []
|
||||
setCapabilities({
|
||||
completion: modelCapabilities.includes('completion'),
|
||||
vision: modelCapabilities.includes('vision'),
|
||||
tools: modelCapabilities.includes('tools'),
|
||||
embeddings: modelCapabilities.includes('embeddings'),
|
||||
web_search: modelCapabilities.includes('web_search'),
|
||||
reasoning: modelCapabilities.includes('reasoning'),
|
||||
})
|
||||
}
|
||||
}, [selectedModel])
|
||||
|
||||
// Track if capabilities were updated by user action
|
||||
const [capabilitiesUpdated, setCapabilitiesUpdated] = useState(false)
|
||||
|
||||
// Update model capabilities - only update local state
|
||||
const handleCapabilityChange = (capability: string, enabled: boolean) => {
|
||||
setCapabilities((prev) => ({
|
||||
...prev,
|
||||
[capability]: enabled,
|
||||
}))
|
||||
// Mark that capabilities were updated by user action
|
||||
setCapabilitiesUpdated(true)
|
||||
}
|
||||
|
||||
// Use effect to update the provider when capabilities are explicitly changed by user
|
||||
useEffect(() => {
|
||||
// Only run if capabilities were updated by user action and we have a selected model
|
||||
if (!capabilitiesUpdated || !selectedModel) return
|
||||
|
||||
// Reset the flag
|
||||
setCapabilitiesUpdated(false)
|
||||
|
||||
// Create updated capabilities array from the state
|
||||
const updatedCapabilities = Object.entries(capabilities)
|
||||
.filter(([, isEnabled]) => isEnabled)
|
||||
.map(([capName]) => capName)
|
||||
|
||||
// Find and update the model in the provider
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const updatedModels = provider.models.map((m: any) => {
|
||||
if (m.id === selectedModelId) {
|
||||
return {
|
||||
...m,
|
||||
capabilities: updatedCapabilities,
|
||||
}
|
||||
}
|
||||
return m
|
||||
})
|
||||
|
||||
// Update the provider with the updated models
|
||||
updateProvider(provider.provider, {
|
||||
...provider,
|
||||
models: updatedModels,
|
||||
})
|
||||
}, [
|
||||
capabilitiesUpdated,
|
||||
capabilities,
|
||||
provider,
|
||||
selectedModel,
|
||||
selectedModelId,
|
||||
updateProvider,
|
||||
])
|
||||
|
||||
if (!selectedModel) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
|
||||
<IconPencil size={18} className="text-main-view-fg/50" />
|
||||
</div>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="line-clamp-1" title={selectedModel.id}>
|
||||
{t('providers:editModel.title', { modelId: selectedModel.id })}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{t('providers:editModel.description')}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="py-1">
|
||||
<h3 className="text-sm font-medium mb-3">
|
||||
{t('providers:editModel.capabilities')}
|
||||
</h3>
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<IconTool className="size-4 text-main-view-fg/70" />
|
||||
<span className="text-sm">
|
||||
{t('providers:editModel.tools')}
|
||||
</span>
|
||||
</div>
|
||||
<Switch
|
||||
id="tools-capability"
|
||||
checked={capabilities.tools}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCapabilityChange('tools', checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<IconEye className="size-4 text-main-view-fg/70" />
|
||||
<span className="text-sm">
|
||||
{t('providers:editModel.vision')}
|
||||
</span>
|
||||
</div>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<Switch
|
||||
id="vision-capability"
|
||||
checked={capabilities.vision}
|
||||
disabled={true}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCapabilityChange('vision', checked)
|
||||
}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{t('providers:editModel.notAvailable')}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<IconCodeCircle2 className="size-4 text-main-view-fg/70" />
|
||||
<span className="text-sm">
|
||||
{t('providers:editModel.embeddings')}
|
||||
</span>
|
||||
</div>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<Switch
|
||||
id="embedding-capability"
|
||||
disabled={true}
|
||||
checked={capabilities.embeddings}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCapabilityChange('embeddings', checked)
|
||||
}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{t('providers:editModel.notAvailable')}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
{/* <div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<IconWorld className="size-4 text-main-view-fg/70" />
|
||||
<span className="text-sm">Web Search</span>
|
||||
</div>
|
||||
<Switch
|
||||
id="web_search-capability"
|
||||
checked={capabilities.web_search}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCapabilityChange('web_search', checked)
|
||||
}
|
||||
/>
|
||||
</div> */}
|
||||
|
||||
{/* <div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<IconAtom className="size-4 text-main-view-fg/70" />
|
||||
<span className="text-sm">{t('reasoning')}</span>
|
||||
</div>
|
||||
<Switch
|
||||
id="reasoning-capability"
|
||||
checked={capabilities.reasoning}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCapabilityChange('reasoning', checked)
|
||||
}
|
||||
/>
|
||||
</div> */}
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
180
web-app/src/hooks/__tests__/useTools.test.ts
Normal file
@ -0,0 +1,180 @@
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
|
||||
// Mock functions
|
||||
const mockGetTools = vi.fn()
|
||||
const mockUpdateTools = vi.fn()
|
||||
const mockListen = vi.fn()
|
||||
const mockUnsubscribe = vi.fn()
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock('@/services/mcp', () => ({
|
||||
getTools: mockGetTools,
|
||||
}))
|
||||
|
||||
vi.mock('../useAppState', () => ({
|
||||
useAppState: () => ({
|
||||
updateTools: mockUpdateTools,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@tauri-apps/api/event', () => ({
|
||||
listen: mockListen,
|
||||
}))
|
||||
|
||||
describe('useTools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockListen.mockResolvedValue(mockUnsubscribe)
|
||||
mockGetTools.mockResolvedValue([])
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('should call getTools and updateTools on mount', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const mockTools = [
|
||||
{ name: 'test-tool', description: 'A test tool' },
|
||||
{ name: 'another-tool', description: 'Another test tool' },
|
||||
]
|
||||
mockGetTools.mockResolvedValue(mockTools)
|
||||
|
||||
renderHook(() => useTools())
|
||||
|
||||
// Wait for async operations to complete
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
|
||||
})
|
||||
|
||||
it('should set up event listener for MCP_UPDATE', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
renderHook(() => useTools())
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockListen).toHaveBeenCalledWith(
|
||||
SystemEvent.MCP_UPDATE,
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should call setTools when MCP_UPDATE event is triggered', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const mockTools = [{ name: 'updated-tool', description: 'Updated tool' }]
|
||||
mockGetTools.mockResolvedValue(mockTools)
|
||||
|
||||
let eventCallback: () => void
|
||||
|
||||
mockListen.mockImplementation((_event, callback) => {
|
||||
eventCallback = callback
|
||||
return Promise.resolve(mockUnsubscribe)
|
||||
})
|
||||
|
||||
renderHook(() => useTools())
|
||||
|
||||
// Wait for initial setup
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
// Clear the initial calls
|
||||
vi.clearAllMocks()
|
||||
mockGetTools.mockResolvedValue(mockTools)
|
||||
|
||||
// Trigger the event
|
||||
await act(async () => {
|
||||
eventCallback()
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
|
||||
})
|
||||
|
||||
it('should return unsubscribe function for cleanup', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const { unmount } = renderHook(() => useTools())
|
||||
|
||||
await act(async () => {
|
||||
await new Promise(resolve => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mockListen).toHaveBeenCalled()
|
||||
|
||||
// Unmount should call the unsubscribe function
|
||||
unmount()
|
||||
|
||||
expect(mockListen).toHaveBeenCalledWith(
|
||||
SystemEvent.MCP_UPDATE,
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle getTools errors gracefully', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockGetTools.mockRejectedValue(new Error('Failed to get tools'))
|
||||
|
||||
renderHook(() => useTools())
|
||||
|
||||
await act(async () => {
|
||||
// Give enough time for the promise to be handled
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
})
|
||||
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
// updateTools should not be called if getTools fails
|
||||
expect(mockUpdateTools).not.toHaveBeenCalled()
|
||||
|
||||
consoleErrorSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle event listener setup errors gracefully', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockListen.mockRejectedValue(new Error('Failed to set up listener'))
|
||||
|
||||
renderHook(() => useTools())
|
||||
|
||||
await act(async () => {
|
||||
// Give enough time for the promise to be handled
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
})
|
||||
|
||||
// Initial getTools should still work
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
expect(mockListen).toHaveBeenCalled()
|
||||
|
||||
consoleErrorSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should only set up effect once with empty dependency array', async () => {
|
||||
const { useTools } = await import('../useTools')
|
||||
|
||||
const { rerender } = renderHook(() => useTools())
|
||||
|
||||
// Initial render
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
expect(mockListen).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Rerender should not trigger additional calls
|
||||
rerender()
|
||||
expect(mockGetTools).toHaveBeenCalledTimes(1)
|
||||
expect(mockListen).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@ -1,4 +1,4 @@
|
||||
import { useCallback, useEffect, useMemo } from 'react'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { usePrompt } from './usePrompt'
|
||||
import { useModelProvider } from './useModelProvider'
|
||||
import { useThreads } from './useThreads'
|
||||
@ -20,10 +20,7 @@ import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { renderInstructions } from '@/lib/instructionTemplate'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { useAssistant } from './useAssistant'
|
||||
import { getTools } from '@/services/mcp'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { listen } from '@tauri-apps/api/event'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
|
||||
import { stopModel, startModel, stopAllModels } from '@/services/models'
|
||||
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
@ -33,6 +30,10 @@ import { updateSettings } from '@/services/providers'
|
||||
import { useContextSizeApproval } from './useModelContextApproval'
|
||||
import { useModelLoad } from './useModelLoad'
|
||||
import { useGeneralSetting } from './useGeneralSetting'
|
||||
import {
|
||||
ReasoningProcessor,
|
||||
extractReasoningFromMessage,
|
||||
} from '@/utils/reasoning'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
@ -41,7 +42,6 @@ export const useChat = () => {
|
||||
tools,
|
||||
updateTokenSpeed,
|
||||
resetTokenSpeed,
|
||||
updateTools,
|
||||
updateStreamingContent,
|
||||
updateLoadingModel,
|
||||
setAbortController,
|
||||
@ -78,22 +78,6 @@ export const useChat = () => {
|
||||
const selectedAssistant =
|
||||
assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
|
||||
|
||||
useEffect(() => {
|
||||
function setTools() {
|
||||
getTools().then((data: MCPTool[]) => {
|
||||
updateTools(data)
|
||||
})
|
||||
}
|
||||
setTools()
|
||||
|
||||
let unsubscribe = () => {}
|
||||
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
|
||||
// Unsubscribe from the event when the component unmounts
|
||||
unsubscribe = unsub
|
||||
})
|
||||
return unsubscribe
|
||||
}, [updateTools])
|
||||
|
||||
const getCurrentThread = useCallback(async () => {
|
||||
let currentThread = retrieveThread()
|
||||
|
||||
@ -306,18 +290,27 @@ export const useChat = () => {
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
accumulatedText =
|
||||
(completion.choices[0]?.message?.content as string) || ''
|
||||
if (completion.choices[0]?.message?.tool_calls) {
|
||||
toolCalls.push(...completion.choices[0].message.tool_calls)
|
||||
const message = completion.choices[0]?.message
|
||||
accumulatedText = (message?.content as string) || ''
|
||||
|
||||
// Handle reasoning field if there is one
|
||||
const reasoning = extractReasoningFromMessage(message)
|
||||
if (reasoning) {
|
||||
accumulatedText =
|
||||
`<think>${reasoning}</think>` + accumulatedText
|
||||
}
|
||||
|
||||
if (message?.tool_calls) {
|
||||
toolCalls.push(...message.tool_calls)
|
||||
}
|
||||
} else {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled) return
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
@ -329,6 +322,12 @@ export const useChat = () => {
|
||||
}
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
@ -349,7 +348,10 @@ export const useChat = () => {
|
||||
}
|
||||
const flushIfPending = () => {
|
||||
if (!rafScheduled) return
|
||||
if (typeof cancelAnimationFrame !== 'undefined' && rafHandle !== undefined) {
|
||||
if (
|
||||
typeof cancelAnimationFrame !== 'undefined' &&
|
||||
rafHandle !== undefined
|
||||
) {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else if (rafHandle !== undefined) {
|
||||
clearTimeout(rafHandle)
|
||||
@ -372,31 +374,63 @@ export const useChat = () => {
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
for await (const part of completion) {
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
if (rafHandle !== undefined) {
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
// Only finalize and flush if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText += reasoningProcessor.finalize()
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
}
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
|
||||
53
web-app/src/hooks/useFavoriteModel.ts
Normal file
@ -0,0 +1,53 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist, createJSONStorage } from 'zustand/middleware'
|
||||
import { localStorageKey } from '@/constants/localStorage'
|
||||
|
||||
interface FavoriteModelState {
|
||||
favoriteModels: Model[]
|
||||
addFavorite: (model: Model) => void
|
||||
removeFavorite: (modelId: string) => void
|
||||
isFavorite: (modelId: string) => boolean
|
||||
toggleFavorite: (model: Model) => void
|
||||
}
|
||||
|
||||
export const useFavoriteModel = create<FavoriteModelState>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
favoriteModels: [],
|
||||
|
||||
addFavorite: (model: Model) => {
|
||||
set((state) => {
|
||||
if (!state.favoriteModels.some((fav) => fav.id === model.id)) {
|
||||
return {
|
||||
favoriteModels: [...state.favoriteModels, model],
|
||||
}
|
||||
}
|
||||
return state
|
||||
})
|
||||
},
|
||||
|
||||
removeFavorite: (modelId: string) => {
|
||||
set((state) => ({
|
||||
favoriteModels: state.favoriteModels.filter((model) => model.id !== modelId),
|
||||
}))
|
||||
},
|
||||
|
||||
isFavorite: (modelId: string) => {
|
||||
return get().favoriteModels.some((model) => model.id === modelId)
|
||||
},
|
||||
|
||||
toggleFavorite: (model: Model) => {
|
||||
const { isFavorite, addFavorite, removeFavorite } = get()
|
||||
if (isFavorite(model.id)) {
|
||||
removeFavorite(model.id)
|
||||
} else {
|
||||
addFavorite(model)
|
||||
}
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStorageKey.favoriteModels,
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -7,6 +7,10 @@ export type MCPServerConfig = {
|
||||
args: string[]
|
||||
env: Record<string, string>
|
||||
active?: boolean
|
||||
type?: 'stdio' | 'http' | 'sse'
|
||||
url?: string
|
||||
headers?: Record<string, string>
|
||||
timeout?: number
|
||||
}
|
||||
|
||||
// Define the structure of all MCP servers
|
||||
|
||||
31
web-app/src/hooks/useTools.ts
Normal file
@ -0,0 +1,31 @@
|
||||
import { useEffect } from 'react'
|
||||
import { getTools } from '@/services/mcp'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { listen } from '@tauri-apps/api/event'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
import { useAppState } from './useAppState'
|
||||
|
||||
export const useTools = () => {
|
||||
const { updateTools } = useAppState()
|
||||
|
||||
useEffect(() => {
|
||||
function setTools() {
|
||||
getTools().then((data: MCPTool[]) => {
|
||||
updateTools(data)
|
||||
}).catch((error) => {
|
||||
console.error('Failed to fetch MCP tools:', error)
|
||||
})
|
||||
}
|
||||
setTools()
|
||||
|
||||
let unsubscribe = () => {}
|
||||
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
|
||||
// Unsubscribe from the event when the component unmounts
|
||||
unsubscribe = unsub
|
||||
}).catch((error) => {
|
||||
console.error('Failed to set up MCP update listener:', error)
|
||||
})
|
||||
return unsubscribe
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
}
|
||||
@ -17,6 +17,7 @@ import {
|
||||
import { useNavigate } from '@tanstack/react-router'
|
||||
import { route } from '@/constants/routes'
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
import { AppEvent, events } from '@janhq/core'
|
||||
|
||||
export function DataProvider() {
|
||||
const { setProviders } = useModelProvider()
|
||||
@ -70,6 +71,13 @@ export function DataProvider() {
|
||||
}
|
||||
}, [checkForUpdate])
|
||||
|
||||
useEffect(() => {
|
||||
events.on(AppEvent.onModelImported, () => {
|
||||
getProviders().then(setProviders)
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
const handleDeepLink = (urls: string[] | null) => {
|
||||
if (!urls) return
|
||||
console.log('Received deeplink:', urls)
|
||||
|
||||
@ -17,7 +17,12 @@ import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { Card, CardItem } from '@/containers/Card'
|
||||
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
||||
import { extractModelName, extractDescription } from '@/lib/models'
|
||||
import { IconDownload, IconFileCode, IconSearch } from '@tabler/icons-react'
|
||||
import {
|
||||
IconDownload,
|
||||
IconFileCode,
|
||||
IconSearch,
|
||||
IconTool,
|
||||
} from '@tabler/icons-react'
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||
@ -133,7 +138,10 @@ function Hub() {
|
||||
if (debouncedSearchValue.length) {
|
||||
const fuse = new Fuse(filtered, searchOptions)
|
||||
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
|
||||
const cleanedSearchValue = debouncedSearchValue.replace(/^https?:\/\/[^/]+\//, '')
|
||||
const cleanedSearchValue = debouncedSearchValue.replace(
|
||||
/^https?:\/\/[^/]+\//,
|
||||
''
|
||||
)
|
||||
filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
|
||||
}
|
||||
// Apply downloaded filter
|
||||
@ -647,6 +655,15 @@ function Hub() {
|
||||
?.length || 0}
|
||||
</span>
|
||||
</div>
|
||||
{filteredModels[virtualItem.index].tools && (
|
||||
<div className="flex items-center gap-1">
|
||||
<IconTool
|
||||
size={17}
|
||||
className="text-main-view-fg/50"
|
||||
title={t('hub:tools')}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{filteredModels[virtualItem.index].quants.length >
|
||||
1 && (
|
||||
<div className="flex items-center gap-2 hub-show-variants-step">
|
||||
|
||||
@ -3,6 +3,7 @@ import { createFileRoute, useSearch } from '@tanstack/react-router'
|
||||
import ChatInput from '@/containers/ChatInput'
|
||||
import HeaderPage from '@/containers/HeaderPage'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useTools } from '@/hooks/useTools'
|
||||
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import SetupScreen from '@/containers/SetupScreen'
|
||||
@ -31,6 +32,7 @@ function Index() {
|
||||
const search = useSearch({ from: route.home as any })
|
||||
const selectedModel = search.model
|
||||
const { setCurrentThreadId } = useThreads()
|
||||
useTools()
|
||||
|
||||
// Conditional to check if there are any valid providers
|
||||
// required min 1 api_key or 1 model in llama.cpp
|
||||
|
||||
@ -29,6 +29,58 @@ const maskSensitiveValue = (value: string) => {
|
||||
return value.slice(0, 4) + '*'.repeat(value.length - 8) + value.slice(-4)
|
||||
}
|
||||
|
||||
// Function to mask sensitive URL parameters
|
||||
const maskSensitiveUrl = (url: string) => {
|
||||
if (!url) return url
|
||||
|
||||
try {
|
||||
const urlObj = new URL(url)
|
||||
const params = urlObj.searchParams
|
||||
|
||||
// List of sensitive parameter names (case-insensitive)
|
||||
const sensitiveParams = [
|
||||
'api_key',
|
||||
'apikey',
|
||||
'key',
|
||||
'token',
|
||||
'secret',
|
||||
'password',
|
||||
'pwd',
|
||||
'auth',
|
||||
'authorization',
|
||||
'bearer',
|
||||
'access_token',
|
||||
'refresh_token',
|
||||
'client_secret',
|
||||
'private_key',
|
||||
'signature',
|
||||
'hash',
|
||||
]
|
||||
|
||||
// Mask sensitive parameters
|
||||
sensitiveParams.forEach((paramName) => {
|
||||
// Check both exact match and case-insensitive match
|
||||
for (const [key, value] of params.entries()) {
|
||||
if (key.toLowerCase() === paramName.toLowerCase()) {
|
||||
params.set(key, maskSensitiveValue(value))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Reconstruct URL with masked parameters
|
||||
urlObj.search = params.toString()
|
||||
return urlObj.toString()
|
||||
} catch {
|
||||
// If URL parsing fails, just mask the entire query string after '?'
|
||||
const queryIndex = url.indexOf('?')
|
||||
if (queryIndex === -1) return url
|
||||
|
||||
const baseUrl = url.substring(0, queryIndex + 1)
|
||||
const queryString = url.substring(queryIndex + 1)
|
||||
return baseUrl + maskSensitiveValue(queryString)
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const Route = createFileRoute(route.settings.mcp_servers as any)({
|
||||
component: MCPServers,
|
||||
@ -195,6 +247,7 @@ function MCPServers() {
|
||||
getConnectedServers().then(setConnectedServers)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.log(error, 'error.mcp')
|
||||
editServer(serverKey, {
|
||||
...(config ?? (mcpServers[serverKey] as MCPServerConfig)),
|
||||
active: false,
|
||||
@ -326,22 +379,56 @@ function MCPServers() {
|
||||
}
|
||||
descriptionOutside={
|
||||
<div className="text-sm text-main-view-fg/70">
|
||||
<div>
|
||||
{t('mcp-servers:command')}: {config.command}
|
||||
<div className="mb-1">
|
||||
Transport:{' '}
|
||||
<span className="uppercase">
|
||||
{config.type || 'stdio'}
|
||||
</span>
|
||||
</div>
|
||||
<div className="my-1 break-all">
|
||||
{t('mcp-servers:args')}: {config?.args?.join(', ')}
|
||||
</div>
|
||||
{config.env && Object.keys(config.env).length > 0 && (
|
||||
<div className="break-all">
|
||||
{t('mcp-servers:env')}:{' '}
|
||||
{Object.entries(config.env)
|
||||
.map(
|
||||
([key, value]) =>
|
||||
`${key}=${maskSensitiveValue(value)}`
|
||||
)
|
||||
.join(', ')}
|
||||
</div>
|
||||
|
||||
{config.type === 'stdio' || !config.type ? (
|
||||
<>
|
||||
<div>
|
||||
{t('mcp-servers:command')}: {config.command}
|
||||
</div>
|
||||
<div className="my-1 break-all">
|
||||
{t('mcp-servers:args')}:{' '}
|
||||
{config?.args?.join(', ')}
|
||||
</div>
|
||||
{config.env &&
|
||||
Object.keys(config.env).length > 0 && (
|
||||
<div className="break-all">
|
||||
{t('mcp-servers:env')}:{' '}
|
||||
{Object.entries(config.env)
|
||||
.map(
|
||||
([key, value]) =>
|
||||
`${key}=${maskSensitiveValue(value)}`
|
||||
)
|
||||
.join(', ')}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="break-all">
|
||||
URL: {maskSensitiveUrl(config.url || '')}
|
||||
</div>
|
||||
{config.headers &&
|
||||
Object.keys(config.headers).length > 0 && (
|
||||
<div className="my-1 break-all">
|
||||
Headers:{' '}
|
||||
{Object.entries(config.headers)
|
||||
.map(
|
||||
([key, value]) =>
|
||||
`${key}=${maskSensitiveValue(value)}`
|
||||
)
|
||||
.join(', ')}
|
||||
</div>
|
||||
)}
|
||||
{config.timeout && (
|
||||
<div>Timeout: {config.timeout}s</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
|
||||
@ -22,10 +22,10 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import Capabilities from '@/containers/Capabilities'
|
||||
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
|
||||
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
||||
import { DialogEditModel } from '@/containers/dialogs/EditModel'
|
||||
import { DialogAddModel } from '@/containers/dialogs/AddModel'
|
||||
import { ModelSetting } from '@/containers/ModelSetting'
|
||||
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
||||
import { FavoriteModelAction } from '@/containers/FavoriteModelAction'
|
||||
import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||
import { route } from '@/constants/routes'
|
||||
@ -554,17 +554,24 @@ function ProviderDetail() {
|
||||
</div>
|
||||
}
|
||||
actions={
|
||||
<div className="flex items-center gap-1">
|
||||
<DialogEditModel
|
||||
provider={provider}
|
||||
modelId={model.id}
|
||||
/>
|
||||
<div className="flex items-center gap-0.5">
|
||||
{model.settings && (
|
||||
<ModelSetting
|
||||
provider={provider}
|
||||
model={model}
|
||||
/>
|
||||
)}
|
||||
{((provider &&
|
||||
!predefinedProviders.some(
|
||||
(p) => p.provider === provider.provider
|
||||
)) ||
|
||||
(provider &&
|
||||
predefinedProviders.some(
|
||||
(p) => p.provider === provider.provider
|
||||
) &&
|
||||
Boolean(provider.api_key?.length))) && (
|
||||
<FavoriteModelAction model={model} />
|
||||
)}
|
||||
<DialogDeleteModel
|
||||
provider={provider}
|
||||
modelId={model.id}
|
||||
|
||||
@ -25,6 +25,7 @@ import { useCallback, useState } from 'react'
|
||||
import { openAIProviderSettings } from '@/consts/providers'
|
||||
import cloneDeep from 'lodash/cloneDeep'
|
||||
import { toast } from 'sonner'
|
||||
import { stopAllModels } from '@/services/models'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const Route = createFileRoute(route.settings.model_providers as any)({
|
||||
@ -169,7 +170,10 @@ function ModelProviders() {
|
||||
)}
|
||||
<Switch
|
||||
checked={provider.active}
|
||||
onCheckedChange={(e) => {
|
||||
onCheckedChange={async (e) => {
|
||||
if (!e && provider.provider.toLowerCase() === 'llamacpp') {
|
||||
await stopAllModels()
|
||||
}
|
||||
updateProvider(provider.provider, {
|
||||
...provider,
|
||||
active: e,
|
||||
|
||||
@ -23,6 +23,7 @@ import { ContentType, ThreadMessage } from '@janhq/core'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useChat } from '@/hooks/useChat'
|
||||
import { useSmallScreen } from '@/hooks/useMediaQuery'
|
||||
import { useTools } from '@/hooks/useTools'
|
||||
|
||||
// as route.threadsDetail
|
||||
export const Route = createFileRoute('/threads/$threadId')({
|
||||
@ -43,6 +44,7 @@ function ThreadDetail() {
|
||||
const { appMainViewBgColor, chatWidth } = useAppearance()
|
||||
const { sendMessage } = useChat()
|
||||
const isSmallScreen = useSmallScreen()
|
||||
useTools()
|
||||
|
||||
const { messages } = useMessages(
|
||||
useShallow((state) => ({
|
||||
|
||||
@ -5,7 +5,7 @@ import {
|
||||
updateSettings,
|
||||
} from '../providers'
|
||||
import { models as providerModels } from 'token.js'
|
||||
import { predefinedProviders } from '@/mock/data'
|
||||
import { predefinedProviders } from '@/consts/providers'
|
||||
import { EngineManager } from '@janhq/core'
|
||||
import { fetchModels } from '../models'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
@ -21,7 +21,7 @@ vi.mock('token.js', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/mock/data', () => ({
|
||||
vi.mock('@/consts/providers', () => ({
|
||||
predefinedProviders: [
|
||||
{
|
||||
active: true,
|
||||
@ -69,6 +69,7 @@ vi.mock('../models', () => ({
|
||||
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
|
||||
])
|
||||
),
|
||||
isToolSupported: vi.fn(() => Promise.resolve(false)),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/extension', () => ({
|
||||
@ -116,7 +117,7 @@ describe('providers service', () => {
|
||||
it('should return builtin and runtime providers', async () => {
|
||||
const providers = await getProviders()
|
||||
|
||||
expect(providers).toHaveLength(9) // 8 runtime + 1 builtin
|
||||
expect(providers).toHaveLength(2) // 1 runtime + 1 builtin (mocked)
|
||||
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
|
||||
expect(providers.some((p) => p.provider === 'openai')).toBe(true)
|
||||
})
|
||||
@ -156,7 +157,7 @@ describe('providers service', () => {
|
||||
provider: 'openai',
|
||||
base_url: 'https://api.openai.com/v1',
|
||||
api_key: 'test-key',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
const models = await fetchModelsFromProvider(provider)
|
||||
|
||||
@ -185,7 +186,7 @@ describe('providers service', () => {
|
||||
provider: 'custom',
|
||||
base_url: 'https://api.custom.com',
|
||||
api_key: '',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
const models = await fetchModelsFromProvider(provider)
|
||||
|
||||
@ -204,7 +205,7 @@ describe('providers service', () => {
|
||||
const provider = {
|
||||
provider: 'custom',
|
||||
base_url: 'https://api.custom.com',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
const models = await fetchModelsFromProvider(provider)
|
||||
|
||||
@ -214,7 +215,7 @@ describe('providers service', () => {
|
||||
it('should throw error when provider has no base_url', async () => {
|
||||
const provider = {
|
||||
provider: 'custom',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||
'Provider must have base_url configured'
|
||||
@ -232,10 +233,10 @@ describe('providers service', () => {
|
||||
const provider = {
|
||||
provider: 'custom',
|
||||
base_url: 'https://api.custom.com',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||
'Cannot connect to custom at https://api.custom.com'
|
||||
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
|
||||
)
|
||||
})
|
||||
|
||||
@ -245,10 +246,10 @@ describe('providers service', () => {
|
||||
const provider = {
|
||||
provider: 'custom',
|
||||
base_url: 'https://api.custom.com',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
|
||||
'Cannot connect to custom at https://api.custom.com'
|
||||
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
|
||||
)
|
||||
})
|
||||
|
||||
@ -264,7 +265,7 @@ describe('providers service', () => {
|
||||
const provider = {
|
||||
provider: 'custom',
|
||||
base_url: 'https://api.custom.com',
|
||||
} as ModelProvider
|
||||
}
|
||||
|
||||
const models = await fetchModelsFromProvider(provider)
|
||||
|
||||
@ -298,7 +299,7 @@ describe('providers service', () => {
|
||||
controller_type: 'input',
|
||||
controller_props: { value: 'test-key' },
|
||||
},
|
||||
] as ProviderSetting[]
|
||||
]
|
||||
|
||||
await updateSettings('openai', settings)
|
||||
|
||||
@ -324,7 +325,7 @@ describe('providers service', () => {
|
||||
mockExtensionManager
|
||||
)
|
||||
|
||||
const settings = [] as ProviderSetting[]
|
||||
const settings = []
|
||||
|
||||
const result = await updateSettings('nonexistent', settings)
|
||||
|
||||
@ -350,7 +351,7 @@ describe('providers service', () => {
|
||||
controller_type: 'input',
|
||||
controller_props: { value: undefined },
|
||||
},
|
||||
] as ProviderSetting[]
|
||||
]
|
||||
|
||||
await updateSettings('openai', settings)
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ export interface CatalogModel {
|
||||
mmproj_models?: MMProjModel[]
|
||||
created_at?: string
|
||||
readme?: string
|
||||
tools?: boolean
|
||||
}
|
||||
|
||||
export type ModelCatalog = CatalogModel[]
|
||||
@ -313,3 +314,16 @@ export const startModel = async (
|
||||
throw error
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model support tool use capability
|
||||
* Returned by backend engine
|
||||
* @param modelId
|
||||
* @returns
|
||||
*/
|
||||
export const isToolSupported = async (modelId: string): Promise<boolean> => {
|
||||
const engine = getEngine()
|
||||
if (!engine) return false
|
||||
|
||||
return engine.isToolSupported(modelId)
|
||||
}
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
import { models as providerModels } from 'token.js'
|
||||
import { predefinedProviders } from '@/consts/providers'
|
||||
import { EngineManager, SettingComponentProps } from '@janhq/core'
|
||||
import {
|
||||
DefaultToolUseSupportedModels,
|
||||
ModelCapabilities,
|
||||
} from '@/types/models'
|
||||
import { ModelCapabilities } from '@/types/models'
|
||||
import { modelSettings } from '@/lib/predefined'
|
||||
import { fetchModels } from './models'
|
||||
import { fetchModels, isToolSupported } from './models'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
|
||||
|
||||
@ -65,52 +62,41 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
controller_props: setting.controllerProps as unknown,
|
||||
}
|
||||
}) as ProviderSetting[],
|
||||
models: models.map((model) => ({
|
||||
id: model.id,
|
||||
model: model.id,
|
||||
name: model.name,
|
||||
description: model.description,
|
||||
capabilities:
|
||||
'capabilities' in model
|
||||
? (model.capabilities as string[])
|
||||
: [
|
||||
ModelCapabilities.COMPLETION,
|
||||
...(Object.values(DefaultToolUseSupportedModels).some((v) =>
|
||||
model.id.toLowerCase().includes(v.toLowerCase())
|
||||
)
|
||||
? [ModelCapabilities.TOOLS]
|
||||
: []),
|
||||
],
|
||||
provider: providerName,
|
||||
settings: Object.values(modelSettings).reduce(
|
||||
(acc, setting) => {
|
||||
let value = setting.controller_props.value
|
||||
if (setting.key === 'ctx_len') {
|
||||
value = 8192 // Default context length for Llama.cpp models
|
||||
}
|
||||
// Set temperature to 0.6 for DefaultToolUseSupportedModels
|
||||
if (
|
||||
Object.values(DefaultToolUseSupportedModels).some((v) =>
|
||||
model.id.toLowerCase().includes(v.toLowerCase())
|
||||
)
|
||||
) {
|
||||
if (setting.key === 'temperature') value = 0.7 // Default temperature for tool-supported models
|
||||
if (setting.key === 'top_k') value = 20 // Default top_k for tool-supported models
|
||||
if (setting.key === 'top_p') value = 0.8 // Default top_p for tool-supported models
|
||||
if (setting.key === 'min_p') value = 0 // Default min_p for tool-supported models
|
||||
}
|
||||
acc[setting.key] = {
|
||||
...setting,
|
||||
controller_props: {
|
||||
...setting.controller_props,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, ProviderSetting>
|
||||
),
|
||||
})),
|
||||
models: await Promise.all(
|
||||
models.map(
|
||||
async (model) =>
|
||||
({
|
||||
id: model.id,
|
||||
model: model.id,
|
||||
name: model.name,
|
||||
description: model.description,
|
||||
capabilities:
|
||||
'capabilities' in model
|
||||
? (model.capabilities as string[])
|
||||
: (await isToolSupported(model.id))
|
||||
? [ModelCapabilities.TOOLS]
|
||||
: [],
|
||||
provider: providerName,
|
||||
settings: Object.values(modelSettings).reduce(
|
||||
(acc, setting) => {
|
||||
let value = setting.controller_props.value
|
||||
if (setting.key === 'ctx_len') {
|
||||
value = 8192 // Default context length for Llama.cpp models
|
||||
}
|
||||
acc[setting.key] = {
|
||||
...setting,
|
||||
controller_props: {
|
||||
...setting.controller_props,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, ProviderSetting>
|
||||
),
|
||||
}) as Model
|
||||
)
|
||||
),
|
||||
}
|
||||
runtimeProviders.push(provider)
|
||||
}
|
||||
|
||||
@ -6,4 +6,5 @@ export type MCPTool = {
|
||||
name: string
|
||||
description: string
|
||||
inputSchema: Record<string, unknown>
|
||||
server: string
|
||||
}
|
||||
|
||||
@ -13,11 +13,4 @@ export enum ModelCapabilities {
|
||||
IMAGE_TO_IMAGE = 'image_to_image',
|
||||
TEXT_TO_AUDIO = 'text_to_audio',
|
||||
AUDIO_TO_TEXT = 'audio_to_text',
|
||||
}
|
||||
|
||||
// TODO: Remove this enum when we integrate llama.cpp extension
|
||||
export enum DefaultToolUseSupportedModels {
|
||||
JanNano = 'jan-',
|
||||
Qwen3 = 'qwen3',
|
||||
Lucy = 'lucy',
|
||||
}
|
||||
}
|
||||
74
web-app/src/utils/reasoning.ts
Normal file
@ -0,0 +1,74 @@
|
||||
import { CompletionResponseChunk } from 'token.js'
|
||||
import {
|
||||
chatCompletionChunk,
|
||||
ChatCompletionMessage,
|
||||
chatCompletionRequestMessage,
|
||||
} from '@janhq/core'
|
||||
|
||||
// Extract reasoning from a message (for completed responses)
|
||||
export function extractReasoningFromMessage(
|
||||
message: chatCompletionRequestMessage | ChatCompletionMessage
|
||||
): string | null {
|
||||
if (!message) return null
|
||||
|
||||
const extendedMessage = message as chatCompletionRequestMessage
|
||||
return extendedMessage.reasoning || null
|
||||
}
|
||||
|
||||
// Extract reasoning from a chunk (for streaming responses)
|
||||
function extractReasoningFromChunk(
|
||||
chunk: CompletionResponseChunk | chatCompletionChunk
|
||||
): string | null {
|
||||
if (!chunk.choices?.[0]?.delta) return null
|
||||
|
||||
const delta = chunk.choices[0].delta as chatCompletionRequestMessage
|
||||
const reasoning = delta.reasoning
|
||||
|
||||
// Return null for falsy values, non-strings, or whitespace-only strings
|
||||
if (!reasoning || typeof reasoning !== 'string' || !reasoning.trim())
|
||||
return null
|
||||
|
||||
return reasoning
|
||||
}
|
||||
|
||||
// Tracks reasoning state and appends reasoning tokens with proper think tags
|
||||
export class ReasoningProcessor {
|
||||
private isReasoningActive = false
|
||||
|
||||
processReasoningChunk(
|
||||
chunk: CompletionResponseChunk | chatCompletionChunk
|
||||
): string {
|
||||
const reasoning = extractReasoningFromChunk(chunk)
|
||||
const chunkContent = chunk.choices?.[0]?.delta?.content || ''
|
||||
|
||||
// Handle reasoning tokens
|
||||
if (reasoning) {
|
||||
if (!this.isReasoningActive) {
|
||||
this.isReasoningActive = true
|
||||
return '<think>' + reasoning
|
||||
}
|
||||
return reasoning
|
||||
}
|
||||
|
||||
// Handle reasoning end when content starts
|
||||
if (this.isReasoningActive && chunkContent) {
|
||||
this.isReasoningActive = false
|
||||
return '</think>'
|
||||
}
|
||||
|
||||
// No reasoning to process
|
||||
return ''
|
||||
}
|
||||
|
||||
finalize(): string {
|
||||
if (this.isReasoningActive) {
|
||||
this.isReasoningActive = false
|
||||
return '</think>'
|
||||
}
|
||||
return ''
|
||||
}
|
||||
|
||||
isReasoningInProgress(): boolean {
|
||||
return this.isReasoningActive
|
||||
}
|
||||
}
|
||||