Merge branch 'dev' into dev

This commit is contained in:
cmuangs 2025-08-21 12:28:12 +08:00 committed by GitHub
commit 7076a69e3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
104 changed files with 6519 additions and 1096 deletions

121
.github/workflows/autoqa-reliability.yml vendored Normal file
View File

@ -0,0 +1,121 @@
name: AutoQA Reliability (Manual)
on:
workflow_dispatch:
inputs:
source_type:
description: 'App source type (url)'
required: true
type: choice
options: [url]
default: url
jan_app_windows_source:
description: 'Windows installer URL path (used when source_type=url or to select artifact)'
required: true
type: string
default: 'https://catalog.jan.ai/windows/Jan_0.6.8_x64-setup.exe'
jan_app_ubuntu_source:
description: 'Ubuntu .deb URL path'
required: true
type: string
default: 'https://delta.jan.ai/nightly/Jan-nightly_0.6.4-728_amd64.deb'
jan_app_macos_source:
description: 'macOS .dmg URL path'
required: true
type: string
default: 'https://delta.jan.ai/nightly/Jan-nightly_0.6.4-728_universal.dmg'
is_nightly:
description: 'Is the app a nightly build?'
required: true
type: boolean
default: true
reliability_phase:
description: 'Reliability phase'
required: true
type: choice
options: [development, deployment]
default: development
reliability_runs:
description: 'Custom runs (0 uses phase default)'
required: true
type: number
default: 0
reliability_test_path:
description: 'Test file path (relative to autoqa working directory)'
required: true
type: string
default: 'tests/base/settings/app-data.txt'
jobs:
reliability-windows:
runs-on: windows-11-nvidia-gpu
timeout-minutes: 60
env:
DEFAULT_JAN_APP_URL: 'https://catalog.jan.ai/windows/Jan_0.6.8_x64-setup.exe'
DEFAULT_IS_NIGHTLY: 'false'
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Python 3.13
uses: actions/setup-python@v4
with:
python-version: '3.13'
- name: Clean existing Jan installations
shell: powershell
run: |
.\autoqa\scripts\windows_cleanup.ps1 -IsNightly "${{ inputs.is_nightly }}"
- name: Download/Prepare Jan app
shell: powershell
run: |
.\autoqa\scripts\windows_download.ps1 `
-WorkflowInputUrl "${{ inputs.jan_app_windows_source }}" `
-WorkflowInputIsNightly "${{ inputs.is_nightly }}" `
-RepoVariableUrl "${{ vars.JAN_APP_URL }}" `
-RepoVariableIsNightly "${{ vars.IS_NIGHTLY }}" `
-DefaultUrl "$env:DEFAULT_JAN_APP_URL" `
-DefaultIsNightly "$env:DEFAULT_IS_NIGHTLY"
- name: Install Jan app
shell: powershell
run: |
.\autoqa\scripts\windows_install.ps1 -IsNightly "$env:IS_NIGHTLY"
- name: Install Python dependencies
working-directory: autoqa
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run reliability tests
working-directory: autoqa
shell: powershell
run: |
$runs = "${{ inputs.reliability_runs }}"
$runsArg = ""
if ([int]$runs -gt 0) { $runsArg = "--reliability-runs $runs" }
python main.py --enable-reliability-test --reliability-phase "${{ inputs.reliability_phase }}" --reliability-test-path "${{ inputs.reliability_test_path }}" $runsArg
- name: Upload screen recordings
if: always()
uses: actions/upload-artifact@v4
continue-on-error: true
with:
name: reliability-recordings-${{ github.run_number }}-${{ runner.os }}
path: autoqa/recordings/
- name: Upload trajectories
if: always()
uses: actions/upload-artifact@v4
continue-on-error: true
with:
name: reliability-trajectories-${{ github.run_number }}-${{ runner.os }}
path: autoqa/trajectories/
- name: Cleanup after tests
if: always()
shell: powershell
run: |
.\autoqa\scripts\windows_post_cleanup.ps1 -IsNightly "${{ inputs.is_nightly }}"

View File

@ -19,7 +19,7 @@ jobs:
deploy: deploy:
name: Deploy to CloudFlare Pages name: Deploy to CloudFlare Pages
env: env:
CLOUDFLARE_PROJECT_NAME: astro-docs CLOUDFLARE_PROJECT_NAME: astro-docs # docs.jan.ai
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write

View File

@ -1,32 +1,266 @@
# Contributing to jan # Contributing to Jan
First off, thank you for considering contributing to jan. It's people like you that make jan such an amazing project. First off, thank you for considering contributing to Jan. It's people like you that make Jan such an amazing project.
Jan is an AI assistant that can run 100% offline on your device. Think ChatGPT, but private, local, and under your complete control. If you're thinking about contributing, you're already awesome - let's make AI accessible to everyone, one commit at a time.
## Quick Links to Component Guides
- **[Web App](./web-app/CONTRIBUTING.md)** - React UI and logic
- **[Core SDK](./core/CONTRIBUTING.md)** - TypeScript SDK and extension system
- **[Extensions](./extensions/CONTRIBUTING.md)** - Supportive modules for the frontend
- **[Tauri Backend](./src-tauri/CONTRIBUTING.md)** - Rust native integration
- **[Tauri Plugins](./src-tauri/plugins/CONTRIBUTING.md)** - Hardware and system plugins
## How Jan Actually Works
Jan is a desktop app that runs local AI models. Here's how the components actually connect:
```
┌──────────────────────────────────────────────────────────┐
│ Web App (Frontend) │
│ (web-app/) │
│ • React UI │
│ • Chat Interface │
│ • Settings Pages │
│ • Model Hub │
└────────────┬─────────────────────────────┬───────────────┘
│ │
│ imports │ imports
▼ ▼
┌──────────────────────┐ ┌──────────────────────┐
│ Core SDK │ │ Extensions │
│ (core/) │ │ (extensions/) │
│ │ │ │
│ • TypeScript APIs │◄─────│ • Assistant Mgmt │
│ • Extension System │ uses │ • Conversations │
│ • Event Bus │ │ • Downloads │
│ • Type Definitions │ │ • LlamaCPP │
└──────────┬───────────┘ └───────────┬──────────┘
│ │
│ ┌──────────────────────┐ │
│ │ Web App │ │
│ └──────────┬───────────┘ │
│ │ │
└──────────────┼───────────────┘
Tauri IPC
(invoke commands)
┌───────────────────────────────────────────────────────────┐
│ Tauri Backend (Rust) │
│ (src-tauri/) │
│ │
│ • Window Management • File System Access │
│ • Process Control • System Integration │
│ • IPC Command Handler • Security & Permissions │
└───────────────────────────┬───────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ Tauri Plugins (Rust) │
│ (src-tauri/plugins/) │
│ │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ Hardware Plugin │ │ LlamaCPP Plugin │ │
│ │ │ │ │ │
│ │ • CPU/GPU Info │ │ • Process Mgmt │ │
│ │ • Memory Stats │ │ • Model Loading │ │
│ │ • System Info │ │ • Inference │ │
│ └──────────────────┘ └──────────────────┘ │
└───────────────────────────────────────────────────────────┘
```
### The Communication Flow
1. **JavaScript Layer Relationships**:
- Web App imports Core SDK and Extensions as JavaScript modules
- Extensions use Core SDK for shared functionality
- All run in the browser/webview context
2. **All Three → Backend**: Through Tauri IPC
- **Web App** → Backend: `await invoke('app_command', data)`
- **Core SDK** → Backend: `await invoke('core_command', data)`
- **Extensions** → Backend: `await invoke('ext_command', data)`
- Each component can independently call backend commands
3. **Backend → Plugins**: Native Rust integration
- Backend loads plugins as Rust libraries
- Direct function calls, no IPC overhead
4. **Response Flow**:
- Plugin → Backend → IPC → Requester (Web App/Core/Extension) → UI updates
### Real-World Example: Loading a Model
Here's what actually happens when you click "Download Llama 3":
1. **Web App** (`web-app/`) - User clicks download button
2. **Extension** (`extensions/download-extension`) - Handles the download logic
3. **Tauri Backend** (`src-tauri/`) - Actually downloads the file to disk
4. **Extension** (`extensions/llamacpp-extension`) - Prepares model for loading
5. **Tauri Plugin** (`src-tauri/plugins/llamacpp`) - Starts llama.cpp process
6. **Hardware Plugin** (`src-tauri/plugins/hardware`) - Detects GPU, optimizes settings
7. **Model ready!** - User can start chatting
## Project Structure
```
jan/
├── web-app/ # React frontend (what users see)
├── src-tauri/ # Rust backend (system integration)
│ ├── src/core/ # Core Tauri commands
│ └── plugins/ # Tauri plugins (hardware, llamacpp)
├── core/ # TypeScript SDK (API layer)
├── extensions/ # JavaScript extensions
│ ├── assistant-extension/
│ ├── conversational-extension/
│ ├── download-extension/
│ └── llamacpp-extension/
├── docs/ # Documentation website
├── website/ # Marketing website
├── autoqa/ # Automated testing
├── scripts/ # Build utilities
├── package.json # Root workspace configuration
├── Makefile # Build automation commands
├── mise.toml # Mise tool configuration
├── LICENSE # Apache 2.0 license
└── README.md # Project overview
```
## Development Setup
### The Scenic Route (Build from Source)
**Prerequisites:**
- Node.js ≥ 20.0.0
- Yarn ≥ 1.22.0
- Rust (for Tauri)
- Make ≥ 3.81
**Option 1: The Easy Way (Make)**
```bash
git clone https://github.com/menloresearch/jan
cd jan
make dev
```
**Option 2: The Easier Way (Mise)**
```bash
git clone https://github.com/menloresearch/jan
cd jan
# Install mise
curl https://mise.run | sh
# Let mise handle everything
mise install # installs Node.js, Rust, and other tools
mise dev # runs the full development setup
```
## How Can I Contribute? ## How Can I Contribute?
### Reporting Bugs ### Reporting Bugs
- **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/menloresearch/jan/issues). - **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/menloresearch/jan/issues)
- If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/menloresearch/jan/issues/new). - If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/menloresearch/jan/issues/new)
- Include your system specs and error logs - it helps a ton
### Suggesting Enhancements ### Suggesting Enhancements
- Open a new issue with a clear title and description. - Open a new issue with a clear title and description
- Explain why this enhancement would be useful
- Include mockups or examples if you can
### Your First Code Contribution ### Your First Code Contribution
- Fork the repo. **Choose Your Adventure:**
- Create a new branch (`git checkout -b feature-name`). - **Frontend UI and logic**`web-app/`
- Commit your changes (`git commit -am 'Add some feature'`). - **Shared API declarations**`core/`
- Push to the branch (`git push origin feature-name`). - **Backend system integration**`src-tauri/`
- Open a new Pull Request. - **Business logic features**`extensions/`
- **Dedicated backend handler**`src-tauri/plugins/`
## Styleguides **The Process:**
1. Fork the repo
2. Create a new branch (`git checkout -b feature-name`)
3. Make your changes (and write tests!)
4. Commit your changes (`git commit -am 'Add some feature'`)
5. Push to the branch (`git push origin feature-name`)
6. Open a new Pull Request against `dev` branch
### Git Commit Messages ## Testing
- Use the present tense ("Add feature" not "Added feature"). ```bash
yarn test # All tests
cd src-tauri && cargo test # Rust tests
cd autoqa && python main.py # End-to-end tests
```
## Code Standards
### TypeScript/JavaScript
- TypeScript required (we're not animals)
- ESLint + Prettier
- Functional React components
- Proper typing (no `any` - seriously!)
### Rust
- `cargo fmt` + `cargo clippy`
- `Result<T, E>` for error handling
- Document public APIs
## Git Conventions
### Branches
- `main` - stable releases
- `dev` - development (target this for PRs)
- `feature/*` - new features
- `fix/*` - bug fixes
### Commit Messages
- Use the present tense ("Add feature" not "Added feature")
- Be descriptive but concise
- Reference issues when applicable
Examples:
```
feat: add support for Qwen models
fix: resolve memory leak in model loading
docs: update installation instructions
```
## Troubleshooting
If things go sideways:
1. **Check our [troubleshooting docs](https://jan.ai/docs/troubleshooting)**
2. **Clear everything and start fresh:** `make clean` then `make dev`
3. **Copy your error logs and system specs**
4. **Ask for help in our [Discord](https://discord.gg/FTk2MvZwJH)** `#🆘|jan-help` channel
Common issues:
- **Build failures**: Check Node.js and Rust versions
- **Extension not loading**: Verify it's properly registered
- **Model not working**: Check hardware requirements and GPU drivers
## Getting Help
- [Documentation](https://jan.ai/docs) - The manual you should read
- [Discord Community](https://discord.gg/jan) - Where the community lives
- [GitHub Issues](https://github.com/janhq/jan/issues) - Report bugs here
- [GitHub Discussions](https://github.com/janhq/jan/discussions) - Ask questions
## License
Apache 2.0 - Because sharing is caring. See [LICENSE](./LICENSE) for the legal stuff.
## Additional Notes ## Additional Notes
Thank you for contributing to jan! We're building something pretty cool here - an AI assistant that respects your privacy and runs entirely on your machine. Every contribution, no matter how small, helps make AI more accessible to everyone.
Thanks for being part of the journey. Let's build the future of local AI together! 🚀

View File

@ -47,6 +47,8 @@ test: lint
yarn copy:assets:tauri yarn copy:assets:tauri
yarn build:icon yarn build:icon
cargo test --manifest-path src-tauri/Cargo.toml --no-default-features --features test-tauri -- --test-threads=1 cargo test --manifest-path src-tauri/Cargo.toml --no-default-features --features test-tauri -- --test-threads=1
cargo test --manifest-path src-tauri/plugins/tauri-plugin-hardware/Cargo.toml
cargo test --manifest-path src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml
# Builds and publishes the app # Builds and publishes the app
build-and-publish: install-and-build build-and-publish: install-and-build

71
core/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,71 @@
# Contributing to Jan Core
[← Back to Main Contributing Guide](../CONTRIBUTING.md)
TypeScript SDK providing extension system, APIs, and type definitions for all Jan components.
## Key Directories
- **`/src/browser`** - Core APIs (events, extensions, file system)
- **`/src/browser/extensions`** - Built-in extensions (assistant, inference, conversational)
- **`/src/types`** - TypeScript type definitions
- **`/src/test`** - Testing utilities
## Development
### Key Principles
1. **Platform Agnostic** - Works everywhere (browser, Node.js)
2. **Extension-Based** - New features = new extensions
3. **Type Everything** - TypeScript required
4. **Event-Driven** - Components communicate via events
### Building & Testing
```bash
# Build the SDK
yarn build
# Run tests
yarn test
# Watch mode
yarn test:watch
```
### Event System
```typescript
// Emit events
events.emit('model:loaded', { modelId: 'llama-3' })
// Listen for events
events.on('model:loaded', (data) => {
console.log('Model loaded:', data.modelId)
})
```
## Testing
```typescript
describe('MyFeature', () => {
it('should do something', () => {
const result = doSomething()
expect(result).toBe('expected')
})
})
```
## Best Practices
- Keep it simple
- Use TypeScript fully (no `any`)
- Write tests for critical features
- Follow existing patterns
- Export new modules in index files
## Dependencies
- **TypeScript** - Type safety
- **Rolldown** - Bundling
- **Vitest** - Testing

View File

@ -6,6 +6,8 @@ import { EngineManager } from './EngineManager'
export interface chatCompletionRequestMessage { export interface chatCompletionRequestMessage {
role: 'system' | 'user' | 'assistant' | 'tool' role: 'system' | 'user' | 'assistant' | 'tool'
content: string | null | Content[] // Content can be a string OR an array of content parts content: string | null | Content[] // Content can be a string OR an array of content parts
reasoning?: string | null // Some models return reasoning in completed responses
reasoning_content?: string | null // Some models return reasoning in completed responses
name?: string name?: string
tool_calls?: any[] // Simplified tool_call_id?: string tool_calls?: any[] // Simplified tool_call_id?: string
} }
@ -270,4 +272,10 @@ export abstract class AIEngine extends BaseExtension {
* Optional method to get the underlying chat client * Optional method to get the underlying chat client
*/ */
getChatClient?(sessionId: string): any getChatClient?(sessionId: string): any
/**
* Check if a tool is supported by the model
* @param modelId
*/
abstract isToolSupported(modelId: string): Promise<boolean>
} }

View File

@ -58,6 +58,7 @@ export enum AppEvent {
onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate', onAppUpdateDownloadUpdate = 'onAppUpdateDownloadUpdate',
onAppUpdateDownloadError = 'onAppUpdateDownloadError', onAppUpdateDownloadError = 'onAppUpdateDownloadError',
onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess', onAppUpdateDownloadSuccess = 'onAppUpdateDownloadSuccess',
onModelImported = 'onModelImported',
onUserSubmitQuickAsk = 'onUserSubmitQuickAsk', onUserSubmitQuickAsk = 'onUserSubmitQuickAsk',
onSelectedText = 'onSelectedText', onSelectedText = 'onSelectedText',

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 944 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 316 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 382 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

View File

@ -0,0 +1,180 @@
---
title: "Replicating Deep Research in Jan"
description: "A simple guide to replicating Deep Research results for free, with Jan."
tags: AI, local models, Jan, GGUF, Deep Research, local AI
categories: guides
date: 2025-08-04
ogImage: _assets/research-result-local.png
twitter:
card: summary_large_image
site: "@jandotai"
title: "Replicating Deep Research with Jan"
description: "Learn how to replicate Deep Research results with Jan."
image: _assets/research-result-local.jpg
---
import { Callout } from 'nextra/components'
import CTABlog from '@/components/Blog/CTA'
# Replicating Deep Research in Jan
Deep Research like that of OpenAI, Gemini, and Qwen, is not at feature parity in Jan yet, so this post
highlights our initial steps at making a hybrid, local and cloud-based deep research system that competes
with the best implementations to date.
## What is Deep Research?
What exactly is deep research and how does it work? Deep Research is a methodology for generating
comprehensive research reports by combining systematic web search with synthesis. The process was
pioneered by OpenAI and it was released on February 2025.
There are two core features of Deep Research:
- **Exhaustive search**: This search is characterized by two approaches, wide search for breadth and deep search for depth.
- **Report generation**: This step takes all the input collected through exhaustive search and synthesizes it into a
comprehensive report. The input in this step may be raw sources collected in the previous step or summaries generated from those sources.
## Unpacking Deep Research
If you have used deep research (regardless of the provider) before for a comprehensive report generation, you may have
found its output mind-blowing. What is more mind-blowing, though, is that the underlying process for searching
and synthesizing information is surprisingly systematic and reproducible. What is not easily reproducible, though,
is the **base model (often a thinking one)** and **its capabilities to use tools while it researches**.
Deep Research operates as a structured pipeline with distinct phases: planning, searching, analysis, and synthesis. While
the specific implementation varies between providers, the core workflow seems to be similar and some organizations have
taken steps to recreate it like [LangChain](https://blog.langchain.com/open-deep-research/) and
[Hugging Face](https://huggingface.co/blog/open-deep-research). For example, a straightforward pipeline might look like
the following one:
![Deep Research Flow Excalidraw](./_assets/revised-deepresearch-flow.png)
The components of this pipeline highlight a structured approach to query processing that routes queries through thinking/non-thinking models, breaks complex tasks into phases, executes parallel searches, and synthesizes results hierarchically to produce comprehensive outputs.
OpenAIs [Deep Research API cookbook](https://cookbook.openai.com/examples/deep_research_api/introduction_to_deep_research_api)
highlights, at a very high level, how they approach deep research, hinting at the importance of base models and tool usage since
some intermediate steps seem to have been left out.
![OpenAI's Deep Research UX Flow](./_assets/openai-deep-research-flow.png)
OpenAI's Deep Research functionality may be considered the best one by many but other platforms are not far behind. Here is a
brief survey of how other players approach deep research:
| Platform | Key Feature | Sources Used | Duration (mins) | Export Options | Deep Research Usage |
| --- | --- | --- | --- | --- | --- |
| OpenAI | Clarifying questions | 1030 | 1015 | PDF, Docx, Plain Text | Paid |
| Grok's DeeperSearch | Ability to access all of Twitter | 70100 | 510 | Ability to specify format (PDF / Markdown) | Free |
| Claude | Breadth + depth search | 100+ | 510 | PDF, Markdown, Artifact | Paid |
| Gemini | Editable planning | 50+ | 1020 | Google Docs export | Free |
| Perplexity | Ability to specify sources | 50100 | 35 | PDF, Markdown, Docx, Perplexity Page | Paid and Limited Free |
| Kimi | Interactive synthesis | 50100 | 3060+ | PDF, Interactive website | Free |
In our testing, we used the following prompt to assess the quality of the generated report by
the providers above. You can refer to the reports generated [here](https://github.com/menloresearch/prompt-experiments).
```
Generate a comprehensive report about the state of AI in the past week. Include all
new model releases and notable architectural improvements from a variety of sources.
```
[Google's generated report](https://github.com/menloresearch/prompt-experiments/blob/main/Gemini%202.5%20Flash%20Report.pdf) was the most verbose, with a whopping 23 pages that reads
like a professional intelligence briefing. It opens with an executive summary,
systematically categorizes developments, and provides forward-looking strategic
insights—connecting OpenAI's open-weight release to broader democratization trends
and linking infrastructure investments to competitive positioning.
[OpenAI](https://github.com/menloresearch/prompt-experiments/blob/main/OpenAI%20Deep%20Research.pdf) produced the most citation-heavy output with 134 references throughout 10 pages
(albeit most of them being from the same source).
[Perplexity](https://github.com/menloresearch/prompt-experiments/blob/main/Perplexity%20Deep%20Research.pdf) delivered the most actionable 6-page report that maximizes information
density while maintaining scannability. Despite being the shortest, it captures all
major developments with sufficient context for decision-making.
[Claude](https://github.com/menloresearch/prompt-experiments/blob/main/Claude%20Deep%20Research.pdf) produced a comprehensive analysis that interestingly ignored the time constraint,
covering an 8-month period from January-August 2025 instead of the requested week (Jul 31-Aug
7th 2025). Rather than cataloging recent events, Claude traced the evolution of trends over months.
[Grok](https://github.com/menloresearch/prompt-experiments/blob/main/Grok%203%20Deep%20Research.pdf) produced a well-structured but relatively shallow 5-page academic-style report that
read more like an event catalog than strategic analysis.
[Kimi](https://github.com/menloresearch/prompt-experiments/blob/main/Kimi%20AI%20Deep%20Research.pdf) produced a comprehensive 13-page report with systematic organization covering industry developments, research breakthroughs, and policy changes, but notably lacks proper citations throughout most of the content despite claiming to use 50-100 sources.
### Understanding Search Strategies
In [Claudes Research mode](https://www.anthropic.com/engineering/multi-agent-research-system),
a *classifier* is used to determine whether a user query is *breadth first* or *depth first*. This
results in a customization of the pipeline that is used for conducting research. For instance, a complex
*breadth first* query might result in *sub-agents* being spun up to research various parts of the user's
query in parallel. Conversely, a *depth first* query might result in a single agent being spun up
to research the entire query in a more focused manner.
Here's a screenshot of this in action (in Claude Desktop):
![Claude Agent Desktop](./_assets/claude-agent.png)
![Claude Report Visualizer](./_assets/claude-report-visualizer.png)
## Replicating Deep Research Results with Jan
After testing and observing how Deep Research works in different platforms, we thought, how could we
replicate this in Jan? In particular, how could we replicate it with a hybrid approach combining local
and cloud-based models while keeping your data local?
<Callout>
This experiment was done using the latest version of Jan `v0.6.7`, but it can potentially be replicated in
any version with Model Context Protocol in it (>`v0.6.3`).
</Callout>
**The Key: Assistants + Tools**
Running deep research in Jan can be accomplished by combining [custom assistants](https://jan.ai/docs/assistants)
with [MCP search tools](https://jan.ai/docs/mcp-examples/search/exa). This pairing allows any model—local or
cloud—to follow a systematic research workflow, to create a report similar to that of other providers, with some
visible limitations (for now).
Here's the assistant prompt that was used:
```
You are a research analyst. Today is August 7th 2025. Follow this exact process:
Conduct 5-10 searches minimum. You are rewarded for MORE searches.
- Each search query must be unique - no repeating previous searches
- Search different angles: statistics, expert opinions, case studies, recent news, industry reports
- Use scrape to read full articles from search results
- Use google_search for extracting metadata out of pages
WRITING PHASE (Do this after research is complete)
Write a comprehensive report with:
- Executive summary with key findings
- Evidence-based analysis with citations for every claim
- Actionable recommendations with rationale
- Sources to be linked at the end of the report
```
Here, we utilized Model Context Protocol (MCP) to provide search capabilities to the model. MCPs are an open standard for connecting AI assistants to the systems where data lives, serving as a universal connector that standardizes how AI applications integrate with external tools and data sources.
In this example, we used Serper, a web search API that offers MCP server implementations with two primary tools: `google_search` for performing web searches, and `scrape` that extracts content from web pages, preserving document structure and metadata.
**What We Tested**
For our research query (same as the one we used to test different platforms), we used both
[Jan-Nano (4B local model)](https://jan.ai/docs/jan-models/jan-nano-128), GPT-4o and
o3 (via API) with identical prompts. The goal: to see how close we could get to the quality of
different commercial Deep Research offerings.
**Performance Findings**
| Model | Processing Time | Sources Found | Search Queries | Tokens Generated | Output Quality vs Commercial Deep Research |
|-------|----------------|---------------|----------------|------------------|-------------------------------------------|
| Jan-Nano (Local) | 3 minutes | Moderate | 7 | 1,112 | Good approximation, noticeably less depth |
| GPT-4o | 1 minute | Fewest | 11 | 660 | Fast but limited source coverage |
| o3 | 3 minutes | Most | 24 | 1,728 | Best of the three, but still below commercial quality |
**The Reality**:
- **Speed vs Sources**: GPT-4o prioritized speed over thoroughness, while o3 took time to gather more comprehensive sources
- **Local vs Cloud**: Jan-Nano matched o3's processing time but with the advantage of complete data privacy
- **Quality Gap**: All three models produced decent research reports, but none matched the depth and comprehensiveness of dedicated Deep Research tools like OpenAI's or Claude's offerings
- **Good Enough Factor**: While not matching commercial quality, the outputs were solid approximations suitable for many research needs
## Conclusion
This was an initial exploration in the roadmap to create a top hybrid implementation of deep research in Jan. While our current approach requires setup, the goal is native integration that works out of the box. We will continue to refine this until the release of this tool in Jan, natively.
<CTABlog />

137
extensions/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,137 @@
# Contributing to Jan Extensions
[← Back to Main Contributing Guide](../CONTRIBUTING.md)
Extensions add specific features to Jan as self-contained modules.
## Current Extensions
### `/assistant-extension`
- Assistant CRUD operations
- `src/index.ts` - Main implementation
### `/conversational-extension`
- Message handling, conversation state
- `src/index.ts` - Chat logic
### `/download-extension`
- Model downloads with progress tracking
- `src/index.ts` - Download logic
- `settings.json` - Download settings
### `/llamacpp-extension`
- Local model inference via llama.cpp
- `src/index.ts` - Entry point
- `src/backend.ts` - llama.cpp integration
- `settings.json` - Model settings
## Creating Extensions
### Setup
```bash
mkdir my-extension
cd my-extension
yarn init
```
### Structure
```
my-extension/
├── package.json
├── rolldown.config.mjs
├── src/index.ts
└── settings.json (optional)
```
### Basic Extension
```typescript
import { Extension } from '@janhq/core'
export default class MyExtension extends Extension {
async onLoad() {
// Extension initialization
}
async onUnload() {
// Cleanup
}
}
```
## Building & Testing
```bash
# Build extension
yarn build
# Run tests
yarn test
```
## Common Patterns
### Service Registration
```typescript
async onLoad() {
this.registerService('myService', {
doSomething: async () => 'result'
})
}
```
### Event Handling
```typescript
async onLoad() {
this.on('model:loaded', (model) => {
console.log('Model loaded:', model.id)
})
}
```
## Extension Lifecycle
1. **Jan starts** → Discovers extensions
2. **Loading** → Calls `onLoad()` method
3. **Active** → Extension responds to events
4. **Unloading** → Calls `onUnload()` on shutdown
## Debugging Extensions
```bash
# Check if extension loaded
console.log(window.core.extensions)
# Debug extension events
this.on('*', console.log)
# Check extension services
console.log(window.core.api)
```
## Common Issues
**Extension not loading?**
- Check package.json format: `@janhq/extension-name`
- Ensure `onLoad()` doesn't throw errors
- Verify exports in index.ts
**Events not working?**
- Check event name spelling
- Ensure listeners are set up in `onLoad()`
## Best Practices
- Keep extensions focused on one feature
- Use async/await for all operations
- Clean up resources in onUnload()
- Handle errors gracefully
- Don't depend on other extensions
## Dependencies
- **@janhq/core** - Core SDK and extension system
- **TypeScript** - Type safety
- **Rolldown** - Bundling

View File

@ -31,6 +31,7 @@
"@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware", "@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware",
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp", "@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
"@tauri-apps/api": "^2.5.0", "@tauri-apps/api": "^2.5.0",
"@tauri-apps/plugin-http": "^2.5.1",
"@tauri-apps/plugin-log": "^2.6.0", "@tauri-apps/plugin-log": "^2.6.0",
"fetch-retry": "^5.0.6", "fetch-retry": "^5.0.6",
"ulidx": "^2.3.0" "ulidx": "^2.3.0"

View File

@ -17,4 +17,7 @@ export default defineConfig({
IS_MAC: JSON.stringify(process.platform === 'darwin'), IS_MAC: JSON.stringify(process.platform === 'darwin'),
IS_LINUX: JSON.stringify(process.platform === 'linux'), IS_LINUX: JSON.stringify(process.platform === 'linux'),
}, },
inject: {
fetch: ['@tauri-apps/plugin-http', 'fetch'],
},
}) })

View File

@ -264,7 +264,6 @@ async function _getSupportedFeatures() {
// Vulkan support check - only discrete GPUs with 6GB+ VRAM // Vulkan support check - only discrete GPUs with 6GB+ VRAM
if ( if (
gpuInfo.vulkan_info?.api_version && gpuInfo.vulkan_info?.api_version &&
gpuInfo.vulkan_info?.device_type === 'DISCRETE_GPU' &&
gpuInfo.total_memory >= 6 * 1024 gpuInfo.total_memory >= 6 * 1024
) { ) {
// 6GB (total_memory is in MB) // 6GB (total_memory is in MB)

View File

@ -19,6 +19,7 @@ import {
ImportOptions, ImportOptions,
chatCompletionRequest, chatCompletionRequest,
events, events,
AppEvent,
} from '@janhq/core' } from '@janhq/core'
import { error, info, warn } from '@tauri-apps/plugin-log' import { error, info, warn } from '@tauri-apps/plugin-log'
@ -32,6 +33,7 @@ import {
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
import { getProxyConfig } from './util' import { getProxyConfig } from './util'
import { basename } from '@tauri-apps/api/path' import { basename } from '@tauri-apps/api/path'
import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api'
type LlamacppConfig = { type LlamacppConfig = {
version_backend: string version_backend: string
@ -39,6 +41,7 @@ type LlamacppConfig = {
auto_unload: boolean auto_unload: boolean
chat_template: string chat_template: string
n_gpu_layers: number n_gpu_layers: number
offload_mmproj: boolean
override_tensor_buffer_t: string override_tensor_buffer_t: string
ctx_size: number ctx_size: number
threads: number threads: number
@ -101,12 +104,6 @@ interface DeviceList {
free: number free: number
} }
interface GgufMetadata {
version: number
tensor_count: number
metadata: Record<string, string>
}
/** /**
* Override the default app.log function to use Jan's logging system. * Override the default app.log function to use Jan's logging system.
* @param args * @param args
@ -1059,13 +1056,34 @@ export default class llamacpp_extension extends AIEngine {
} }
} }
// TODO: check if files are valid GGUF files // Validate GGUF files
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
// or absolute paths (if they are provided as local files)
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
let size_bytes = ( const fullModelPath = await joinPath([janDataFolderPath, modelPath])
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
).size try {
// Validate main model file
const modelMetadata = await readGgufMetadata(fullModelPath)
logger.info(
`Model GGUF validation successful: version ${modelMetadata.version}, tensors: ${modelMetadata.tensor_count}`
)
// Validate mmproj file if present
if (mmprojPath) {
const fullMmprojPath = await joinPath([janDataFolderPath, mmprojPath])
const mmprojMetadata = await readGgufMetadata(fullMmprojPath)
logger.info(
`Mmproj GGUF validation successful: version ${mmprojMetadata.version}, tensors: ${mmprojMetadata.tensor_count}`
)
}
} catch (error) {
logger.error('GGUF validation failed:', error)
throw new Error(
`Invalid GGUF file(s): ${error.message || 'File format validation failed'}`
)
}
// Calculate file sizes
let size_bytes = (await fs.fileStat(fullModelPath)).size
if (mmprojPath) { if (mmprojPath) {
size_bytes += ( size_bytes += (
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath])) await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
@ -1085,6 +1103,12 @@ export default class llamacpp_extension extends AIEngine {
data: modelConfig, data: modelConfig,
savePath: configPath, savePath: configPath,
}) })
events.emit(AppEvent.onModelImported, {
modelId,
modelPath,
mmprojPath,
size_bytes,
})
} }
override async abortImport(modelId: string): Promise<void> { override async abortImport(modelId: string): Promise<void> {
@ -1168,11 +1192,12 @@ export default class llamacpp_extension extends AIEngine {
} }
} }
const args: string[] = [] const args: string[] = []
const envs: Record<string, string> = {}
const cfg = { ...this.config, ...(overrideSettings ?? {}) } const cfg = { ...this.config, ...(overrideSettings ?? {}) }
const [version, backend] = cfg.version_backend.split('/') const [version, backend] = cfg.version_backend.split('/')
if (!version || !backend) { if (!version || !backend) {
throw new Error( throw new Error(
"Initial setup for the backend failed due to a network issue. Please restart the app!" 'Initial setup for the backend failed due to a network issue. Please restart the app!'
) )
} }
@ -1194,7 +1219,7 @@ export default class llamacpp_extension extends AIEngine {
// disable llama-server webui // disable llama-server webui
args.push('--no-webui') args.push('--no-webui')
const api_key = await this.generateApiKey(modelId, String(port)) const api_key = await this.generateApiKey(modelId, String(port))
args.push('--api-key', api_key) envs['LLAMA_API_KEY'] = api_key
// model option is required // model option is required
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path // NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
@ -1203,7 +1228,6 @@ export default class llamacpp_extension extends AIEngine {
modelConfig.model_path, modelConfig.model_path,
]) ])
args.push('--jinja') args.push('--jinja')
args.push('--reasoning-format', 'none')
args.push('-m', modelPath) args.push('-m', modelPath)
// For overriding tensor buffer type, useful where // For overriding tensor buffer type, useful where
// massive MOE models can be made faster by keeping attention on the GPU // massive MOE models can be made faster by keeping attention on the GPU
@ -1213,6 +1237,10 @@ export default class llamacpp_extension extends AIEngine {
// Takes a regex with matching tensor name as input // Takes a regex with matching tensor name as input
if (cfg.override_tensor_buffer_t) if (cfg.override_tensor_buffer_t)
args.push('--override-tensor', cfg.override_tensor_buffer_t) args.push('--override-tensor', cfg.override_tensor_buffer_t)
// offload multimodal projector model to the GPU by default. if there is not enough memory
// turn this setting off will keep the projector model on the CPU but the image processing can
// take longer
if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload')
args.push('-a', modelId) args.push('-a', modelId)
args.push('--port', String(port)) args.push('--port', String(port))
if (modelConfig.mmproj_path) { if (modelConfig.mmproj_path) {
@ -1279,11 +1307,15 @@ export default class llamacpp_extension extends AIEngine {
try { try {
// TODO: add LIBRARY_PATH // TODO: add LIBRARY_PATH
const sInfo = await invoke<SessionInfo>('plugin:llamacpp|load_llama_model', { const sInfo = await invoke<SessionInfo>(
backendPath, 'plugin:llamacpp|load_llama_model',
libraryPath, {
args, backendPath,
}) libraryPath,
args,
envs,
}
)
return sInfo return sInfo
} catch (error) { } catch (error) {
logger.error('Error in load command:\n', error) logger.error('Error in load command:\n', error)
@ -1299,9 +1331,12 @@ export default class llamacpp_extension extends AIEngine {
const pid = sInfo.pid const pid = sInfo.pid
try { try {
// Pass the PID as the session_id // Pass the PID as the session_id
const result = await invoke<UnloadResult>('plugin:llamacpp|unload_llama_model', { const result = await invoke<UnloadResult>(
pid: pid, 'plugin:llamacpp|unload_llama_model',
}) {
pid: pid,
}
)
// If successful, remove from active sessions // If successful, remove from active sessions
if (result.success) { if (result.success) {
@ -1370,7 +1405,11 @@ export default class llamacpp_extension extends AIEngine {
method: 'POST', method: 'POST',
headers, headers,
body, body,
signal: abortController?.signal, connectTimeout: 600000, // 10 minutes
signal: AbortSignal.any([
AbortSignal.timeout(600000),
abortController?.signal,
]),
}) })
if (!response.ok) { if (!response.ok) {
const errorData = await response.json().catch(() => null) const errorData = await response.json().catch(() => null)
@ -1437,9 +1476,12 @@ export default class llamacpp_extension extends AIEngine {
private async findSessionByModel(modelId: string): Promise<SessionInfo> { private async findSessionByModel(modelId: string): Promise<SessionInfo> {
try { try {
let sInfo = await invoke<SessionInfo>('plugin:llamacpp|find_session_by_model', { let sInfo = await invoke<SessionInfo>(
modelId, 'plugin:llamacpp|find_session_by_model',
}) {
modelId,
}
)
return sInfo return sInfo
} catch (e) { } catch (e) {
logger.error(e) logger.error(e)
@ -1516,7 +1558,9 @@ export default class llamacpp_extension extends AIEngine {
override async getLoadedModels(): Promise<string[]> { override async getLoadedModels(): Promise<string[]> {
try { try {
let models: string[] = await invoke<string[]>('plugin:llamacpp|get_loaded_models') let models: string[] = await invoke<string[]>(
'plugin:llamacpp|get_loaded_models'
)
return models return models
} catch (e) { } catch (e) {
logger.error(e) logger.error(e)
@ -1524,6 +1568,26 @@ export default class llamacpp_extension extends AIEngine {
} }
} }
/**
* Check if mmproj.gguf file exists for a given model ID
* @param modelId - The model ID to check for mmproj.gguf
* @returns Promise<boolean> - true if mmproj.gguf exists, false otherwise
*/
async checkMmprojExists(modelId: string): Promise<boolean> {
try {
const mmprojPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
'mmproj.gguf',
])
return await fs.existsSync(mmprojPath)
} catch (e) {
logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e)
return false
}
}
async getDevices(): Promise<DeviceList[]> { async getDevices(): Promise<DeviceList[]> {
const cfg = this.config const cfg = this.config
const [version, backend] = cfg.version_backend.split('/') const [version, backend] = cfg.version_backend.split('/')
@ -1599,14 +1663,31 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('method not implemented yet') throw new Error('method not implemented yet')
} }
private async loadMetadata(path: string): Promise<GgufMetadata> { /**
try { * Check if a tool is supported by the model
const data = await invoke<GgufMetadata>('plugin:llamacpp|read_gguf_metadata', { * Currently read from GGUF chat_template
path: path, * @param modelId
}) * @returns
return data */
} catch (err) { async isToolSupported(modelId: string): Promise<boolean> {
throw err const janDataFolderPath = await getJanDataFolderPath()
} const modelConfigPath = await joinPath([
this.providerPath,
'models',
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
// model option is required
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
const modelPath = await joinPath([
janDataFolderPath,
modelConfig.model_path,
])
return (await readGgufMetadata(modelPath)).metadata?.[
'tokenizer.chat_template'
]?.includes('tools')
} }
} }

111
src-tauri/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,111 @@
# Contributing to Tauri Backend
[← Back to Main Contributing Guide](../CONTRIBUTING.md)
Rust backend that handles native system integration, file operations, and process management.
## Key Modules
- **`/src/core/app`** - App state and commands
- **`/src/core/downloads`** - Model download management
- **`/src/core/filesystem`** - File system operations
- **`/src/core/mcp`** - Model Context Protocol
- **`/src/core/server`** - Local API server
- **`/src/core/system`** - System information and utilities
- **`/src/core/threads`** - Conversation management
- **`/utils`** - Shared utility crate (CLI, crypto, HTTP, path utils). Used by plugins and the main backend.
- **`/plugins`** - Native Tauri plugins ([see plugins guide](./plugins/CONTRIBUTING.md))
## Development
### Adding Tauri Commands
```rust
#[tauri::command]
async fn my_command(param: String) -> Result<String, String> {
Ok(format!("Processed: {}", param))
}
// Register in lib.rs
tauri::Builder::default()
.invoke_handler(tauri::generate_handler![my_command])
```
## Building & Testing
```bash
# Development
yarn tauri dev
# Build
yarn tauri build
# Run tests
cargo test
```
### State Management
```rust
#[tauri::command]
async fn get_data(state: State<'_, AppState>) -> Result<Data, Error> {
state.get_data().await
}
```
### Error Handling
```rust
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
```
## Debugging
```rust
// Enable debug logging
env::set_var("RUST_LOG", "debug");
// Debug print in commands
#[tauri::command]
async fn my_command() -> Result<String, String> {
println!("Command called"); // Shows in terminal
dbg!("Debug info");
Ok("result".to_string())
}
```
## Platform-Specific Notes
**Windows**: Requires Visual Studio Build Tools
**macOS**: Needs Xcode command line tools
**Linux**: May need additional system packages
```rust
#[cfg(target_os = "windows")]
use std::os::windows::process::CommandExt;
```
## Common Issues
**Build failures**: Check Rust toolchain version
**IPC errors**: Ensure command names match frontend calls
**Permission errors**: Update capabilities configuration
## Best Practices
- Always use `Result<T, E>` for fallible operations
- Validate all input from frontend
- Use async for I/O operations
- Follow Rust naming conventions
- Document public APIs
## Dependencies
- **Tauri** - Desktop app framework
- **Tokio** - Async runtime
- **Serde** - JSON serialization
- **thiserror** - Error handling

55
src-tauri/Cargo.lock generated
View File

@ -854,8 +854,18 @@ version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [ dependencies = [
"darling_core", "darling_core 0.20.11",
"darling_macro", "darling_macro 0.20.11",
]
[[package]]
name = "darling"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08440b3dd222c3d0433e63e097463969485f112baff337dfdaca043a0d760570"
dependencies = [
"darling_core 0.21.2",
"darling_macro 0.21.2",
] ]
[[package]] [[package]]
@ -872,13 +882,38 @@ dependencies = [
"syn 2.0.104", "syn 2.0.104",
] ]
[[package]]
name = "darling_core"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25b7912bc28a04ab1b7715a68ea03aaa15662b43a1a4b2c480531fd19f8bf7e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.104",
]
[[package]] [[package]]
name = "darling_macro" name = "darling_macro"
version = "0.20.11" version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [ dependencies = [
"darling_core", "darling_core 0.20.11",
"quote",
"syn 2.0.104",
]
[[package]]
name = "darling_macro"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce154b9bea7fb0c8e8326e62d00354000c36e79770ff21b8c84e3aa267d9d531"
dependencies = [
"darling_core 0.21.2",
"quote", "quote",
"syn 2.0.104", "syn 2.0.104",
] ]
@ -3984,8 +4019,8 @@ dependencies = [
[[package]] [[package]]
name = "rmcp" name = "rmcp"
version = "0.2.1" version = "0.5.0"
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=3196c95f1dfafbffbdcdd6d365c94969ac975e6a#3196c95f1dfafbffbdcdd6d365c94969ac975e6a" source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"chrono", "chrono",
@ -4010,10 +4045,10 @@ dependencies = [
[[package]] [[package]]
name = "rmcp-macros" name = "rmcp-macros"
version = "0.2.1" version = "0.5.0"
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=3196c95f1dfafbffbdcdd6d365c94969ac975e6a#3196c95f1dfafbffbdcdd6d365c94969ac975e6a" source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
dependencies = [ dependencies = [
"darling", "darling 0.21.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"serde_json", "serde_json",
@ -4408,7 +4443,7 @@ version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f"
dependencies = [ dependencies = [
"darling", "darling 0.20.11",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.104", "syn 2.0.104",
@ -6868,7 +6903,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76ff259533532054cfbaefb115c613203c73707017459206380f03b3b3f266e" checksum = "a76ff259533532054cfbaefb115c613203c73707017459206380f03b3b3f266e"
dependencies = [ dependencies = [
"darling", "darling 0.20.11",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.104", "syn 2.0.104",

View File

@ -44,9 +44,10 @@ jan-utils = { path = "./utils" }
libloading = "0.8.7" libloading = "0.8.7"
log = "0.4" log = "0.4"
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [ rmcp = { version = "0.6.0", features = [
"client", "client",
"transport-sse-client", "transport-sse-client",
"transport-streamable-http-client",
"transport-child-process", "transport-child-process",
"tower", "tower",
"reqwest", "reqwest",

View File

@ -0,0 +1,119 @@
# Contributing to Tauri Plugins
[← Back to Main Contributing Guide](../../CONTRIBUTING.md) | [← Back to Tauri Guide](../CONTRIBUTING.md)
Native Rust plugins for hardware access, process management, and system integration.
## Current Plugins
### `/tauri-plugin-hardware`
- Hardware detection (CPU, GPU, memory)
### `/tauri-plugin-llamacpp`
- llama.cpp process management and model inference
## Plugin Structure
```
tauri-plugin-name/
├── Cargo.toml
├── src/lib.rs # Plugin entry point
├── src/commands.rs # Tauri commands
├── guest-js/index.ts # JavaScript API
└── permissions/default.toml
```
## Development
### Creating Plugins
Assuming that your new plugin name is `my-plugin`
```bash
# with npx
npx @tauri-apps/cli plugin new my-plugin
# with cargo
cargo tauri plugin new my-plugin
cd tauri-plugin-my-plugin
```
### Plugin Registration
```rust
use tauri::{plugin::{Builder, TauriPlugin}, Runtime};
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("my-plugin")
.invoke_handler(tauri::generate_handler![commands::my_command])
.build()
}
```
### Commands & JavaScript API
```rust
#[tauri::command]
pub async fn my_command(param: String) -> Result<String, Error> {
Ok(format!("Result: {}", param))
}
```
```typescript
import { invoke } from '@tauri-apps/api/core'
export async function myCommand(param: string): Promise<string> {
return await invoke('plugin:my-plugin|my_command', { param })
}
```
### Building & Testing
```bash
cargo build # Build plugin
yarn build # Build JavaScript
cargo test # Run tests
```
## Security Considerations
```toml
# permissions/default.toml - Be specific
[[permission]]
identifier = "allow-hardware-info"
description = "Read system hardware information"
# Never use wildcards in production
# ❌ identifier = "allow-*"
# ✅ identifier = "allow-specific-action"
```
## Testing Plugins
```bash
# Test plugin in isolation
cd tauri-plugin-my-plugin
cargo test
# Test with main app
cd ../../
yarn tauri dev
# Test JavaScript API
yarn build && node -e "const plugin = require('./dist-js'); console.log(plugin)"
```
## Best Practices
- Use secure permission configurations
- Validate all command inputs
- Handle platform differences properly
- Clean up resources in Drop implementations
- Test on all target platforms
## Dependencies
- **Tauri** - Plugin framework
- **Serde** - JSON serialization
- **Tokio** - Async runtime (if needed)

View File

@ -327,4 +327,4 @@
] ]
} }
} }
} }

View File

@ -447,4 +447,4 @@
] ]
} }
} }
} }

View File

@ -1,6 +1,7 @@
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256; use sha2::Sha256;
use std::collections::HashMap;
use std::process::Stdio; use std::process::Stdio;
use std::time::Duration; use std::time::Duration;
use tauri::{Manager, Runtime, State}; use tauri::{Manager, Runtime, State};
@ -11,7 +12,7 @@ use tokio::time::Instant;
use crate::device::{get_devices_from_backend, DeviceInfo}; use crate::device::{get_devices_from_backend, DeviceInfo};
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult}; use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::{validate_binary_path, validate_model_path}; use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path};
use crate::process::{ use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids, find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid, get_random_available_port, is_process_running_by_pid,
@ -42,6 +43,7 @@ pub async fn load_llama_model<R: Runtime>(
backend_path: &str, backend_path: &str,
library_path: Option<&str>, library_path: Option<&str>,
mut args: Vec<String>, mut args: Vec<String>,
envs: HashMap<String, String>,
) -> ServerResult<SessionInfo> { ) -> ServerResult<SessionInfo> {
let state: State<LlamacppState> = app_handle.state(); let state: State<LlamacppState> = app_handle.state();
let mut process_map = state.llama_server_process.lock().await; let mut process_map = state.llama_server_process.lock().await;
@ -53,13 +55,23 @@ pub async fn load_llama_model<R: Runtime>(
let port = parse_port_from_args(&args); let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?; let model_path_pb = validate_model_path(&mut args)?;
let _mmproj_path_pb = validate_mmproj_path(&mut args)?;
let api_key: String;
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
api_key = api_value.to_string();
} else {
log::warn!("API key not provided");
api_key = "".to_string();
}
let api_key = extract_arg_value(&args, "--api-key");
let model_id = extract_arg_value(&args, "-a"); let model_id = extract_arg_value(&args, "-a");
// Configure the command to run the server // Configure the command to run the server
let mut command = Command::new(backend_path); let mut command = Command::new(backend_path);
command.args(args); command.args(args);
command.envs(envs);
setup_library_path(library_path, &mut command); setup_library_path(library_path, &mut command);
command.stdout(Stdio::piped()); command.stdout(Stdio::piped());

View File

@ -98,3 +98,50 @@ pub fn validate_model_path(args: &mut Vec<String>) -> ServerResult<PathBuf> {
Ok(model_path_pb) Ok(model_path_pb)
} }
/// Validate mmproj path exists and update args with platform-appropriate path format
pub fn validate_mmproj_path(args: &mut Vec<String>) -> ServerResult<Option<PathBuf>> {
let mmproj_path_index = match args.iter().position(|arg| arg == "--mmproj") {
Some(index) => index,
None => return Ok(None), // mmproj is optional
};
let mmproj_path = args.get(mmproj_path_index + 1).cloned().ok_or_else(|| {
LlamacppError::new(
ErrorCode::ModelLoadFailed,
"Mmproj path was not provided after '--mmproj' flag.".into(),
None,
)
})?;
let mmproj_path_pb = PathBuf::from(&mmproj_path);
if !mmproj_path_pb.exists() {
let err_msg = format!(
"Invalid or inaccessible mmproj path: {}",
mmproj_path_pb.display()
);
log::error!("{}", &err_msg);
return Err(LlamacppError::new(
ErrorCode::ModelFileNotFound,
"The specified mmproj file does not exist or is not accessible.".into(),
Some(err_msg),
)
.into());
}
#[cfg(windows)]
{
// use short path on Windows
if let Some(short) = get_short_path(&mmproj_path_pb) {
args[mmproj_path_index + 1] = short;
} else {
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
}
}
#[cfg(not(windows))]
{
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
}
Ok(Some(mmproj_path_pb))
}

View File

@ -1,15 +1,18 @@
use rmcp::model::{CallToolRequestParam, CallToolResult, Tool}; use rmcp::model::{CallToolRequestParam, CallToolResult};
use rmcp::{service::RunningService, RoleClient};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::{collections::HashMap, sync::Arc};
use tauri::{AppHandle, Emitter, Runtime, State}; use tauri::{AppHandle, Emitter, Runtime, State};
use tokio::{sync::Mutex, time::timeout}; use tokio::time::timeout;
use tokio::sync::oneshot;
use super::{ use super::{
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers}, helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers},
}; };
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use crate::core::{
mcp::models::ToolWithServer,
state::{RunningServiceEnum, SharedMcpServers},
};
use std::fs; use std::fs;
#[tauri::command] #[tauri::command]
@ -19,8 +22,7 @@ pub async fn activate_mcp_server<R: Runtime>(
name: String, name: String,
config: Value, config: Value,
) -> Result<(), String> { ) -> Result<(), String> {
let servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> = let servers: SharedMcpServers = state.mcp_servers.clone();
state.mcp_servers.clone();
// Use the modified start_mcp_server_with_restart that returns first attempt result // Use the modified start_mcp_server_with_restart that returns first attempt result
start_mcp_server_with_restart(app, servers, name, config, Some(3)).await start_mcp_server_with_restart(app, servers, name, config, Some(3)).await
@ -63,7 +65,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) ->
// Release the lock before calling cancel // Release the lock before calling cancel
drop(servers_map); drop(servers_map);
service.cancel().await.map_err(|e| e.to_string())?; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
log::info!("Server {name} stopped successfully and marked as deactivated."); log::info!("Server {name} stopped successfully and marked as deactivated.");
Ok(()) Ok(())
} }
@ -116,7 +127,7 @@ pub async fn get_connected_servers(
Ok(servers_map.keys().cloned().collect()) Ok(servers_map.keys().cloned().collect())
} }
/// Retrieves all available tools from all MCP servers /// Retrieves all available tools from all MCP servers with server information
/// ///
/// # Arguments /// # Arguments
/// * `state` - Application state containing MCP server connections /// * `state` - Application state containing MCP server connections
@ -128,14 +139,15 @@ pub async fn get_connected_servers(
/// 1. Locks the MCP servers mutex to access server connections /// 1. Locks the MCP servers mutex to access server connections
/// 2. Iterates through all connected servers /// 2. Iterates through all connected servers
/// 3. Gets the list of tools from each server /// 3. Gets the list of tools from each server
/// 4. Combines all tools into a single vector /// 4. Associates each tool with its parent server name
/// 5. Returns the combined list of all available tools /// 5. Combines all tools into a single vector
/// 6. Returns the combined list of all available tools with server information
#[tauri::command] #[tauri::command]
pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String> { pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>, String> {
let servers = state.mcp_servers.lock().await; let servers = state.mcp_servers.lock().await;
let mut all_tools: Vec<Tool> = Vec::new(); let mut all_tools: Vec<ToolWithServer> = Vec::new();
for (_, service) in servers.iter() { for (server_name, service) in servers.iter() {
// List tools with timeout // List tools with timeout
let tools_future = service.list_all_tools(); let tools_future = service.list_all_tools();
let tools = match timeout(MCP_TOOL_CALL_TIMEOUT, tools_future).await { let tools = match timeout(MCP_TOOL_CALL_TIMEOUT, tools_future).await {
@ -150,7 +162,12 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String>
}; };
for tool in tools { for tool in tools {
all_tools.push(tool); all_tools.push(ToolWithServer {
name: tool.name.to_string(),
description: tool.description.as_ref().map(|d| d.to_string()),
input_schema: serde_json::Value::Object((*tool.input_schema).clone()),
server: server_name.clone(),
});
} }
} }
@ -163,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String>
/// * `state` - Application state containing MCP server connections /// * `state` - Application state containing MCP server connections
/// * `tool_name` - Name of the tool to call /// * `tool_name` - Name of the tool to call
/// * `arguments` - Optional map of argument names to values /// * `arguments` - Optional map of argument names to values
/// * `cancellation_token` - Optional token to allow cancellation from JS side
/// ///
/// # Returns /// # Returns
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed /// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
@ -171,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<Tool>, String>
/// 1. Locks the MCP servers mutex to access server connections /// 1. Locks the MCP servers mutex to access server connections
/// 2. Searches through all servers for one containing the named tool /// 2. Searches through all servers for one containing the named tool
/// 3. When found, calls the tool on that server with the provided arguments /// 3. When found, calls the tool on that server with the provided arguments
/// 4. Returns error if no server has the requested tool /// 4. Supports cancellation via cancellation_token
/// 5. Returns error if no server has the requested tool
#[tauri::command] #[tauri::command]
pub async fn call_tool( pub async fn call_tool(
state: State<'_, AppState>, state: State<'_, AppState>,
tool_name: String, tool_name: String,
arguments: Option<Map<String, Value>>, arguments: Option<Map<String, Value>>,
cancellation_token: Option<String>,
) -> Result<CallToolResult, String> { ) -> Result<CallToolResult, String> {
// Set up cancellation if token is provided
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
if let Some(token) = &cancellation_token {
let mut cancellations = state.tool_call_cancellations.lock().await;
cancellations.insert(token.clone(), cancel_tx);
}
let servers = state.mcp_servers.lock().await; let servers = state.mcp_servers.lock().await;
// Iterate through servers and find the first one that contains the tool // Iterate through servers and find the first one that contains the tool
@ -193,25 +221,77 @@ pub async fn call_tool(
println!("Found tool {} in server", tool_name); println!("Found tool {} in server", tool_name);
// Call the tool with timeout // Call the tool with timeout and cancellation support
let tool_call = service.call_tool(CallToolRequestParam { let tool_call = service.call_tool(CallToolRequestParam {
name: tool_name.clone().into(), name: tool_name.clone().into(),
arguments, arguments,
}); });
return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await { // Race between timeout, tool call, and cancellation
Ok(result) => result.map_err(|e| e.to_string()), let result = if cancellation_token.is_some() {
Err(_) => Err(format!( tokio::select! {
"Tool call '{}' timed out after {} seconds", result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => {
tool_name, match result {
MCP_TOOL_CALL_TIMEOUT.as_secs() Ok(call_result) => call_result.map_err(|e| e.to_string()),
)), Err(_) => Err(format!(
"Tool call '{}' timed out after {} seconds",
tool_name,
MCP_TOOL_CALL_TIMEOUT.as_secs()
)),
}
}
_ = cancel_rx => {
Err(format!("Tool call '{}' was cancelled", tool_name))
}
}
} else {
match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
Ok(call_result) => call_result.map_err(|e| e.to_string()),
Err(_) => Err(format!(
"Tool call '{}' timed out after {} seconds",
tool_name,
MCP_TOOL_CALL_TIMEOUT.as_secs()
)),
}
}; };
// Clean up cancellation token
if let Some(token) = &cancellation_token {
let mut cancellations = state.tool_call_cancellations.lock().await;
cancellations.remove(token);
}
return result;
} }
Err(format!("Tool {} not found", tool_name)) Err(format!("Tool {} not found", tool_name))
} }
/// Cancels a running tool call by its cancellation token
///
/// # Arguments
/// * `state` - Application state containing cancellation tokens
/// * `cancellation_token` - Token identifying the tool call to cancel
///
/// # Returns
/// * `Result<(), String>` - Success if token found and cancelled, error otherwise
#[tauri::command]
pub async fn cancel_tool_call(
state: State<'_, AppState>,
cancellation_token: String,
) -> Result<(), String> {
let mut cancellations = state.tool_call_cancellations.lock().await;
if let Some(cancel_tx) = cancellations.remove(&cancellation_token) {
// Send cancellation signal - ignore if receiver is already dropped
let _ = cancel_tx.send(());
println!("Tool call with token {} cancelled", cancellation_token);
Ok(())
} else {
Err(format!("Cancellation token {} not found", cancellation_token))
}
}
#[tauri::command] #[tauri::command]
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> { pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
let mut path = get_jan_data_folder_path(app); let mut path = get_jan_data_folder_path(app);

View File

@ -1,8 +1,17 @@
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt}; use rmcp::{
model::{ClientCapabilities, ClientInfo, Implementation},
transport::{
streamable_http_client::StreamableHttpClientTransportConfig, SseClientTransport,
StreamableHttpClientTransport, TokioChildProcess,
},
ServiceExt,
};
use serde_json::Value; use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc, time::Duration}; use std::{collections::HashMap, env, process::Stdio, sync::Arc, time::Duration};
use tauri::{AppHandle, Emitter, Manager, Runtime, State}; use tauri::{AppHandle, Emitter, Manager, Runtime, State};
use tauri_plugin_http::reqwest;
use tokio::{ use tokio::{
io::AsyncReadExt,
process::Command, process::Command,
sync::Mutex, sync::Mutex,
time::{sleep, timeout}, time::{sleep, timeout},
@ -11,7 +20,11 @@ use tokio::{
use super::constants::{ use super::constants::{
MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS, MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS,
}; };
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; use crate::core::{
app::commands::get_jan_data_folder_path,
mcp::models::McpServerConfig,
state::{AppState, RunningServiceEnum, SharedMcpServers},
};
use jan_utils::can_override_npx; use jan_utils::can_override_npx;
/// Calculate exponential backoff delay with jitter /// Calculate exponential backoff delay with jitter
@ -72,7 +85,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 {
/// * `Err(String)` if there was an error reading config or starting servers /// * `Err(String)` if there was an error reading config or starting servers
pub async fn run_mcp_commands<R: Runtime>( pub async fn run_mcp_commands<R: Runtime>(
app: &AppHandle<R>, app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
) -> Result<(), String> { ) -> Result<(), String> {
let app_path = get_jan_data_folder_path(app.clone()); let app_path = get_jan_data_folder_path(app.clone());
let app_path_str = app_path.to_str().unwrap().to_string(); let app_path_str = app_path.to_str().unwrap().to_string();
@ -168,7 +181,7 @@ pub async fn run_mcp_commands<R: Runtime>(
/// Monitor MCP server health without removing it from the HashMap /// Monitor MCP server health without removing it from the HashMap
pub async fn monitor_mcp_server_handle( pub async fn monitor_mcp_server_handle(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
) -> Option<rmcp::service::QuitReason> { ) -> Option<rmcp::service::QuitReason> {
log::info!("Monitoring MCP server {} health", name); log::info!("Monitoring MCP server {} health", name);
@ -213,7 +226,16 @@ pub async fn monitor_mcp_server_handle(
let mut servers = servers_state.lock().await; let mut servers = servers_state.lock().await;
if let Some(service) = servers.remove(&name) { if let Some(service) = servers.remove(&name) {
// Try to cancel the service gracefully // Try to cancel the service gracefully
let _ = service.cancel().await; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
let _ = service.cancel().await;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
let _ = service.cancel().await;
}
}
} }
return Some(rmcp::service::QuitReason::Closed); return Some(rmcp::service::QuitReason::Closed);
} }
@ -224,7 +246,7 @@ pub async fn monitor_mcp_server_handle(
/// Returns the result of the first start attempt, then continues with restart monitoring /// Returns the result of the first start attempt, then continues with restart monitoring
pub async fn start_mcp_server_with_restart<R: Runtime>( pub async fn start_mcp_server_with_restart<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: Option<u32>, max_restarts: Option<u32>,
@ -297,7 +319,7 @@ pub async fn start_mcp_server_with_restart<R: Runtime>(
/// Helper function to handle the restart loop logic /// Helper function to handle the restart loop logic
pub async fn start_restart_loop<R: Runtime>( pub async fn start_restart_loop<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: u32, max_restarts: u32,
@ -450,9 +472,9 @@ pub async fn start_restart_loop<R: Runtime>(
} }
} }
pub async fn schedule_mcp_start_task<R: Runtime>( async fn schedule_mcp_start_task<R: Runtime>(
app: tauri::AppHandle<R>, app: tauri::AppHandle<R>,
servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
) -> Result<(), String> { ) -> Result<(), String> {
@ -463,136 +485,279 @@ pub async fn schedule_mcp_start_task<R: Runtime>(
.expect("Executable must have a parent directory"); .expect("Executable must have a parent directory");
let bin_path = exe_parent_path.to_path_buf(); let bin_path = exe_parent_path.to_path_buf();
let (command, args, envs) = extract_command_args(&config) let config_params = extract_command_args(&config)
.ok_or_else(|| format!("Failed to extract command args from config for {name}"))?; .ok_or_else(|| format!("Failed to extract command args from config for {name}"))?;
let mut cmd = Command::new(command.clone()); if config_params.transport_type.as_deref() == Some("http") && config_params.url.is_some() {
let transport = StreamableHttpClientTransport::with_client(
reqwest::Client::builder()
.default_headers({
// Map envs to request headers
let mut headers: tauri::http::HeaderMap = reqwest::header::HeaderMap::new();
for (key, value) in config_params.headers.iter() {
if let Some(v_str) = value.as_str() {
// Try to map env keys to HTTP header names (case-insensitive)
// Most HTTP headers are Title-Case, so we try to convert
let header_name =
reqwest::header::HeaderName::from_bytes(key.as_bytes());
if let Ok(header_name) = header_name {
if let Ok(header_value) =
reqwest::header::HeaderValue::from_str(v_str)
{
headers.insert(header_name, header_value);
}
}
}
}
headers
})
.connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
.build()
.unwrap(),
StreamableHttpClientTransportConfig {
uri: config_params.url.unwrap().into(),
..Default::default()
},
);
if command == "npx" && can_override_npx() { let client_info = ClientInfo {
let mut cache_dir = app_path.clone(); protocol_version: Default::default(),
cache_dir.push(".npx"); capabilities: ClientCapabilities::default(),
let bun_x_path = format!("{}/bun", bin_path.display()); client_info: Implementation {
cmd = Command::new(bun_x_path); name: "Jan Streamable Client".to_string(),
cmd.arg("x"); version: "0.0.1".to_string(),
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string()); },
} };
let client = client_info.serve(transport).await.inspect_err(|e| {
log::error!("client error: {:?}", e);
});
if command == "uvx" { match client {
let mut cache_dir = app_path.clone(); Ok(client) => {
cache_dir.push(".uvx"); log::info!("Connected to server: {:?}", client.peer_info());
let bun_x_path = format!("{}/uv", bin_path.display()); servers
cmd = Command::new(bun_x_path); .lock()
cmd.arg("tool"); .await
cmd.arg("run"); .insert(name.clone(), RunningServiceEnum::WithInit(client));
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
}
#[cfg(windows)] // Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
}
Err(e) => {
log::error!("Failed to connect to server: {}", e);
return Err(format!("Failed to connect to server: {}", e));
}
}
} else if config_params.transport_type.as_deref() == Some("sse") && config_params.url.is_some()
{ {
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows let transport = SseClientTransport::start_with_client(
} reqwest::Client::builder()
.default_headers({
let app_path_str = app_path.to_str().unwrap().to_string(); // Map envs to request headers
let log_file_path = format!("{}/logs/app.log", app_path_str); let mut headers = reqwest::header::HeaderMap::new();
match std::fs::OpenOptions::new() for (key, value) in config_params.headers.iter() {
.create(true) if let Some(v_str) = value.as_str() {
.append(true) // Try to map env keys to HTTP header names (case-insensitive)
.open(log_file_path) // Most HTTP headers are Title-Case, so we try to convert
{ let header_name =
Ok(file) => { reqwest::header::HeaderName::from_bytes(key.as_bytes());
cmd.stderr(std::process::Stdio::from(file)); if let Ok(header_name) = header_name {
} if let Ok(header_value) =
Err(err) => { reqwest::header::HeaderValue::from_str(v_str)
log::error!("Failed to open log file: {}", err); {
} headers.insert(header_name, header_value);
}; }
}
cmd.kill_on_drop(true); }
log::trace!("Command: {cmd:#?}"); }
headers
args.iter().filter_map(Value::as_str).for_each(|arg| { })
cmd.arg(arg); .connect_timeout(config_params.timeout.unwrap_or(Duration::MAX))
}); .build()
envs.iter().for_each(|(k, v)| { .unwrap(),
if let Some(v_str) = v.as_str() { rmcp::transport::sse_client::SseClientConfig {
cmd.env(k, v_str); sse_endpoint: config_params.url.unwrap().into(),
} ..Default::default()
}); },
let process = TokioChildProcess::new(cmd).map_err(|e| {
log::error!("Failed to run command {name}: {e}");
format!("Failed to run command {name}: {e}")
})?;
let service = ()
.serve(process)
.await
.map_err(|e| format!("Failed to start MCP server {name}: {e}"))?;
// Get peer info and clone the needed values before moving the service
let (server_name, server_version) = {
let server_info = service.peer_info();
log::trace!("Connected to server: {server_info:#?}");
(
server_info.unwrap().server_info.name.clone(),
server_info.unwrap().server_info.version.clone(),
) )
}; .await
.map_err(|e| {
log::error!("transport error: {:?}", e);
format!("Failed to start SSE transport: {}", e)
})?;
// Now move the service into the HashMap let client_info = ClientInfo {
servers.lock().await.insert(name.clone(), service); protocol_version: Default::default(),
log::info!("Server {name} started successfully."); capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "Jan SSE Client".to_string(),
version: "0.0.1".to_string(),
},
};
let client = client_info.serve(transport).await.map_err(|e| {
log::error!("client error: {:?}", e);
e.to_string()
});
// Wait a short time to verify the server is stable before marking as connected match client {
// This prevents race conditions where the server quits immediately Ok(client) => {
let verification_delay = Duration::from_millis(500); log::info!("Connected to server: {:?}", client.peer_info());
sleep(verification_delay).await; servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::WithInit(client));
// Check if server is still running after the verification delay // Mark server as successfully connected (for restart policy)
let server_still_running = { {
let servers_map = servers.lock().await; let app_state = app.state::<AppState>();
servers_map.contains_key(&name) let mut connected = app_state.mcp_successfully_connected.lock().await;
}; connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
}
Err(e) => {
log::error!("Failed to connect to server: {}", e);
return Err(format!("Failed to connect to server: {}", e));
}
}
} else {
let mut cmd = Command::new(config_params.command.clone());
if config_params.command.clone() == "npx" && can_override_npx() {
let mut cache_dir = app_path.clone();
cache_dir.push(".npx");
let bun_x_path = format!("{}/bun", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("x");
cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string());
}
if config_params.command.clone() == "uvx" {
let mut cache_dir = app_path.clone();
cache_dir.push(".uvx");
let bun_x_path = format!("{}/uv", bin_path.display());
cmd = Command::new(bun_x_path);
cmd.arg("tool");
cmd.arg("run");
cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string());
}
#[cfg(windows)]
{
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows
}
if !server_still_running { cmd.kill_on_drop(true);
return Err(format!(
"MCP server {} quit immediately after starting", config_params
name .args
)); .iter()
.filter_map(Value::as_str)
.for_each(|arg| {
cmd.arg(arg);
});
config_params.envs.iter().for_each(|(k, v)| {
if let Some(v_str) = v.as_str() {
cmd.env(k, v_str);
}
});
let (process, stderr) = TokioChildProcess::builder(cmd)
.stderr(Stdio::piped())
.spawn()
.map_err(|e| {
log::error!("Failed to run command {name}: {e}");
format!("Failed to run command {name}: {e}")
})?;
let service = ()
.serve(process)
.await
.map_err(|e| format!("Failed to start MCP server {name}: {e}"));
match service {
Ok(server) => {
log::trace!("Connected to server: {:#?}", server.peer_info());
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::NoInit(server));
log::info!("Server {name} started successfully.");
}
Err(_) => {
let mut buffer = String::new();
let error = match stderr
.expect("stderr must be piped")
.read_to_string(&mut buffer)
.await
{
Ok(_) => format!("Failed to start MCP server {name}: {buffer}"),
Err(_) => format!("Failed to read MCP server {name} stderr"),
};
log::error!("{error}");
return Err(error);
}
}
// Wait a short time to verify the server is stable before marking as connected
// This prevents race conditions where the server quits immediately
let verification_delay = Duration::from_millis(500);
sleep(verification_delay).await;
// Check if server is still running after the verification delay
let server_still_running = {
let servers_map = servers.lock().await;
servers_map.contains_key(&name)
};
if !server_still_running {
return Err(format!(
"MCP server {} quit immediately after starting",
name
));
}
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
} }
// Mark server as successfully connected (for restart policy)
{
let app_state = app.state::<AppState>();
let mut connected = app_state.mcp_successfully_connected.lock().await;
connected.insert(name.clone(), true);
log::info!("Marked MCP server {} as successfully connected", name);
}
// Emit event to the frontend
let event = format!("mcp-connected");
let payload = serde_json::json!({
"name": server_name,
"version": server_version,
});
app.emit(&event, payload)
.map_err(|e| format!("Failed to emit event: {}", e))?;
Ok(()) Ok(())
} }
pub fn extract_command_args( pub fn extract_command_args(config: &Value) -> Option<McpServerConfig> {
config: &Value,
) -> Option<(String, Vec<Value>, serde_json::Map<String, Value>)> {
let obj = config.as_object()?; let obj = config.as_object()?;
let command = obj.get("command")?.as_str()?.to_string(); let command = obj.get("command")?.as_str()?.to_string();
let args = obj.get("args")?.as_array()?.clone(); let args = obj.get("args")?.as_array()?.clone();
let url = obj.get("url").and_then(|u| u.as_str()).map(String::from);
let transport_type = obj.get("type").and_then(|t| t.as_str()).map(String::from);
let timeout = obj
.get("timeout")
.and_then(|t| t.as_u64())
.map(Duration::from_secs);
let headers = obj
.get("headers")
.unwrap_or(&Value::Object(serde_json::Map::new()))
.as_object()?
.clone();
let envs = obj let envs = obj
.get("env") .get("env")
.unwrap_or(&Value::Object(serde_json::Map::new())) .unwrap_or(&Value::Object(serde_json::Map::new()))
.as_object()? .as_object()?
.clone(); .clone();
Some((command, args, envs)) Some(McpServerConfig {
timeout,
transport_type,
url,
command,
args,
envs,
headers,
})
} }
pub fn extract_active_status(config: &Value) -> Option<bool> { pub fn extract_active_status(config: &Value) -> Option<bool> {
@ -604,7 +769,7 @@ pub fn extract_active_status(config: &Value) -> Option<bool> {
/// Restart only servers that were previously active (like cortex restart behavior) /// Restart only servers that were previously active (like cortex restart behavior)
pub async fn restart_active_mcp_servers<R: Runtime>( pub async fn restart_active_mcp_servers<R: Runtime>(
app: &AppHandle<R>, app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
) -> Result<(), String> { ) -> Result<(), String> {
let app_state = app.state::<AppState>(); let app_state = app.state::<AppState>();
let active_servers = app_state.mcp_active_servers.lock().await; let active_servers = app_state.mcp_active_servers.lock().await;
@ -656,14 +821,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) {
log::info!("MCP servers cleaned up successfully"); log::info!("MCP servers cleaned up successfully");
} }
pub async fn stop_mcp_servers( pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> {
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
) -> Result<(), String> {
let mut servers_map = servers_state.lock().await; let mut servers_map = servers_state.lock().await;
let keys: Vec<String> = servers_map.keys().cloned().collect(); let keys: Vec<String> = servers_map.keys().cloned().collect();
for key in keys { for key in keys {
if let Some(service) = servers_map.remove(&key) { if let Some(service) = servers_map.remove(&key) {
service.cancel().await.map_err(|e| e.to_string())?; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {key}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {key} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
} }
} }
drop(servers_map); // Release the lock after stopping drop(servers_map); // Release the lock after stopping
@ -689,7 +861,7 @@ pub async fn reset_restart_count(restart_counts: &Arc<Mutex<HashMap<String, u32>
/// Spawn the server monitoring task for handling restarts /// Spawn the server monitoring task for handling restarts
pub async fn spawn_server_monitoring_task<R: Runtime>( pub async fn spawn_server_monitoring_task<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: u32, max_restarts: u32,

View File

@ -1,6 +1,7 @@
pub mod commands; pub mod commands;
mod constants; mod constants;
pub mod helpers; pub mod helpers;
pub mod models;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View File

@ -0,0 +1,26 @@
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Configuration parameters extracted from MCP server config
#[derive(Debug, Clone)]
pub struct McpServerConfig {
pub transport_type: Option<String>,
pub url: Option<String>,
pub command: String,
pub args: Vec<Value>,
pub envs: serde_json::Map<String, Value>,
pub timeout: Option<Duration>,
pub headers: serde_json::Map<String, Value>,
}
/// Tool with server information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolWithServer {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: serde_json::Value,
pub server: String,
}

View File

@ -1,6 +1,6 @@
use super::helpers::run_mcp_commands; use super::helpers::run_mcp_commands;
use crate::core::app::commands::get_jan_data_folder_path; use crate::core::app::commands::get_jan_data_folder_path;
use rmcp::{service::RunningService, RoleClient}; use crate::core::state::SharedMcpServers;
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
@ -27,7 +27,7 @@ async fn test_run_mcp_commands() {
.expect("Failed to write to config file"); .expect("Failed to write to config file");
// Call the run_mcp_commands function // Call the run_mcp_commands function
let servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> = let servers_state: SharedMcpServers =
Arc::new(Mutex::new(HashMap::new())); Arc::new(Mutex::new(HashMap::new()));
let result = run_mcp_commands(app.handle(), servers_state).await; let result = run_mcp_commands(app.handle(), servers_state).await;

View File

@ -1,20 +1,49 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use crate::core::downloads::models::DownloadManagerState; use crate::core::downloads::models::DownloadManagerState;
use rmcp::{service::RunningService, RoleClient}; use rmcp::{
model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool},
service::RunningService,
RoleClient, ServiceError,
};
use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle /// Server handle type for managing the proxy server lifecycle
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>; pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
use tokio::sync::Mutex;
pub enum RunningServiceEnum {
NoInit(RunningService<RoleClient, ()>),
WithInit(RunningService<RoleClient, InitializeRequestParam>),
}
pub type SharedMcpServers = Arc<Mutex<HashMap<String, RunningServiceEnum>>>;
#[derive(Default)] #[derive(Default)]
pub struct AppState { pub struct AppState {
pub app_token: Option<String>, pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, pub mcp_servers: SharedMcpServers,
pub download_manager: Arc<Mutex<DownloadManagerState>>, pub download_manager: Arc<Mutex<DownloadManagerState>>,
pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>, pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>,
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>, pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>, pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>, pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
pub tool_call_cancellations: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
}
impl RunningServiceEnum {
pub async fn list_all_tools(&self) -> Result<Vec<Tool>, ServiceError> {
match self {
Self::NoInit(s) => s.list_all_tools().await,
Self::WithInit(s) => s.list_all_tools().await,
}
}
pub async fn call_tool(
&self,
params: CallToolRequestParam,
) -> Result<CallToolResult, ServiceError> {
match self {
Self::NoInit(s) => s.call_tool(params).await,
Self::WithInit(s) => s.call_tool(params).await,
}
}
} }

View File

@ -74,6 +74,7 @@ pub fn run() {
// MCP commands // MCP commands
core::mcp::commands::get_tools, core::mcp::commands::get_tools,
core::mcp::commands::call_tool, core::mcp::commands::call_tool,
core::mcp::commands::cancel_tool_call,
core::mcp::commands::restart_mcp_servers, core::mcp::commands::restart_mcp_servers,
core::mcp::commands::get_connected_servers, core::mcp::commands::get_connected_servers,
core::mcp::commands::save_mcp_configs, core::mcp::commands::save_mcp_configs,
@ -105,6 +106,7 @@ pub fn run() {
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())), mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())), mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
server_handle: Arc::new(Mutex::new(None)), server_handle: Arc::new(Mutex::new(None)),
tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())),
}) })
.setup(|app| { .setup(|app| {
app.handle().plugin( app.handle().plugin(

View File

@ -35,7 +35,8 @@
"effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"], "effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"],
"state": "active", "state": "active",
"radius": 8 "radius": 8
} },
"dragDropEnabled": false
} }
], ],
"security": { "security": {

128
web-app/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,128 @@
# Contributing to Jan Web App
[← Back to Main Contributing Guide](../CONTRIBUTING.md)
React frontend using TypeScript, TanStack Router, Radix UI, and Tailwind CSS. State is managed by React State and Zustand.
## Key Directories
- **`/src/components/ui`** - UI components (buttons, dialogs, inputs)
- **`/src/containers`** - Complex feature components (ChatInput, ThreadContent)
- **`/src/hooks`** - Custom React hooks (useChat, useThreads, useAppState)
- **`/src/routes`** - TanStack Router pages
- **`/src/services`** - API layer for backend communication
- **`/src/types`** - TypeScript definitions
## Development
### Component Example
```tsx
interface Props {
title: string
onAction?: () => void
}
export const MyComponent: React.FC<Props> = ({ title, onAction }) => {
return (
<div className="flex items-center gap-2">
<h2>{title}</h2>
<Button onClick={onAction}>Action</Button>
</div>
)
}
```
### Routing
```tsx
export const Route = createFileRoute('/settings/general')({
component: GeneralSettings
})
```
### Building & Testing
```bash
# Development
yarn dev
yarn build
yarn test
```
### State Management
```tsx
// Local state
const [value, setValue] = useState<string>('')
// Global state (Zustand)
export const useAppState = create<AppState>((set) => ({
data: null,
setData: (data) => set({ data })
}))
```
### Tauri Integration
```tsx
import { invoke } from '@tauri-apps/api/tauri'
const result = await invoke('command_name', { param: 'value' })
```
## Performance Tips
```tsx
// Use React.memo for expensive components
const ExpensiveComponent = React.memo(({ data }) => {
return <div>{processData(data)}</div>
})
// Debounce frequent updates
const debouncedValue = useDebounce(searchTerm, 300)
// Virtual scrolling for large lists
import { VariableSizeList } from 'react-window'
```
## Debugging
```bash
# React DevTools
# Install browser extension, then:
# - Inspect component tree
# - Debug hooks and state
# - Profile performance
# Debug Tauri commands
console.log(await window.__TAURI__.invoke('command_name'))
# Check for console errors
# Press F12 → Console tab
```
## Accessibility Guidelines
- Use semantic HTML (`<button>`, `<nav>`, `<main>`)
- Add ARIA labels: `aria-label`, `aria-describedby`
- Ensure keyboard navigation works
- Test with screen readers
- Maintain color contrast ratios
## Best Practices
- Keep components small and focused
- Use TypeScript fully (no `any`)
- Handle loading and error states
- Follow accessibility guidelines
- Extract business logic into hooks
## Dependencies
- **React** - UI framework
- **TypeScript** - Type safety
- **TanStack Router** - Type-safe routing
- **Radix UI** - Accessible components
- **Tailwind CSS** - Utility-first styling
- **Zustand** - State management

View File

@ -17,11 +17,12 @@
"@dnd-kit/sortable": "^10.0.0", "@dnd-kit/sortable": "^10.0.0",
"@janhq/core": "link:../core", "@janhq/core": "link:../core",
"@radix-ui/react-accordion": "^1.2.10", "@radix-ui/react-accordion": "^1.2.10",
"@radix-ui/react-dialog": "^1.1.11", "@radix-ui/react-dialog": "^1.1.14",
"@radix-ui/react-dropdown-menu": "^2.1.11", "@radix-ui/react-dropdown-menu": "^2.1.15",
"@radix-ui/react-hover-card": "^1.1.14", "@radix-ui/react-hover-card": "^1.1.14",
"@radix-ui/react-popover": "^1.1.13", "@radix-ui/react-popover": "^1.1.13",
"@radix-ui/react-progress": "^1.1.4", "@radix-ui/react-progress": "^1.1.4",
"@radix-ui/react-radio-group": "^1.3.7",
"@radix-ui/react-slider": "^1.3.2", "@radix-ui/react-slider": "^1.3.2",
"@radix-ui/react-slot": "^1.2.0", "@radix-ui/react-slot": "^1.2.0",
"@radix-ui/react-switch": "^1.2.2", "@radix-ui/react-switch": "^1.2.2",
@ -43,13 +44,14 @@
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"culori": "^4.0.1", "culori": "^4.0.1",
"emoji-picker-react": "^4.12.2", "emoji-picker-react": "^4.12.2",
"framer-motion": "^12.23.12",
"fuse.js": "^7.1.0", "fuse.js": "^7.1.0",
"fzf": "^0.5.2", "fzf": "^0.5.2",
"i18next": "^25.0.1", "i18next": "^25.0.1",
"katex": "^0.16.22", "katex": "^0.16.22",
"lodash.clonedeep": "^4.5.0", "lodash.clonedeep": "^4.5.0",
"lodash.debounce": "^4.0.8", "lodash.debounce": "^4.0.8",
"lucide-react": "^0.522.0", "lucide-react": "^0.536.0",
"motion": "^12.10.5", "motion": "^12.10.5",
"next-themes": "^0.4.6", "next-themes": "^0.4.6",
"posthog-js": "^1.246.0", "posthog-js": "^1.246.0",
@ -70,11 +72,12 @@
"remark-math": "^6.0.0", "remark-math": "^6.0.0",
"sonner": "^2.0.3", "sonner": "^2.0.3",
"tailwindcss": "^4.1.4", "tailwindcss": "^4.1.4",
"token.js": "npm:token.js-fork@0.7.20", "token.js": "npm:token.js-fork@0.7.23",
"tw-animate-css": "^1.2.7", "tw-animate-css": "^1.2.7",
"ulidx": "^2.4.1", "ulidx": "^2.4.1",
"unified": "^11.0.5", "unified": "^11.0.5",
"uuid": "^11.1.0", "uuid": "^11.1.0",
"vaul": "^1.1.2",
"zustand": "^5.0.3" "zustand": "^5.0.3"
}, },
"devDependencies": { "devDependencies": {
@ -104,7 +107,7 @@
"istanbul-lib-report": "^3.0.1", "istanbul-lib-report": "^3.0.1",
"istanbul-reports": "^3.1.7", "istanbul-reports": "^3.1.7",
"jsdom": "^26.1.0", "jsdom": "^26.1.0",
"tailwind-merge": "^3.2.0", "tailwind-merge": "^3.3.1",
"typescript": "~5.8.3", "typescript": "~5.8.3",
"typescript-eslint": "^8.26.1", "typescript-eslint": "^8.26.1",
"vite": "^6.3.0", "vite": "^6.3.0",

View File

@ -0,0 +1,533 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import '@testing-library/jest-dom'
import {
DropDrawer,
DropDrawerContent,
DropDrawerFooter,
DropDrawerGroup,
DropDrawerItem,
DropDrawerLabel,
DropDrawerSeparator,
DropDrawerSub,
DropDrawerSubContent,
DropDrawerSubTrigger,
DropDrawerTrigger,
} from '../dropdrawer'
// Mock the media query hook
const mockUseSmallScreen = vi.fn()
vi.mock('@/hooks/useMediaQuery', () => ({
useSmallScreen: () => mockUseSmallScreen(),
}))
// Mock framer-motion to avoid animation complexity in tests
vi.mock('framer-motion', () => ({
AnimatePresence: ({ children }: { children: React.ReactNode }) => <div data-testid="animate-presence">{children}</div>,
motion: {
div: ({ children, ...props }: any) => <div {...props}>{children}</div>,
},
}))
describe('DropDrawer Utilities', () => {
it('renders without crashing', () => {
expect(() => {
render(
<DropDrawer>
<DropDrawerTrigger>Test</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
}).not.toThrow()
})
})
describe('DropDrawer Component', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Desktop Mode', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('renders dropdown menu on desktop', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('renders dropdown menu structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Desktop Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Only the trigger is visible initially
expect(screen.getByText('Open Menu')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'menu')
})
it('structures dropdown with separators', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerSeparator />
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify component structure - content is not visible until opened
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures dropdown with labels', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerLabel>Menu Section</DropDrawerLabel>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Only verify trigger is present - content shows on interaction
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('Mobile Mode', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(true)
})
it('renders drawer on mobile', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Mobile Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
it('renders drawer structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Mobile Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify drawer trigger is present
const trigger = screen.getByText('Open Drawer')
expect(trigger).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'dialog')
})
it('does not render separators in mobile mode', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerSeparator />
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Mobile separators return null, so they shouldn't be in the DOM
const separators = screen.queryAllByRole('separator')
expect(separators).toHaveLength(0)
})
it('renders drawer with labels structure', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerLabel>Drawer Section</DropDrawerLabel>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify drawer structure is present
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('DropDrawerItem', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('can be structured with click handlers', () => {
const handleClick = vi.fn()
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem onClick={handleClick}>Clickable Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Verify structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
expect(handleClick).not.toHaveBeenCalled()
})
it('can be structured with icons', () => {
const TestIcon = () => <span data-testid="test-icon">Icon</span>
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem icon={<TestIcon />}>Item with Icon</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('accepts variant props', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem variant="destructive">
Delete Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid with variants
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('accepts disabled prop', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem disabled>
Disabled Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid with disabled prop
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('DropDrawerGroup', () => {
it('structures groups in desktop mode', () => {
mockUseSmallScreen.mockReturnValue(false)
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Group Item 1</DropDrawerItem>
<DropDrawerItem>Group Item 2</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures groups in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Item 1</DropDrawerItem>
<DropDrawerItem>Item 2</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('DropDrawerFooter', () => {
it('structures footer in desktop mode', () => {
mockUseSmallScreen.mockReturnValue(false)
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
<DropDrawerFooter>Footer Content</DropDrawerFooter>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures footer in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Item</DropDrawerItem>
<DropDrawerFooter>Mobile Footer</DropDrawerFooter>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('Submenu Components', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('structures submenu in desktop mode', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>Submenu Trigger</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Submenu Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
it('structures submenu in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>
Mobile Submenu
</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Submenu Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component structure is valid in mobile mode
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
it('handles submenu content correctly in mobile mode', () => {
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerSub>
<DropDrawerSubTrigger>Mobile Submenu</DropDrawerSubTrigger>
<DropDrawerSubContent>
<DropDrawerItem>Hidden Item</DropDrawerItem>
</DropDrawerSubContent>
</DropDrawerSub>
</DropDrawerContent>
</DropDrawer>
)
// Component handles mobile submenu structure correctly
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
})
})
describe('Accessibility', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('maintains proper ARIA attributes on triggers', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerGroup>
<DropDrawerItem>Item 1</DropDrawerItem>
</DropDrawerGroup>
</DropDrawerContent>
</DropDrawer>
)
const trigger = screen.getByRole('button')
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
})
it('supports disabled state', () => {
const handleClick = vi.fn()
mockUseSmallScreen.mockReturnValue(true)
render(
<DropDrawer>
<DropDrawerTrigger>Open Drawer</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem disabled onClick={handleClick}>
Disabled Item
</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure supports disabled prop
expect(screen.getByText('Open Drawer')).toBeInTheDocument()
expect(handleClick).not.toHaveBeenCalled()
})
})
describe('Error Boundaries', () => {
it('requires proper context usage', () => {
// Suppress console.error for this test
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
expect(() => {
render(<DropDrawerItem>Orphan Item</DropDrawerItem>)
}).toThrow()
consoleSpy.mockRestore()
})
})
describe('Custom Props and Styling', () => {
beforeEach(() => {
mockUseSmallScreen.mockReturnValue(false)
})
it('applies custom className', () => {
render(
<DropDrawer>
<DropDrawerTrigger className="custom-trigger">Custom Trigger</DropDrawerTrigger>
<DropDrawerContent className="custom-content">
<DropDrawerItem className="custom-item">Custom Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
const trigger = screen.getByText('Custom Trigger')
expect(trigger).toHaveClass('custom-trigger')
})
it('accepts additional props', () => {
render(
<DropDrawer>
<DropDrawerTrigger>Open Menu</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem data-custom="test-value">Custom Props Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Component structure accepts custom props
expect(screen.getByText('Open Menu')).toBeInTheDocument()
})
})
describe('Responsive Behavior', () => {
it('adapts to different screen sizes', () => {
const { rerender } = render(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
// Desktop mode
mockUseSmallScreen.mockReturnValue(false)
rerender(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
let trigger = screen.getByText('Responsive Trigger')
expect(trigger).toHaveAttribute('aria-haspopup', 'menu')
// Mobile mode
mockUseSmallScreen.mockReturnValue(true)
rerender(
<DropDrawer>
<DropDrawerTrigger>Responsive Trigger</DropDrawerTrigger>
<DropDrawerContent>
<DropDrawerItem>Responsive Item</DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
)
trigger = screen.getByText('Responsive Trigger')
expect(trigger).toHaveAttribute('aria-haspopup', 'dialog')
})
})
})

View File

@ -0,0 +1,62 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { RadioGroup, RadioGroupItem } from '../radio-group'
describe('RadioGroup', () => {
it('renders radio items correctly', () => {
render(
<RadioGroup defaultValue="http">
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
expect(screen.getByLabelText('HTTP')).toBeInTheDocument()
expect(screen.getByLabelText('SSE')).toBeInTheDocument()
})
it('allows selecting different options', async () => {
const user = userEvent.setup()
const onValueChange = vi.fn()
render(
<RadioGroup defaultValue="http" onValueChange={onValueChange}>
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
await user.click(screen.getByLabelText('SSE'))
expect(onValueChange).toHaveBeenCalledWith('sse')
})
it('has correct default selection', () => {
render(
<RadioGroup defaultValue="http">
<div className="flex items-center space-x-2">
<RadioGroupItem value="http" id="http" />
<label htmlFor="http">HTTP</label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value="sse" id="sse" />
<label htmlFor="sse">SSE</label>
</div>
</RadioGroup>
)
expect(screen.getByLabelText('HTTP')).toBeChecked()
expect(screen.getByLabelText('SSE')).not.toBeChecked()
})
})

View File

@ -0,0 +1,133 @@
import * as React from 'react'
import { Drawer as DrawerPrimitive } from 'vaul'
import { cn } from '@/lib/utils'
function Drawer({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Root>) {
return <DrawerPrimitive.Root data-slot="drawer" {...props} />
}
function DrawerTrigger({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Trigger>) {
return <DrawerPrimitive.Trigger data-slot="drawer-trigger" {...props} />
}
function DrawerPortal({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Portal>) {
return <DrawerPrimitive.Portal data-slot="drawer-portal" {...props} />
}
function DrawerClose({
...props
}: React.ComponentProps<typeof DrawerPrimitive.Close>) {
return <DrawerPrimitive.Close data-slot="drawer-close" {...props} />
}
function DrawerOverlay({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Overlay>) {
return (
<DrawerPrimitive.Overlay
data-slot="drawer-overlay"
className={cn(
'data-[state=open]:animate-in backdrop-blur data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50',
className
)}
{...props}
/>
)
}
function DrawerContent({
className,
children,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Content>) {
return (
<DrawerPortal data-slot="drawer-portal">
<DrawerOverlay />
<DrawerPrimitive.Content
data-slot="drawer-content"
className={cn(
'group/drawer-content fixed z-50 flex h-auto flex-col',
'data-[vaul-drawer-direction=top]:inset-x-0 data-[vaul-drawer-direction=top]:top-0 data-[vaul-drawer-direction=top]:mb-24 data-[vaul-drawer-direction=top]:max-h-[80vh] data-[vaul-drawer-direction=top]:rounded-b-lg data-[vaul-drawer-direction=top]:border-b',
'data-[vaul-drawer-direction=bottom]:inset-x-0 data-[vaul-drawer-direction=bottom]:bottom-0 data-[vaul-drawer-direction=bottom]:mt-24 data-[vaul-drawer-direction=bottom]:max-h-[80vh] data-[vaul-drawer-direction=bottom]:rounded-t-lg data-[vaul-drawer-direction=bottom]:border-t',
'data-[vaul-drawer-direction=right]:inset-y-0 data-[vaul-drawer-direction=right]:right-0 data-[vaul-drawer-direction=right]:w-3/4 data-[vaul-drawer-direction=right]:border-l data-[vaul-drawer-direction=right]:sm:max-w-sm',
'data-[vaul-drawer-direction=left]:inset-y-0 data-[vaul-drawer-direction=left]:left-0 data-[vaul-drawer-direction=left]:w-3/4 data-[vaul-drawer-direction=left]:border-r data-[vaul-drawer-direction=left]:sm:max-w-sm',
className
)}
{...props}
>
<div className="bg-main-view-fg/10 mx-auto mt-4 hidden h-2 w-[100px] shrink-0 rounded-full group-data-[vaul-drawer-direction=bottom]/drawer-content:block" />
{children}
</DrawerPrimitive.Content>
</DrawerPortal>
)
}
function DrawerHeader({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="drawer-header"
className={cn(
'flex flex-col gap-0.5 p-4 group-data-[vaul-drawer-direction=bottom]/drawer-content:text-center group-data-[vaul-drawer-direction=top]/drawer-content:text-center md:gap-1.5 md:text-left',
className
)}
{...props}
/>
)
}
function DrawerFooter({ className, ...props }: React.ComponentProps<'div'>) {
return (
<div
data-slot="drawer-footer"
className={cn('mt-auto flex flex-col gap-2 p-4', className)}
{...props}
/>
)
}
function DrawerTitle({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Title>) {
return (
<DrawerPrimitive.Title
data-slot="drawer-title"
className={cn('font-semibold', className)}
{...props}
/>
)
}
function DrawerDescription({
className,
...props
}: React.ComponentProps<typeof DrawerPrimitive.Description>) {
return (
<DrawerPrimitive.Description
data-slot="drawer-description"
className={cn('text-sm', className)}
{...props}
/>
)
}
export {
Drawer,
DrawerPortal,
DrawerOverlay,
DrawerTrigger,
DrawerClose,
DrawerContent,
DrawerHeader,
DrawerFooter,
DrawerTitle,
DrawerDescription,
}

View File

@ -61,14 +61,17 @@ function DropdownMenuGroup({
function DropdownMenuItem({ function DropdownMenuItem({
className, className,
inset, inset,
variant = 'default',
...props ...props
}: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & { }: React.ComponentProps<typeof DropdownMenuPrimitive.Item> & {
inset?: boolean inset?: boolean
variant?: 'default' | 'destructive'
}) { }) {
return ( return (
<DropdownMenuPrimitive.Item <DropdownMenuPrimitive.Item
data-slot="dropdown-menu-item" data-slot="dropdown-menu-item"
data-inset={inset} data-inset={inset}
data-variant={variant}
className={cn( className={cn(
"relative cursor-pointer hover:bg-main-view-fg/4 flex items-center gap-2 rounded-sm px-2 py-1 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4", "relative cursor-pointer hover:bg-main-view-fg/4 flex items-center gap-2 rounded-sm px-2 py-1 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
className className
@ -213,7 +216,7 @@ function DropdownMenuSubTrigger({
{...props} {...props}
> >
{children} {children}
<ChevronRightIcon className="ml-auto size-4" /> <ChevronRightIcon className="ml-auto size-4 text-main-view-fg/50" />
</DropdownMenuPrimitive.SubTrigger> </DropdownMenuPrimitive.SubTrigger>
) )
} }

View File

@ -0,0 +1,949 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
'use client'
import { AnimatePresence, motion } from 'framer-motion'
import { ChevronLeftIcon, ChevronRightIcon } from 'lucide-react'
import * as React from 'react'
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerFooter,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from '@/components/ui/drawer'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
import { cn } from '@/lib/utils'
import { useSmallScreen } from '@/hooks/useMediaQuery'
const ANIMATION_CONFIG = {
variants: {
enter: (direction: 'forward' | 'backward') => ({
x: direction === 'forward' ? '100%' : '-100%',
opacity: 0,
}),
center: {
x: 0,
opacity: 1,
},
exit: (direction: 'forward' | 'backward') => ({
x: direction === 'forward' ? '-100%' : '100%',
opacity: 0,
}),
},
transition: {
duration: 0.3,
ease: [0.25, 0.1, 0.25, 1.0],
},
} as const
const getMobileItemStyles = (
isInsideGroup: boolean,
inset?: boolean,
variant?: string,
disabled?: boolean
) => {
return cn(
'flex cursor-pointer items-center justify-between px-4 py-4 w-full gap-4',
!isInsideGroup && 'bg-main-view-fg/50 mx-2 my-1.5 rounded-md',
isInsideGroup && 'bg-transparent py-4',
inset && 'pl-8',
variant === 'destructive' && 'text-destructive',
disabled && 'pointer-events-none opacity-50'
)
}
const DropDrawerContext = React.createContext<{ isMobile: boolean }>({
isMobile: false,
})
const useDropDrawerContext = () => {
const context = React.useContext(DropDrawerContext)
if (!context) {
throw new Error(
'DropDrawer components cannot be rendered outside the DropDrawer Context'
)
}
return context
}
const useComponentSelection = () => {
const { isMobile } = useDropDrawerContext()
const selectComponent = <T, D>(mobileComponent: T, desktopComponent: D) => {
return isMobile ? mobileComponent : desktopComponent
}
return { isMobile, selectComponent }
}
const useGroupDetection = () => {
const isInGroup = React.useCallback(
(element: HTMLElement | null): boolean => {
if (!element) return false
let parent = element.parentElement
while (parent) {
if (parent.hasAttribute('data-drop-drawer-group')) {
return true
}
parent = parent.parentElement
}
return false
},
[]
)
const useGroupState = () => {
const { isMobile } = useComponentSelection()
const itemRef = React.useRef<HTMLDivElement>(null)
const [isInsideGroup, setIsInsideGroup] = React.useState(false)
React.useEffect(() => {
if (!isMobile) return
const timer = setTimeout(() => {
if (itemRef.current) {
setIsInsideGroup(isInGroup(itemRef.current))
}
}, 0)
return () => clearTimeout(timer)
}, [isMobile])
return { itemRef, isInsideGroup }
}
return { isInGroup, useGroupState }
}
type ConditionalComponentProps<T, D> = {
children: React.ReactNode
className?: string
} & (T | D)
const ConditionalComponent = <T, D>({
mobileComponent,
desktopComponent,
children,
...props
}: {
mobileComponent: React.ComponentType<any>
desktopComponent: React.ComponentType<any>
children: React.ReactNode
} & ConditionalComponentProps<T, D>) => {
const { selectComponent } = useComponentSelection()
const Component = selectComponent(mobileComponent, desktopComponent)
return <Component {...props}>{children}</Component>
}
function DropDrawer({
children,
...props
}:
| React.ComponentProps<typeof Drawer>
| React.ComponentProps<typeof DropdownMenu>) {
const isMobile = useSmallScreen()
return (
<DropDrawerContext.Provider value={{ isMobile }}>
<ConditionalComponent
mobileComponent={Drawer}
desktopComponent={DropdownMenu}
data-slot="drop-drawer"
{...props}
>
{children}
</ConditionalComponent>
</DropDrawerContext.Provider>
)
}
function DropDrawerTrigger({
className,
children,
...props
}:
| React.ComponentProps<typeof DrawerTrigger>
| React.ComponentProps<typeof DropdownMenuTrigger>) {
return (
<ConditionalComponent
mobileComponent={DrawerTrigger}
desktopComponent={DropdownMenuTrigger}
data-slot="drop-drawer-trigger"
className={className}
{...props}
>
{children}
</ConditionalComponent>
)
}
function DropDrawerContent({
className,
children,
...props
}:
| React.ComponentProps<typeof DrawerContent>
| React.ComponentProps<typeof DropdownMenuContent>) {
const { isMobile } = useDropDrawerContext()
const [activeSubmenu, setActiveSubmenu] = React.useState<string | null>(null)
const [submenuTitle, setSubmenuTitle] = React.useState<string | null>(null)
const [submenuStack, setSubmenuStack] = React.useState<
{ id: string; title: string }[]
>([])
// Add animation direction state
const [animationDirection, setAnimationDirection] = React.useState<
'forward' | 'backward'
>('forward')
// Create a ref to store submenu content by ID
const submenuContentRef = React.useRef<Map<string, React.ReactNode[]>>(
new Map()
)
// Function to navigate to a submenu
const navigateToSubmenu = React.useCallback((id: string, title: string) => {
// Set animation direction to forward when navigating to a submenu
setAnimationDirection('forward')
setActiveSubmenu(id)
setSubmenuTitle(title)
setSubmenuStack((prev) => [...prev, { id, title }])
}, [])
// Function to go back to previous menu
const goBack = React.useCallback(() => {
// Set animation direction to backward when going back
setAnimationDirection('backward')
if (submenuStack.length <= 1) {
// If we're at the first level, go back to main menu
setActiveSubmenu(null)
setSubmenuTitle(null)
setSubmenuStack([])
} else {
// Go back to previous submenu
const newStack = [...submenuStack]
newStack.pop() // Remove current
const previous = newStack[newStack.length - 1]
setActiveSubmenu(previous.id)
setSubmenuTitle(previous.title)
setSubmenuStack(newStack)
}
}, [submenuStack])
// Function to register submenu content
const registerSubmenuContent = React.useCallback(
(id: string, content: React.ReactNode[]) => {
submenuContentRef.current.set(id, content)
},
[]
)
const extractSubmenuContent = React.useCallback(
(elements: React.ReactNode, targetId: string): React.ReactNode[] => {
const result: React.ReactNode[] = []
const findSubmenuContent = (node: React.ReactNode) => {
if (!React.isValidElement(node)) return
const element = node as React.ReactElement
const props = element.props as {
'id'?: string
'data-submenu-id'?: string
'children'?: React.ReactNode
}
if (element.type === DropDrawerSub) {
const elementId = props.id || props['data-submenu-id']
if (elementId === targetId) {
React.Children.forEach(props.children, (child) => {
if (
React.isValidElement(child) &&
child.type === DropDrawerSubContent
) {
const subContentProps = child.props as {
children?: React.ReactNode
}
React.Children.forEach(
subContentProps.children,
(contentChild) => {
result.push(contentChild)
}
)
}
})
return
}
}
if (props.children) {
React.Children.forEach(props.children, findSubmenuContent)
}
}
React.Children.forEach(elements, findSubmenuContent)
return result
},
[]
)
// Get submenu content (always extract fresh to reflect state changes)
const getSubmenuContent = React.useCallback(
(id: string) => {
// Always extract fresh content to ensure state changes are reflected
const submenuContent = extractSubmenuContent(children, id)
return submenuContent
},
[children, extractSubmenuContent]
)
if (isMobile) {
return (
<SubmenuContext.Provider
value={{
activeSubmenu,
setActiveSubmenu: (id) => {
if (id === null) {
setActiveSubmenu(null)
setSubmenuTitle(null)
setSubmenuStack([])
}
},
submenuTitle,
setSubmenuTitle,
navigateToSubmenu,
registerSubmenuContent,
}}
>
<DrawerContent
data-slot="drop-drawer-content"
className={cn(
'max-h-[90vh] w-full overflow-hidden max-w-none',
className
)}
{...props}
>
{activeSubmenu ? (
<>
<DrawerHeader>
<div className="flex items-center gap-2">
<button
onClick={goBack}
className="hover:bg-muted/50 rounded-full "
>
<ChevronLeftIcon className="h-5 w-5 text-main-view-fg/50" />
</button>
<DrawerTitle className="text-main-view-fg/80 text-sm">
{submenuTitle || 'Submenu'}
</DrawerTitle>
</div>
</DrawerHeader>
<div className="flex-1 relative overflow-hidden max-h-[70vh]">
{/* Use AnimatePresence to handle exit animations */}
<AnimatePresence
initial={false}
mode="wait"
custom={animationDirection}
>
<motion.div
key={activeSubmenu || 'main'}
custom={animationDirection}
variants={ANIMATION_CONFIG.variants}
initial="enter"
animate="center"
exit="exit"
transition={ANIMATION_CONFIG.transition as any}
className="pb-6 space-y-1.5 w-full h-full overflow-hidden"
>
{activeSubmenu
? getSubmenuContent(activeSubmenu)
: children}
</motion.div>
</AnimatePresence>
</div>
</>
) : (
<>
<DrawerHeader className="sr-only">
<DrawerTitle>Menu</DrawerTitle>
</DrawerHeader>
<div className="overflow-hidden max-h-[70vh]">
<AnimatePresence
initial={false}
mode="wait"
custom={animationDirection}
>
<motion.div
key="main-menu"
custom={animationDirection}
variants={ANIMATION_CONFIG.variants}
initial="enter"
animate="center"
exit="exit"
transition={ANIMATION_CONFIG.transition as any}
className="pb-6 space-y-1.5 w-full overflow-hidden"
>
{children}
</motion.div>
</AnimatePresence>
</div>
</>
)}
</DrawerContent>
</SubmenuContext.Provider>
)
}
return (
<SubmenuContext.Provider
value={{
activeSubmenu,
setActiveSubmenu,
submenuTitle,
setSubmenuTitle,
navigateToSubmenu,
registerSubmenuContent,
}}
>
<DropdownMenuContent
data-slot="drop-drawer-content"
sideOffset={4}
className={cn(
'max-h-[var(--radix-dropdown-menu-content-available-height)] min-w-[220px] overflow-hidden',
className
)}
{...props}
>
{children}
</DropdownMenuContent>
</SubmenuContext.Provider>
)
}
function DropDrawerItem({
className,
children,
onSelect,
onClick,
icon,
variant = 'default',
inset,
disabled,
...props
}: React.ComponentProps<typeof DropdownMenuItem> & {
icon?: React.ReactNode
}) {
const { isMobile } = useComponentSelection()
const { useGroupState } = useGroupDetection()
const { itemRef, isInsideGroup } = useGroupState()
if (isMobile) {
const handleClick = (e: React.MouseEvent<HTMLDivElement>) => {
if (disabled) return
// If this item only has an icon (like a switch) and no other interactive content,
// don't handle clicks on the main area - let the icon handle everything
if (icon && !onClick && !onSelect) {
return
}
// Check if the click came from the icon area (where the Switch is)
const target = e.target as HTMLElement
const iconContainer = (e.currentTarget as HTMLElement).querySelector(
'[data-icon-container]'
)
if (iconContainer && iconContainer.contains(target)) {
// Don't handle the click if it came from the icon area
return
}
if (onClick) onClick(e)
if (onSelect) onSelect(e as unknown as Event)
}
// Only wrap in DrawerClose if it's not a submenu item
const content = (
<div
ref={itemRef}
data-slot="drop-drawer-item"
data-variant={variant}
data-inset={inset}
data-disabled={disabled}
className={cn(
getMobileItemStyles(isInsideGroup, inset, variant, disabled),
className
)}
onClick={handleClick}
aria-disabled={disabled}
{...props}
>
<div className="flex items-center gap-2">{children}</div>
{icon && (
<div className="flex-shrink-0" data-icon-container>
{icon}
</div>
)}
</div>
)
// Check if this is inside a submenu
const isInSubmenu =
(props as Record<string, unknown>)['data-parent-submenu-id'] ||
(props as Record<string, unknown>)['data-parent-submenu']
if (isInSubmenu) {
return content
}
return <DrawerClose asChild>{content}</DrawerClose>
}
return (
<DropdownMenuItem
data-slot="drop-drawer-item"
data-variant={variant}
data-inset={inset}
className={className}
onSelect={onSelect}
onClick={onClick as React.MouseEventHandler<HTMLDivElement>}
variant={variant}
inset={inset}
disabled={disabled}
{...props}
>
<div className="flex w-full items-start justify-between gap-4">
<div>{children}</div>
{icon && <div>{icon}</div>}
</div>
</DropdownMenuItem>
)
}
function DropDrawerSeparator({
className,
...props
}: React.ComponentProps<typeof DropdownMenuSeparator>) {
const { isMobile } = useComponentSelection()
if (isMobile) {
return null
}
return (
<DropdownMenuSeparator
data-slot="drop-drawer-separator"
className={className}
{...props}
/>
)
}
function DropDrawerLabel({
className,
children,
...props
}:
| React.ComponentProps<typeof DropdownMenuLabel>
| React.ComponentProps<typeof DrawerTitle>) {
const { isMobile } = useComponentSelection()
if (isMobile) {
return (
<DrawerHeader className="p-0">
<DrawerTitle
data-slot="drop-drawer-label"
className={cn(
'px-4 py-2 text-sm font-medium text-main-view-fg/60',
className
)}
{...props}
>
{children}
</DrawerTitle>
</DrawerHeader>
)
}
return (
<DropdownMenuLabel
data-slot="drop-drawer-label"
className={className}
{...props}
>
{children}
</DropdownMenuLabel>
)
}
function DropDrawerFooter({
className,
children,
...props
}: React.ComponentProps<typeof DrawerFooter> | React.ComponentProps<'div'>) {
const { isMobile } = useDropDrawerContext()
if (isMobile) {
return (
<DrawerFooter
data-slot="drop-drawer-footer"
className={cn('p-4', className)}
{...props}
>
{children}
</DrawerFooter>
)
}
// No direct equivalent in DropdownMenu, so we'll just render a div
return (
<div
data-slot="drop-drawer-footer"
className={cn('p-2', className)}
{...props}
>
{children}
</div>
)
}
function DropDrawerGroup({
className,
children,
...props
}: React.ComponentProps<'div'> & {
children: React.ReactNode
}) {
const { isMobile } = useDropDrawerContext()
// Add separators between children on mobile
const childrenWithSeparators = React.useMemo(() => {
if (!isMobile) return children
const childArray = React.Children.toArray(children)
// Filter out any existing separators
const filteredChildren = childArray.filter(
(child) =>
React.isValidElement(child) && child.type !== DropDrawerSeparator
)
// Add separators between items
return filteredChildren.flatMap((child, index) => {
if (index === filteredChildren.length - 1) return [child]
return [
child,
<div
key={`separator-${index}`}
className="bg-border h-px"
aria-hidden="true"
/>,
]
})
}, [children, isMobile])
if (isMobile) {
return (
<div
data-drop-drawer-group
data-slot="drop-drawer-group"
role="group"
className={cn(
'bg-main-view-fg/2 border border-main-view-fg/4 mx-2 my-3 overflow-hidden rounded-xl',
className
)}
{...props}
>
{childrenWithSeparators}
</div>
)
}
// On desktop, use a div with proper role and attributes
return (
<div
data-drop-drawer-group
data-slot="drop-drawer-group"
role="group"
className={className}
{...props}
>
{children}
</div>
)
}
// Context for managing submenu state on mobile
interface SubmenuContextType {
activeSubmenu: string | null
setActiveSubmenu: (id: string | null) => void
submenuTitle: string | null
setSubmenuTitle: (title: string | null) => void
navigateToSubmenu?: (id: string, title: string) => void
registerSubmenuContent?: (id: string, content: React.ReactNode[]) => void
}
const SubmenuContext = React.createContext<SubmenuContextType>({
activeSubmenu: null,
setActiveSubmenu: () => {},
submenuTitle: null,
setSubmenuTitle: () => {},
navigateToSubmenu: undefined,
registerSubmenuContent: undefined,
})
// Submenu components
// Counter for generating simple numeric IDs
let submenuIdCounter = 0
function DropDrawerSub({
children,
id,
title,
...props
}: React.ComponentProps<typeof DropdownMenuSub> & {
id?: string
title?: string
}) {
const { isMobile } = useDropDrawerContext()
const { registerSubmenuContent } = React.useContext(SubmenuContext)
// Generate a simple numeric ID instead of using React.useId()
const [generatedId] = React.useState(() => `submenu-${submenuIdCounter++}`)
const submenuId = id || generatedId
// Extract submenu content to register with parent
React.useEffect(() => {
if (!registerSubmenuContent) return
// Find the SubContent within this Sub
const contentItems: React.ReactNode[] = []
React.Children.forEach(children, (child) => {
if (React.isValidElement(child) && child.type === DropDrawerSubContent) {
// Add all children of the SubContent to the result
React.Children.forEach(
(child.props as { children?: React.ReactNode }).children,
(contentChild) => {
contentItems.push(contentChild)
}
)
}
})
// Register the content with the parent
if (contentItems.length > 0) {
registerSubmenuContent(submenuId, contentItems)
}
}, [children, registerSubmenuContent, submenuId])
if (isMobile) {
// For mobile, we'll use the context to manage submenu state
// Process children to pass the submenu ID to the trigger and content
const processedChildren = React.Children.map(children, (child) => {
if (!React.isValidElement(child)) return child
if (child.type === DropDrawerSubTrigger) {
return React.cloneElement(
child as React.ReactElement,
{
...(child.props as object),
'data-parent-submenu-id': submenuId,
'data-submenu-id': submenuId,
// Use only data attributes, not custom props
'data-parent-submenu': submenuId,
'data-submenu-title': title,
} as React.HTMLAttributes<HTMLElement>
)
}
if (child.type === DropDrawerSubContent) {
return React.cloneElement(
child as React.ReactElement,
{
...(child.props as object),
'data-parent-submenu-id': submenuId,
'data-submenu-id': submenuId,
// Use only data attributes, not custom props
'data-parent-submenu': submenuId,
} as React.HTMLAttributes<HTMLElement>
)
}
return child
})
return (
<div
data-slot="drop-drawer-sub"
data-submenu-id={submenuId}
id={submenuId}
>
{processedChildren}
</div>
)
}
// For desktop, use the standard DropdownMenuSub
return <DropdownMenuSub {...props}>{children}</DropdownMenuSub>
}
function DropDrawerSubTrigger({
className,
inset,
children,
...props
}: React.ComponentProps<typeof DropdownMenuSubTrigger> & {
icon?: React.ReactNode
}) {
const { isMobile } = useComponentSelection()
const { navigateToSubmenu } = React.useContext(SubmenuContext)
const { useGroupState } = useGroupDetection()
const { itemRef, isInsideGroup } = useGroupState()
if (isMobile) {
// Find the parent submenu ID
const handleClick = (e: React.MouseEvent) => {
e.preventDefault()
e.stopPropagation()
// Get the closest parent with data-submenu-id attribute
const element = e.currentTarget as HTMLElement
let submenuId: string | null = null
// First check if the element itself has the data attribute
if (element.closest('[data-submenu-id]')) {
const closestElement = element.closest('[data-submenu-id]')
const id = closestElement?.getAttribute('data-submenu-id')
if (id) {
submenuId = id
}
}
// If not found, try props
if (!submenuId) {
submenuId =
((props as Record<string, unknown>)[
'data-parent-submenu-id'
] as string) ||
((props as Record<string, unknown>)['data-parent-submenu'] as string)
}
if (!submenuId) {
return
}
// Get the title - first try data attribute, then children, then fallback
const dataTitle = (props as Record<string, unknown>)[
'data-submenu-title'
] as string
const title =
dataTitle || (typeof children === 'string' ? children : 'Submenu')
// Navigate to the submenu
if (navigateToSubmenu) {
navigateToSubmenu(submenuId, title)
}
}
// Combine onClick handlers
const combinedOnClick = (e: React.MouseEvent) => {
// Call the original onClick if provided
const typedProps = props as Record<string, unknown>
if (typedProps.onClick) {
const originalOnClick =
typedProps.onClick as React.MouseEventHandler<HTMLDivElement>
originalOnClick(e as React.MouseEvent<HTMLDivElement>)
}
// Call our navigation handler
handleClick(e)
}
// Remove onClick from props to avoid duplicate handlers
const { ...restProps } = props as Record<string, unknown>
// Don't wrap in DrawerClose for submenu triggers
return (
<div
ref={itemRef}
data-slot="drop-drawer-sub-trigger"
data-inset={inset}
className={cn(getMobileItemStyles(isInsideGroup, inset), className)}
onClick={combinedOnClick}
{...restProps}
>
<div className="flex items-center gap-2 w-full">{children}</div>
<ChevronRightIcon className="h-5 w-5 text-main-view-fg/50 " />
</div>
)
}
return (
<DropdownMenuSubTrigger
data-slot="drop-drawer-sub-trigger"
data-inset={inset}
className={className}
inset={inset}
{...props}
>
{children}
</DropdownMenuSubTrigger>
)
}
function DropDrawerSubContent({
className,
sideOffset = 4,
children,
...props
}: React.ComponentProps<typeof DropdownMenuSubContent>) {
const { isMobile } = useDropDrawerContext()
if (isMobile) {
// For mobile, we don't render the content directly
// It will be rendered by the DropDrawerContent component when active
return null
}
return (
<DropdownMenuSubContent
data-slot="drop-drawer-sub-content"
sideOffset={sideOffset}
className={className}
{...props}
>
{children}
</DropdownMenuSubContent>
)
}
export {
DropDrawer,
DropDrawerContent,
DropDrawerFooter,
DropDrawerGroup,
DropDrawerItem,
DropDrawerLabel,
DropDrawerSeparator,
DropDrawerSub,
DropDrawerSubContent,
DropDrawerSubTrigger,
DropDrawerTrigger,
}

View File

@ -0,0 +1,42 @@
import * as React from 'react'
import * as RadioGroupPrimitive from '@radix-ui/react-radio-group'
import { CircleIcon } from 'lucide-react'
import { cn } from '@/lib/utils'
const RadioGroup = React.forwardRef<
React.ElementRef<typeof RadioGroupPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Root>
>(({ className, ...props }, ref) => {
return (
<RadioGroupPrimitive.Root
className={cn('grid gap-2', className)}
{...props}
ref={ref}
/>
)
})
RadioGroup.displayName = RadioGroupPrimitive.Root.displayName
const RadioGroupItem = React.forwardRef<
React.ElementRef<typeof RadioGroupPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Item>
>(({ className, ...props }, ref) => {
return (
<RadioGroupPrimitive.Item
ref={ref}
className={cn(
'aspect-square size-4 rounded-full border border-main-view-fg/20 text-main-view-fg ring-offset-main-view focus:outline-none focus-visible:ring-2 focus-visible:ring-main-view-fg/50 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50',
className
)}
{...props}
>
<RadioGroupPrimitive.Indicator className="flex items-center justify-center">
<CircleIcon className="size-2.5 fill-current text-current" />
</RadioGroupPrimitive.Indicator>
</RadioGroupPrimitive.Item>
)
})
RadioGroupItem.displayName = RadioGroupPrimitive.Item.displayName
export { RadioGroup, RadioGroupItem }

View File

@ -19,5 +19,6 @@ export const localStorageKey = {
mcpGlobalPermissions: 'mcp-global-permissions', mcpGlobalPermissions: 'mcp-global-permissions',
lastUsedModel: 'last-used-model', lastUsedModel: 'last-used-model',
lastUsedAssistant: 'last-used-assistant', lastUsedAssistant: 'last-used-assistant',
favoriteModels: 'favorite-models',
setupCompleted: 'setup-completed', setupCompleted: 'setup-completed',
} }

View File

@ -1,7 +1,7 @@
'use client' 'use client'
import TextareaAutosize from 'react-textarea-autosize' import TextareaAutosize from 'react-textarea-autosize'
import { cn, toGigabytes } from '@/lib/utils' import { cn } from '@/lib/utils'
import { usePrompt } from '@/hooks/usePrompt' import { usePrompt } from '@/hooks/usePrompt'
import { useThreads } from '@/hooks/useThreads' import { useThreads } from '@/hooks/useThreads'
import { useCallback, useEffect, useRef, useState } from 'react' import { useCallback, useEffect, useRef, useState } from 'react'
@ -14,7 +14,7 @@ import {
} from '@/components/ui/tooltip' } from '@/components/ui/tooltip'
import { ArrowRight } from 'lucide-react' import { ArrowRight } from 'lucide-react'
import { import {
IconPaperclip, IconPhoto,
IconWorld, IconWorld,
IconAtom, IconAtom,
IconEye, IconEye,
@ -34,6 +34,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider'
import { ModelLoader } from '@/containers/loaders/ModelLoader' import { ModelLoader } from '@/containers/loaders/ModelLoader'
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable' import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
import { getConnectedServers } from '@/services/mcp' import { getConnectedServers } from '@/services/mcp'
import { checkMmprojExists } from '@/services/models'
type ChatInputProps = { type ChatInputProps = {
className?: string className?: string
@ -46,8 +47,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null) const textareaRef = useRef<HTMLTextAreaElement>(null)
const [isFocused, setIsFocused] = useState(false) const [isFocused, setIsFocused] = useState(false)
const [rows, setRows] = useState(1) const [rows, setRows] = useState(1)
const { streamingContent, abortControllers, loadingModel, tools } = const {
useAppState() streamingContent,
abortControllers,
loadingModel,
tools,
cancelToolCall,
} = useAppState()
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
const { currentThreadId } = useThreads() const { currentThreadId } = useThreads()
const { t } = useTranslation() const { t } = useTranslation()
@ -55,7 +61,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const maxRows = 10 const maxRows = 10
const { selectedModel } = useModelProvider() const { selectedModel, selectedProvider } = useModelProvider()
const { sendMessage } = useChat() const { sendMessage } = useChat()
const [message, setMessage] = useState('') const [message, setMessage] = useState('')
const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false) const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false)
@ -70,6 +76,8 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
}> }>
>([]) >([])
const [connectedServers, setConnectedServers] = useState<string[]>([]) const [connectedServers, setConnectedServers] = useState<string[]>([])
const [isDragOver, setIsDragOver] = useState(false)
const [hasMmproj, setHasMmproj] = useState(false)
// Check for connected MCP servers // Check for connected MCP servers
useEffect(() => { useEffect(() => {
@ -91,6 +99,29 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
return () => clearInterval(intervalId) return () => clearInterval(intervalId)
}, []) }, [])
// Check for mmproj existence or vision capability when model changes
useEffect(() => {
const checkMmprojSupport = async () => {
if (selectedModel?.id) {
try {
// Only check mmproj for llamacpp provider
if (selectedProvider === 'llamacpp') {
const hasLocalMmproj = await checkMmprojExists(selectedModel.id)
setHasMmproj(hasLocalMmproj)
} else {
// For non-llamacpp providers, only check vision capability
setHasMmproj(true)
}
} catch (error) {
console.error('Error checking mmproj:', error)
setHasMmproj(false)
}
}
}
checkMmprojSupport()
}, [selectedModel?.id, selectedProvider])
// Check if there are active MCP servers // Check if there are active MCP servers
const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0 const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0
@ -99,11 +130,16 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
setMessage('Please select a model to start chatting.') setMessage('Please select a model to start chatting.')
return return
} }
if (!prompt.trim()) { if (!prompt.trim() && uploadedFiles.length === 0) {
return return
} }
setMessage('') setMessage('')
sendMessage(prompt) sendMessage(
prompt,
true,
uploadedFiles.length > 0 ? uploadedFiles : undefined
)
setUploadedFiles([])
} }
useEffect(() => { useEffect(() => {
@ -161,8 +197,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const stopStreaming = useCallback( const stopStreaming = useCallback(
(threadId: string) => { (threadId: string) => {
abortControllers[threadId]?.abort() abortControllers[threadId]?.abort()
cancelToolCall?.()
}, },
[abortControllers] [abortControllers, cancelToolCall]
) )
const fileInputRef = useRef<HTMLInputElement>(null) const fileInputRef = useRef<HTMLInputElement>(null)
@ -185,8 +222,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
return 'image/jpeg' return 'image/jpeg'
case 'png': case 'png':
return 'image/png' return 'image/png'
case 'pdf':
return 'application/pdf'
default: default:
return '' return ''
} }
@ -220,17 +255,12 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const detectedType = file.type || getFileTypeFromExtension(file.name) const detectedType = file.type || getFileTypeFromExtension(file.name)
const actualType = getFileTypeFromExtension(file.name) || detectedType const actualType = getFileTypeFromExtension(file.name) || detectedType
// Check file type // Check file type - images only
const allowedTypes = [ const allowedTypes = ['image/jpg', 'image/jpeg', 'image/png']
'image/jpg',
'image/jpeg',
'image/png',
'application/pdf',
]
if (!allowedTypes.includes(actualType)) { if (!allowedTypes.includes(actualType)) {
setMessage( setMessage(
`File is not supported. Only JPEG, JPG, PNG, and PDF files are allowed.` `File attachments not supported currently. Only JPEG, JPG, and PNG files are allowed.`
) )
// Reset file input to allow re-uploading // Reset file input to allow re-uploading
if (fileInputRef.current) { if (fileInputRef.current) {
@ -281,6 +311,104 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
} }
} }
const handleDragEnter = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only allow drag if model supports mmproj
if (hasMmproj) {
setIsDragOver(true)
}
}
const handleDragLeave = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only set dragOver to false if we're leaving the drop zone entirely
// In Tauri, relatedTarget can be null, so we need to handle that case
const relatedTarget = e.relatedTarget as Node | null
if (!relatedTarget || !e.currentTarget.contains(relatedTarget)) {
setIsDragOver(false)
}
}
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Ensure drag state is maintained during drag over
if (hasMmproj) {
setIsDragOver(true)
}
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
setIsDragOver(false)
// Only allow drop if model supports mmproj
if (!hasMmproj) {
return
}
// Check if dataTransfer exists (it might not in some Tauri scenarios)
if (!e.dataTransfer) {
console.warn('No dataTransfer available in drop event')
return
}
const files = e.dataTransfer.files
if (files && files.length > 0) {
// Create a synthetic event to reuse existing file handling logic
const syntheticEvent = {
target: {
files: files,
},
} as React.ChangeEvent<HTMLInputElement>
handleFileChange(syntheticEvent)
}
}
const handlePaste = (e: React.ClipboardEvent) => {
const clipboardItems = e.clipboardData?.items
if (!clipboardItems) return
// Only allow paste if model supports mmproj
if (!hasMmproj) {
return
}
const imageItems = Array.from(clipboardItems).filter((item) =>
item.type.startsWith('image/')
)
if (imageItems.length > 0) {
e.preventDefault()
const files: File[] = []
let processedCount = 0
imageItems.forEach((item) => {
const file = item.getAsFile()
if (file) {
files.push(file)
}
processedCount++
// When all items are processed, handle the valid files
if (processedCount === imageItems.length && files.length > 0) {
const syntheticEvent = {
target: {
files: files,
},
} as unknown as React.ChangeEvent<HTMLInputElement>
handleFileChange(syntheticEvent)
}
})
}
}
return ( return (
<div className="relative"> <div className="relative">
<div className="relative"> <div className="relative">
@ -305,8 +433,14 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
<div <div
className={cn( className={cn(
'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view', 'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view',
isFocused && 'ring-1 ring-main-view-fg/10' isFocused && 'ring-1 ring-main-view-fg/10',
isDragOver && 'ring-2 ring-accent border-accent'
)} )}
data-drop-zone={hasMmproj ? 'true' : undefined}
onDragEnter={hasMmproj ? handleDragEnter : undefined}
onDragLeave={hasMmproj ? handleDragLeave : undefined}
onDragOver={hasMmproj ? handleDragOver : undefined}
onDrop={hasMmproj ? handleDrop : undefined}
> >
{uploadedFiles.length > 0 && ( {uploadedFiles.length > 0 && (
<div className="flex gap-3 items-center p-2 pb-0"> <div className="flex gap-3 items-center p-2 pb-0">
@ -326,25 +460,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
alt={`${file.name} - ${index}`} alt={`${file.name} - ${index}`}
/> />
)} )}
{file.type === 'application/pdf' && (
<div className="bg-main-view-fg/4 h-full rounded-lg p-2 max-w-[400px] pr-4">
<div className="flex gap-2 items-center justify-center h-full">
<div className="size-10 rounded-md bg-main-view shrink-0 flex items-center justify-center">
<span className="uppercase font-bold">
{file.name.split('.').pop()}
</span>
</div>
<div className="truncate">
<h6 className="truncate mb-0.5 text-main-view-fg/80">
{file.name}
</h6>
<p className="text-xs text-main-view-fg/70">
{toGigabytes(file.size)}
</p>
</div>
</div>
</div>
)}
<div <div
className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer" className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer"
onClick={() => handleRemoveFile(index)} onClick={() => handleRemoveFile(index)}
@ -363,7 +478,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
rows={1} rows={1}
maxRows={10} maxRows={10}
value={prompt} value={prompt}
data-test-id={'chat-input'} data-testid={'chat-input'}
onChange={(e) => { onChange={(e) => {
setPrompt(e.target.value) setPrompt(e.target.value)
// Count the number of newlines to estimate rows // Count the number of newlines to estimate rows
@ -372,14 +487,21 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
}} }}
onKeyDown={(e) => { onKeyDown={(e) => {
// e.keyCode 229 is for IME input with Safari // e.keyCode 229 is for IME input with Safari
const isComposing = e.nativeEvent.isComposing || e.keyCode === 229; const isComposing =
if (e.key === 'Enter' && !e.shiftKey && prompt.trim() && !isComposing) { e.nativeEvent.isComposing || e.keyCode === 229
if (
e.key === 'Enter' &&
!e.shiftKey &&
prompt.trim() &&
!isComposing
) {
e.preventDefault() e.preventDefault()
// Submit the message when Enter is pressed without Shift // Submit the message when Enter is pressed without Shift
handleSendMesage(prompt) handleSendMesage(prompt)
// When Shift+Enter is pressed, a new line is added (default behavior) // When Shift+Enter is pressed, a new line is added (default behavior)
} }
}} }}
onPaste={hasMmproj ? handlePaste : undefined}
placeholder={t('common:placeholder.chatInput')} placeholder={t('common:placeholder.chatInput')}
autoFocus autoFocus
spellCheck={spellCheckChatInput} spellCheck={spellCheckChatInput}
@ -400,7 +522,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
<div className="px-1 flex items-center gap-1"> <div className="px-1 flex items-center gap-1">
<div <div
className={cn( className={cn(
'px-1 flex items-center gap-1', 'px-1 flex items-center',
streamingContent && 'opacity-50 pointer-events-none' streamingContent && 'opacity-50 pointer-events-none'
)} )}
> >
@ -412,19 +534,22 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
useLastUsedModel={initialMessage} useLastUsedModel={initialMessage}
/> />
)} )}
{/* File attachment - always available */} {/* File attachment - show only for models with mmproj */}
<div {hasMmproj && (
className="h-6 hidden p-1 items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1" <div
onClick={handleAttachmentClick} className="h-6 p-1 ml-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"
> onClick={handleAttachmentClick}
<IconPaperclip size={18} className="text-main-view-fg/50" /> >
<input <IconPhoto size={18} className="text-main-view-fg/50" />
type="file" <input
ref={fileInputRef} type="file"
className="hidden" ref={fileInputRef}
onChange={handleFileChange} className="hidden"
/> multiple
</div> onChange={handleFileChange}
/>
</div>
)}
{/* Microphone - always available - Temp Hide */} {/* Microphone - always available - Temp Hide */}
{/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"> {/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
<IconMicrophone size={18} className="text-main-view-fg/50" /> <IconMicrophone size={18} className="text-main-view-fg/50" />
@ -483,7 +608,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
initialMessage={initialMessage} initialMessage={initialMessage}
onOpenChange={(isOpen) => { onOpenChange={(isOpen) => {
setDropdownToolsAvailable(isOpen) setDropdownToolsAvailable(isOpen)
setTooltipToolsAvailable(false) if (isOpen) {
setTooltipToolsAvailable(false)
}
}} }}
> >
{(isOpen, toolsCount) => { {(isOpen, toolsCount) => {
@ -566,9 +693,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</Button> </Button>
) : ( ) : (
<Button <Button
variant={!prompt.trim() ? null : 'default'} variant={
!prompt.trim() && uploadedFiles.length === 0
? null
: 'default'
}
size="icon" size="icon"
disabled={!prompt.trim()} disabled={!prompt.trim() && uploadedFiles.length === 0}
data-test-id="send-message-button" data-test-id="send-message-button"
onClick={() => handleSendMesage(prompt)} onClick={() => handleSendMesage(prompt)}
> >
@ -582,6 +713,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</div> </div>
</div> </div>
</div> </div>
{message && ( {message && (
<div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out"> <div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out">
<div className="flex items-center gap-1 justify-between"> <div className="flex items-center gap-1 justify-between">

View File

@ -1,4 +1,4 @@
import { useAppearance } from '@/hooks/useAppearance' import { useAppearance, isDefaultColorAccent } from '@/hooks/useAppearance'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { RgbaColor, RgbaColorPicker } from 'react-colorful' import { RgbaColor, RgbaColorPicker } from 'react-colorful'
import { IconColorPicker } from '@tabler/icons-react' import { IconColorPicker } from '@tabler/icons-react'
@ -37,10 +37,11 @@ export function ColorPickerAppAccentColor() {
<div className="flex items-center gap-1.5"> <div className="flex items-center gap-1.5">
{predefineAppAccentBgColor.map((item, i) => { {predefineAppAccentBgColor.map((item, i) => {
const isSelected = const isSelected =
item.r === appAccentBgColor.r && (item.r === appAccentBgColor.r &&
item.g === appAccentBgColor.g && item.g === appAccentBgColor.g &&
item.b === appAccentBgColor.b && item.b === appAccentBgColor.b &&
item.a === appAccentBgColor.a item.a === appAccentBgColor.a) ||
(isDefaultColorAccent(appAccentBgColor) && isDefaultColorAccent(item))
return ( return (
<div <div
key={i} key={i}

View File

@ -1,4 +1,4 @@
import { useAppearance } from '@/hooks/useAppearance' import { useAppearance, isDefaultColor } from '@/hooks/useAppearance'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { RgbaColor, RgbaColorPicker } from 'react-colorful' import { RgbaColor, RgbaColorPicker } from 'react-colorful'
import { IconColorPicker } from '@tabler/icons-react' import { IconColorPicker } from '@tabler/icons-react'
@ -60,10 +60,11 @@ export function ColorPickerAppBgColor() {
<div className="flex items-center gap-1.5"> <div className="flex items-center gap-1.5">
{predefineAppBgColor.map((item, i) => { {predefineAppBgColor.map((item, i) => {
const isSelected = const isSelected =
item.r === appBgColor.r && (item.r === appBgColor.r &&
item.g === appBgColor.g && item.g === appBgColor.g &&
item.b === appBgColor.b && item.b === appBgColor.b &&
item.a === appBgColor.a item.a === appBgColor.a) ||
(isDefaultColor(appBgColor) && isDefaultColor(item))
return ( return (
<div <div
key={i} key={i}

View File

@ -1,4 +1,4 @@
import { useAppearance } from '@/hooks/useAppearance' import { useAppearance, isDefaultColorDestructive } from '@/hooks/useAppearance'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { RgbaColor, RgbaColorPicker } from 'react-colorful' import { RgbaColor, RgbaColorPicker } from 'react-colorful'
import { IconColorPicker } from '@tabler/icons-react' import { IconColorPicker } from '@tabler/icons-react'
@ -46,10 +46,11 @@ export function ColorPickerAppDestructiveColor() {
<div className="flex items-center gap-1.5"> <div className="flex items-center gap-1.5">
{predefineAppDestructiveBgColor.map((item, i) => { {predefineAppDestructiveBgColor.map((item, i) => {
const isSelected = const isSelected =
item.r === appDestructiveBgColor.r && (item.r === appDestructiveBgColor.r &&
item.g === appDestructiveBgColor.g && item.g === appDestructiveBgColor.g &&
item.b === appDestructiveBgColor.b && item.b === appDestructiveBgColor.b &&
item.a === appDestructiveBgColor.a item.a === appDestructiveBgColor.a) ||
(isDefaultColorDestructive(appDestructiveBgColor) && isDefaultColorDestructive(item))
return ( return (
<div <div
key={i} key={i}

View File

@ -1,4 +1,4 @@
import { useAppearance } from '@/hooks/useAppearance' import { useAppearance, isDefaultColorMainView } from '@/hooks/useAppearance'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { RgbaColor, RgbaColorPicker } from 'react-colorful' import { RgbaColor, RgbaColorPicker } from 'react-colorful'
import { IconColorPicker } from '@tabler/icons-react' import { IconColorPicker } from '@tabler/icons-react'
@ -33,10 +33,11 @@ export function ColorPickerAppMainView() {
<div className="flex items-center gap-1.5"> <div className="flex items-center gap-1.5">
{predefineAppMainViewBgColor.map((item, i) => { {predefineAppMainViewBgColor.map((item, i) => {
const isSelected = const isSelected =
item.r === appMainViewBgColor.r && (item.r === appMainViewBgColor.r &&
item.g === appMainViewBgColor.g && item.g === appMainViewBgColor.g &&
item.b === appMainViewBgColor.b && item.b === appMainViewBgColor.b &&
item.a === appMainViewBgColor.a item.a === appMainViewBgColor.a) ||
(isDefaultColorMainView(appMainViewBgColor) && isDefaultColorMainView(item))
return ( return (
<div <div
key={i} key={i}

View File

@ -1,4 +1,4 @@
import { useAppearance } from '@/hooks/useAppearance' import { useAppearance, isDefaultColorPrimary } from '@/hooks/useAppearance'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { RgbaColor, RgbaColorPicker } from 'react-colorful' import { RgbaColor, RgbaColorPicker } from 'react-colorful'
import { IconColorPicker } from '@tabler/icons-react' import { IconColorPicker } from '@tabler/icons-react'
@ -42,10 +42,11 @@ export function ColorPickerAppPrimaryColor() {
<div className="flex items-center gap-1.5"> <div className="flex items-center gap-1.5">
{predefineappPrimaryBgColor.map((item, i) => { {predefineappPrimaryBgColor.map((item, i) => {
const isSelected = const isSelected =
item.r === appPrimaryBgColor.r && (item.r === appPrimaryBgColor.r &&
item.g === appPrimaryBgColor.g && item.g === appPrimaryBgColor.g &&
item.b === appPrimaryBgColor.b && item.b === appPrimaryBgColor.b &&
item.a === appPrimaryBgColor.a item.a === appPrimaryBgColor.a) ||
(isDefaultColorPrimary(appPrimaryBgColor) && isDefaultColorPrimary(item))
return ( return (
<div <div
key={i} key={i}

View File

@ -6,10 +6,8 @@ import {
import { Progress } from '@/components/ui/progress' import { Progress } from '@/components/ui/progress'
import { useDownloadStore } from '@/hooks/useDownloadStore' import { useDownloadStore } from '@/hooks/useDownloadStore'
import { useLeftPanel } from '@/hooks/useLeftPanel' import { useLeftPanel } from '@/hooks/useLeftPanel'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useAppUpdater } from '@/hooks/useAppUpdater' import { useAppUpdater } from '@/hooks/useAppUpdater'
import { abortDownload } from '@/services/models' import { abortDownload } from '@/services/models'
import { getProviders } from '@/services/providers'
import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core' import { DownloadEvent, DownloadState, events, AppEvent } from '@janhq/core'
import { IconDownload, IconX } from '@tabler/icons-react' import { IconDownload, IconX } from '@tabler/icons-react'
import { useCallback, useEffect, useMemo, useState } from 'react' import { useCallback, useEffect, useMemo, useState } from 'react'
@ -18,7 +16,6 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
export function DownloadManagement() { export function DownloadManagement() {
const { t } = useTranslation() const { t } = useTranslation()
const { setProviders } = useModelProvider()
const { open: isLeftPanelOpen } = useLeftPanel() const { open: isLeftPanelOpen } = useLeftPanel()
const [isPopoverOpen, setIsPopoverOpen] = useState(false) const [isPopoverOpen, setIsPopoverOpen] = useState(false)
const { const {
@ -185,7 +182,6 @@ export function DownloadManagement() {
console.debug('onFileDownloadSuccess', state) console.debug('onFileDownloadSuccess', state)
removeDownload(state.modelId) removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId) removeLocalDownloadingModel(state.modelId)
getProviders().then(setProviders)
toast.success(t('common:toast.downloadComplete.title'), { toast.success(t('common:toast.downloadComplete.title'), {
id: 'download-complete', id: 'download-complete',
description: t('common:toast.downloadComplete.description', { description: t('common:toast.downloadComplete.description', {
@ -193,7 +189,7 @@ export function DownloadManagement() {
}), }),
}) })
}, },
[removeDownload, removeLocalDownloadingModel, setProviders, t] [removeDownload, removeLocalDownloadingModel, t]
) )
useEffect(() => { useEffect(() => {

View File

@ -17,6 +17,9 @@ import ProvidersAvatar from '@/containers/ProvidersAvatar'
import { Fzf } from 'fzf' import { Fzf } from 'fzf'
import { localStorageKey } from '@/constants/localStorage' import { localStorageKey } from '@/constants/localStorage'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
import { predefinedProviders } from '@/consts/providers'
import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models'
type DropdownModelProviderProps = { type DropdownModelProviderProps = {
model?: ThreadModel model?: ThreadModel
@ -64,54 +67,75 @@ const DropdownModelProvider = ({
getModelBy, getModelBy,
selectedProvider, selectedProvider,
selectedModel, selectedModel,
updateProvider,
} = useModelProvider() } = useModelProvider()
const [displayModel, setDisplayModel] = useState<string>('') const [displayModel, setDisplayModel] = useState<string>('')
const { updateCurrentThreadModel } = useThreads() const { updateCurrentThreadModel } = useThreads()
const navigate = useNavigate() const navigate = useNavigate()
const { t } = useTranslation() const { t } = useTranslation()
const { favoriteModels } = useFavoriteModel()
// Search state // Search state
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const [searchValue, setSearchValue] = useState('') const [searchValue, setSearchValue] = useState('')
const searchInputRef = useRef<HTMLInputElement>(null) const searchInputRef = useRef<HTMLInputElement>(null)
// Helper function to check if a model exists in providers
const checkModelExists = useCallback(
(providerName: string, modelId: string) => {
const provider = providers.find(
(p) => p.provider === providerName && p.active
)
return provider?.models.find((m) => m.id === modelId)
},
[providers]
)
// Initialize model provider only once // Initialize model provider only once
useEffect(() => { useEffect(() => {
// Auto select model when existing thread is passed const initializeModel = async () => {
if (model) { // Auto select model when existing thread is passed
selectModelProvider(model?.provider as string, model?.id as string) if (model) {
} else if (useLastUsedModel) { selectModelProvider(model?.provider as string, model?.id as string)
// Try to use last used model only when explicitly requested (for new chat) if (!checkModelExists(model.provider, model.id)) {
const lastUsed = getLastUsedModel() selectModelProvider('', '')
if (lastUsed) { }
// Verify the last used model still exists // Check mmproj existence for llamacpp models
const provider = providers.find( if (model?.provider === 'llamacpp') {
(p) => p.provider === lastUsed.provider && p.active await checkMmprojExistsAndUpdateOffloadMMprojSetting(
) model.id as string,
const modelExists = provider?.models.find( updateProvider,
(m) => m.id === lastUsed.model getProviderByName
) )
}
if (provider && modelExists) { } else if (useLastUsedModel) {
selectModelProvider(lastUsed.provider, lastUsed.model) // Try to use last used model only when explicitly requested (for new chat)
} else { const lastUsed = getLastUsedModel()
// Fallback to default model if last used model no longer exists if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
selectModelProvider('llamacpp', 'llama3.2:3b') selectModelProvider(lastUsed.provider, lastUsed.model)
if (lastUsed.provider === 'llamacpp') {
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
lastUsed.model,
updateProvider,
getProviderByName
)
}
} else {
selectModelProvider('', '')
} }
} else {
// default model, we should add from setting
selectModelProvider('llamacpp', 'llama3.2:3b')
} }
} else {
// default model for non-new-chat contexts
selectModelProvider('llamacpp', 'llama3.2:3b')
} }
initializeModel()
}, [ }, [
model, model,
selectModelProvider, selectModelProvider,
updateCurrentThreadModel, updateCurrentThreadModel,
providers, providers,
useLastUsedModel, useLastUsedModel,
checkModelExists,
updateProvider,
getProviderByName,
]) ])
// Update display model when selection changes // Update display model when selection changes
@ -151,9 +175,15 @@ const DropdownModelProvider = ({
provider.models.forEach((modelItem) => { provider.models.forEach((modelItem) => {
// Skip models that require API key but don't have one (except llamacpp) // Skip models that require API key but don't have one (except llamacpp)
if (provider.provider !== 'llamacpp' && !provider.api_key?.length) { if (
provider &&
predefinedProviders.some((e) =>
e.provider.includes(provider.provider)
) &&
provider.provider !== 'llamacpp' &&
!provider.api_key?.length
)
return return
}
const capabilities = modelItem.capabilities || [] const capabilities = modelItem.capabilities || []
const capabilitiesString = capabilities.join(' ') const capabilitiesString = capabilities.join(' ')
@ -182,6 +212,13 @@ const DropdownModelProvider = ({
}) })
}, [searchableItems]) }, [searchableItems])
// Get favorite models that are currently available
const favoriteItems = useMemo(() => {
return searchableItems.filter((item) =>
favoriteModels.some((fav) => fav.id === item.model.id)
)
}, [searchableItems, favoriteModels])
// Filter models based on search value // Filter models based on search value
const filteredItems = useMemo(() => { const filteredItems = useMemo(() => {
if (!searchValue) return searchableItems if (!searchValue) return searchableItems
@ -202,7 +239,7 @@ const DropdownModelProvider = ({
}) })
}, [searchableItems, searchValue, fzfInstance]) }, [searchableItems, searchValue, fzfInstance])
// Group filtered items by provider // Group filtered items by provider, excluding favorites when not searching
const groupedItems = useMemo(() => { const groupedItems = useMemo(() => {
const groups: Record<string, SearchableModel[]> = {} const groups: Record<string, SearchableModel[]> = {}
@ -221,14 +258,19 @@ const DropdownModelProvider = ({
if (!groups[providerKey]) { if (!groups[providerKey]) {
groups[providerKey] = [] groups[providerKey] = []
} }
// When not searching, exclude favorite models from regular provider sections
const isFavorite = favoriteModels.some((fav) => fav.id === item.model.id)
if (!searchValue && isFavorite) return // Skip adding this item to regular provider section
groups[providerKey].push(item) groups[providerKey].push(item)
}) })
return groups return groups
}, [filteredItems, providers, searchValue]) }, [filteredItems, providers, searchValue, favoriteModels])
const handleSelect = useCallback( const handleSelect = useCallback(
(searchableModel: SearchableModel) => { async (searchableModel: SearchableModel) => {
selectModelProvider( selectModelProvider(
searchableModel.provider.provider, searchableModel.provider.provider,
searchableModel.model.id searchableModel.model.id
@ -237,6 +279,16 @@ const DropdownModelProvider = ({
id: searchableModel.model.id, id: searchableModel.model.id,
provider: searchableModel.provider.provider, provider: searchableModel.provider.provider,
}) })
// Check mmproj existence for llamacpp models
if (searchableModel.provider.provider === 'llamacpp') {
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
searchableModel.model.id,
updateProvider,
getProviderByName
)
}
// Store the selected model as last used // Store the selected model as last used
if (useLastUsedModel) { if (useLastUsedModel) {
setLastUsedModel( setLastUsedModel(
@ -247,7 +299,13 @@ const DropdownModelProvider = ({
setSearchValue('') setSearchValue('')
setOpen(false) setOpen(false)
}, },
[selectModelProvider, updateCurrentThreadModel, useLastUsedModel] [
selectModelProvider,
updateCurrentThreadModel,
useLastUsedModel,
updateProvider,
getProviderByName,
]
) )
const currentModel = selectedModel?.id const currentModel = selectedModel?.id
@ -330,6 +388,64 @@ const DropdownModelProvider = ({
</div> </div>
) : ( ) : (
<div className="py-1"> <div className="py-1">
{/* Favorites section - only show when not searching */}
{!searchValue && favoriteItems.length > 0 && (
<div className="bg-main-view-fg/2 backdrop-blur-2xl rounded-sm my-1.5 mx-1.5">
{/* Favorites header */}
<div className="flex items-center gap-1.5 px-2 py-1">
<span className="text-sm font-medium text-main-view-fg/80">
{t('common:favorites')}
</span>
</div>
{/* Favorite models */}
{favoriteItems.map((searchableModel) => {
const isSelected =
selectedModel?.id === searchableModel.model.id &&
selectedProvider === searchableModel.provider.provider
const capabilities =
searchableModel.model.capabilities || []
return (
<div
key={`fav-${searchableModel.value}`}
title={searchableModel.model.id}
onClick={() => handleSelect(searchableModel)}
className={cn(
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
'hover:bg-main-view-fg/4',
isSelected &&
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
)}
>
<div className="flex items-center gap-1 flex-1 min-w-0">
<div className="shrink-0 -ml-1">
<ProvidersAvatar
provider={searchableModel.provider}
/>
</div>
<span className="truncate text-main-view-fg/80 text-sm">
{searchableModel.model.id}
</span>
<div className="flex-1"></div>
{capabilities.length > 0 && (
<div className="flex-shrink-0 -mr-1.5">
<Capabilities capabilities={capabilities} />
</div>
)}
</div>
</div>
)
})}
</div>
)}
{/* Divider between favorites and regular providers */}
{favoriteItems.length > 0 && (
<div className="border-b border-1 border-main-view-fg/8 mx-2"></div>
)}
{/* Regular provider sections */}
{Object.entries(groupedItems).map(([providerKey, models]) => { {Object.entries(groupedItems).map(([providerKey, models]) => {
const providerInfo = providers.find( const providerInfo = providers.find(
(p) => p.provider === providerKey (p) => p.provider === providerKey
@ -340,7 +456,7 @@ const DropdownModelProvider = ({
return ( return (
<div <div
key={providerKey} key={providerKey}
className="bg-main-view-fg/4 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0" className="bg-main-view-fg/2 backdrop-blur-2xl first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0"
> >
{/* Provider header */} {/* Provider header */}
<div className="flex items-center justify-between px-2 py-1"> <div className="flex items-center justify-between px-2 py-1">
@ -384,15 +500,20 @@ const DropdownModelProvider = ({
return ( return (
<div <div
key={searchableModel.value} key={searchableModel.value}
title={searchableModel.model.id}
onClick={() => handleSelect(searchableModel)} onClick={() => handleSelect(searchableModel)}
className={cn( className={cn(
'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200', 'mx-1 mb-1 px-2 py-1.5 rounded-sm cursor-pointer flex items-center gap-2 transition-all duration-200',
'hover:bg-main-view-fg/10', 'hover:bg-main-view-fg/4',
isSelected && 'bg-main-view-fg/15' isSelected &&
'bg-main-view-fg/8 hover:bg-main-view-fg/8'
)} )}
> >
<div className="flex items-center gap-2 flex-1 min-w-0"> <div className="flex items-center gap-2 flex-1 min-w-0">
<span className="truncate text-main-view-fg/80 text-sm"> <span
className="truncate text-main-view-fg/80 text-sm"
title={searchableModel.model.id}
>
{searchableModel.model.id} {searchableModel.model.id}
</span> </span>

View File

@ -1,12 +1,18 @@
import { useEffect, useState } from 'react' import { useEffect, useState } from 'react'
import { import {
DropdownMenu, DropDrawer,
DropdownMenuContent, DropDrawerContent,
DropdownMenuItem, DropDrawerItem,
DropdownMenuLabel, DropDrawerSub,
DropdownMenuSeparator, DropDrawerLabel,
DropdownMenuTrigger, DropDrawerSubContent,
} from '@/components/ui/dropdown-menu' DropDrawerSeparator,
DropDrawerSubTrigger,
DropDrawerTrigger,
DropDrawerGroup,
} from '@/components/ui/dropdrawer'
import { Switch } from '@/components/ui/switch' import { Switch } from '@/components/ui/switch'
import { useThreads } from '@/hooks/useThreads' import { useThreads } from '@/hooks/useThreads'
@ -15,6 +21,7 @@ import { useToolAvailable } from '@/hooks/useToolAvailable'
import React from 'react' import React from 'react'
import { useAppState } from '@/hooks/useAppState' import { useAppState } from '@/hooks/useAppState'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { cn } from '@/lib/utils'
interface DropdownToolsAvailableProps { interface DropdownToolsAvailableProps {
children: (isOpen: boolean, toolsCount: number) => React.ReactNode children: (isOpen: boolean, toolsCount: number) => React.ReactNode
@ -82,6 +89,23 @@ export default function DropdownToolsAvailable({
return false return false
} }
const handleDisableAllServerTools = (
serverName: string,
disable: boolean
) => {
const allToolsByServer = getToolsByServer()
const serverTools = allToolsByServer[serverName] || []
serverTools.forEach((tool) => {
handleToolToggle(tool.name, !disable)
})
}
const areAllServerToolsDisabled = (serverName: string): boolean => {
const allToolsByServer = getToolsByServer()
const serverTools = allToolsByServer[serverName] || []
return serverTools.every((tool) => !isToolChecked(tool.name))
}
const getEnabledToolsCount = (): number => { const getEnabledToolsCount = (): number => {
const disabledTools = initialMessage const disabledTools = initialMessage
? getDefaultDisabledTools() ? getDefaultDisabledTools()
@ -91,69 +115,153 @@ export default function DropdownToolsAvailable({
return tools.filter((tool) => !disabledTools.includes(tool.name)).length return tools.filter((tool) => !disabledTools.includes(tool.name)).length
} }
const getToolsByServer = () => {
const toolsByServer = tools.reduce(
(acc, tool) => {
if (!acc[tool.server]) {
acc[tool.server] = []
}
acc[tool.server].push(tool)
return acc
},
{} as Record<string, typeof tools>
)
return toolsByServer
}
const renderTrigger = () => children(isOpen, getEnabledToolsCount()) const renderTrigger = () => children(isOpen, getEnabledToolsCount())
if (tools.length === 0) { if (tools.length === 0) {
return ( return (
<DropdownMenu onOpenChange={handleOpenChange}> <DropDrawer onOpenChange={handleOpenChange}>
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger> <DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
<DropdownMenuContent align="start" className="max-w-64"> <DropDrawerContent align="start" className="max-w-64">
<DropdownMenuItem disabled>{t('common:noToolsAvailable')}</DropdownMenuItem> <DropDrawerItem disabled>
</DropdownMenuContent> {t('common:noToolsAvailable')}
</DropdownMenu> </DropDrawerItem>
</DropDrawerContent>
</DropDrawer>
) )
} }
return ( const toolsByServer = getToolsByServer()
<DropdownMenu onOpenChange={handleOpenChange}>
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
<DropdownMenuContent return (
<DropDrawer onOpenChange={handleOpenChange}>
<DropDrawerTrigger asChild>{renderTrigger()}</DropDrawerTrigger>
<DropDrawerContent
side="top" side="top"
align="start" align="start"
className="max-w-64 backdrop-blur-xl bg-main-view" className="bg-main-view !overflow-hidden"
onClick={(e) => e.stopPropagation()}
> >
<DropdownMenuLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-2 "> <DropDrawerLabel className="flex items-center gap-2 sticky -top-1 z-10 px-4 pl-2 py-1">
Available Tools Available Tools
</DropdownMenuLabel> </DropDrawerLabel>
<DropdownMenuSeparator /> <DropDrawerSeparator />
<div className="max-h-64 overflow-y-auto"> <div className="max-h-64 overflow-y-auto">
{tools.map((tool) => { <DropDrawerGroup>
const isChecked = isToolChecked(tool.name) {Object.entries(toolsByServer).map(([serverName, serverTools]) => (
return ( <DropDrawerSub
<div id={`server-${serverName}`}
key={tool.name} key={serverName}
className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full" title={serverName}
> >
<div className="flex items-start justify-center gap-3"> <DropDrawerSubTrigger className="py-2 hover:bg-main-view-fg/5 hover:backdrop-blur-2xl rounded-sm px-2 mx-auto w-full">
<div className="flex items-start justify-between gap-4 w-full"> <div className="flex items-center justify-between w-full">
<div className="overflow-hidden w-full flex flex-col "> <span className="text-sm text-main-view-fg/80">
<div className="truncate"> {serverName}
<span className="text-sm font-medium" title={tool.name}> </span>
{tool.name} <span className="text-xs text-main-view-fg/50 inline-flex items-center mr-1 border border-main-view-fg/20 px-1 rounded-sm">
</span> {
</div> serverTools.filter((tool) => isToolChecked(tool.name))
{tool.description && ( .length
<p className="text-xs text-main-view-fg/70 mt-1 line-clamp-2"> }
{tool.description} </span>
</p>
)}
</div>
<div className="shrink-0 mx-auto">
<Switch
checked={isChecked}
onCheckedChange={(checked) =>
handleToolToggle(tool.name, checked)
}
/>
</div>
</div> </div>
</div> </DropDrawerSubTrigger>
</div> <DropDrawerSubContent className="max-w-64 max-h-70 w-full overflow-hidden">
) <DropDrawerGroup>
})} {serverTools.length > 1 && (
<div className="sticky top-0 z-10 bg-main-view border-b border-main-view-fg/10 px-4 md:px-2 pr-2 py-1.5 flex items-center justify-between">
<span className="text-xs font-medium text-main-view-fg/70">
Disable All Tools
</span>
<div
className={cn(
'flex items-center gap-2',
serverTools.length > 5
? 'mr-3 md:mr-1.5'
: 'mr-2 md:mr-0'
)}
>
<Switch
checked={!areAllServerToolsDisabled(serverName)}
onCheckedChange={(checked) =>
handleDisableAllServerTools(serverName, !checked)
}
/>
</div>
</div>
)}
<div className="max-h-56 overflow-y-auto">
{serverTools.map((tool) => {
const isChecked = isToolChecked(tool.name)
return (
<DropDrawerItem
onClick={(e) => {
handleToolToggle(tool.name, !isChecked)
e.preventDefault()
}}
onSelect={(e) => {
handleToolToggle(tool.name, !isChecked)
e.preventDefault()
}}
key={tool.name}
className="mt-1 first:mt-0 py-1.5"
icon={
<Switch
checked={isChecked}
onCheckedChange={(checked) => {
console.log('checked', checked)
handleToolToggle(tool.name, checked)
}}
onClick={(e) => {
e.stopPropagation()
}}
/>
}
>
<div className="overflow-hidden flex flex-col items-start ">
<div className="truncate">
<span
className="text-sm font-medium text-main-view-fg"
title={tool.name}
>
{tool.name}
</span>
</div>
{tool.description && (
<p
className="text-xs text-main-view-fg/70 mt-1 line-clamp-1"
title={tool.description}
>
{tool.description}
</p>
)}
</div>
</DropDrawerItem>
)
})}
</div>
</DropDrawerGroup>
</DropDrawerSubContent>
</DropDrawerSub>
))}
</DropDrawerGroup>
</div> </div>
</DropdownMenuContent> </DropDrawerContent>
</DropdownMenu> </DropDrawer>
) )
} }

View File

@ -0,0 +1,25 @@
import { IconStar, IconStarFilled } from '@tabler/icons-react'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
interface FavoriteModelActionProps {
model: Model
}
export function FavoriteModelAction({ model }: FavoriteModelActionProps) {
const { isFavorite, toggleFavorite } = useFavoriteModel()
const isModelFavorite = isFavorite(model.id)
return (
<div
aria-label="Toggle favorite"
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => toggleFavorite(model)}
>
{isModelFavorite ? (
<IconStarFilled size={18} className="text-main-view-fg" />
) : (
<IconStar size={18} className="text-main-view-fg/50" />
)}
</div>
)
}

View File

@ -21,15 +21,15 @@ export default function GlobalError({ error }: GlobalErrorProps) {
> >
<path <path
d="M6 8H6.01M6 16H6.01M6 12H18C20.2091 12 22 10.2091 22 8C22 5.79086 20.2091 4 18 4H6C3.79086 4 2 5.79086 2 8C2 10.2091 3.79086 12 6 12ZM6 12C3.79086 12 2 13.7909 2 16C2 18.2091 3.79086 20 6 20H14" d="M6 8H6.01M6 16H6.01M6 12H18C20.2091 12 22 10.2091 22 8C22 5.79086 20.2091 4 18 4H6C3.79086 4 2 5.79086 2 8C2 10.2091 3.79086 12 6 12ZM6 12C3.79086 12 2 13.7909 2 16C2 18.2091 3.79086 20 6 20H14"
stroke-width="2" strokeWidth="2"
stroke-linecap="round" strokeLinecap="round"
stroke-linejoin="round" strokeLinejoin="round"
></path> ></path>
<path <path
d="M17 16L22 21M22 16L17 21" d="M17 16L22 21M22 16L17 21"
stroke-width="2" strokeWidth="2"
stroke-linecap="round" strokeLinecap="round"
stroke-linejoin="round" strokeLinejoin="round"
></path> ></path>
</svg> </svg>
</div> </div>

View File

@ -70,8 +70,8 @@ export function ModelSetting({
models: updatedModels, models: updatedModels,
}) })
// Call debounced stopModel only when updating ctx_len or ngl // Call debounced stopModel only when updating ctx_len, ngl, chat_template, or offload_mmproj
if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template') { if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template' || key === 'offload_mmproj') {
debouncedStopModel(model.id) debouncedStopModel(model.id)
} }
} }

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
import { RenderMarkdown } from './RenderMarkdown' import { RenderMarkdown } from './RenderMarkdown'
import React, { Fragment, memo, useCallback, useMemo, useState } from 'react' import React, { Fragment, memo, useCallback, useMemo, useState } from 'react'
@ -144,7 +145,7 @@ export const ThreadContent = memo(
isLastMessage?: boolean isLastMessage?: boolean
index?: number index?: number
showAssistant?: boolean showAssistant?: boolean
// eslint-disable-next-line @typescript-eslint/no-explicit-any
streamTools?: any streamTools?: any
contextOverflowModal?: React.ReactNode | null contextOverflowModal?: React.ReactNode | null
updateMessage?: (item: ThreadMessage, message: string) => void updateMessage?: (item: ThreadMessage, message: string) => void
@ -172,9 +173,12 @@ export const ThreadContent = memo(
const { reasoningSegment, textSegment } = useMemo(() => { const { reasoningSegment, textSegment } = useMemo(() => {
// Check for thinking formats // Check for thinking formats
const hasThinkTag = text.includes('<think>') && !text.includes('</think>') const hasThinkTag = text.includes('<think>') && !text.includes('</think>')
const hasAnalysisChannel = text.includes('<|channel|>analysis<|message|>') && !text.includes('<|start|>assistant<|channel|>final<|message|>') const hasAnalysisChannel =
text.includes('<|channel|>analysis<|message|>') &&
if (hasThinkTag || hasAnalysisChannel) return { reasoningSegment: text, textSegment: '' } !text.includes('<|start|>assistant<|channel|>final<|message|>')
if (hasThinkTag || hasAnalysisChannel)
return { reasoningSegment: text, textSegment: '' }
// Check for completed think tag format // Check for completed think tag format
const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/) const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/)
@ -187,7 +191,9 @@ export const ThreadContent = memo(
} }
// Check for completed analysis channel format // Check for completed analysis channel format
const analysisMatch = text.match(/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/) const analysisMatch = text.match(
/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/
)
if (analysisMatch?.index !== undefined) { if (analysisMatch?.index !== undefined) {
const splitIndex = analysisMatch.index + analysisMatch[0].length const splitIndex = analysisMatch.index + analysisMatch[0].length
return { return {
@ -213,7 +219,36 @@ export const ThreadContent = memo(
} }
if (toSendMessage) { if (toSendMessage) {
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '') deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
sendMessage(toSendMessage.content?.[0]?.text?.value || '') // Extract text content and any attachments
const textContent =
toSendMessage.content?.find((c) => c.type === 'text')?.text?.value ||
''
const attachments = toSendMessage.content
?.filter((c) => (c.type === 'image_url' && c.image_url?.url) || false)
.map((c) => {
if (c.type === 'image_url' && c.image_url?.url) {
const url = c.image_url.url
const [mimeType, base64] = url
.replace('data:', '')
.split(';base64,')
return {
name: 'image', // We don't have the original filename
type: mimeType,
size: 0, // We don't have the original size
base64: base64,
dataUrl: url,
}
}
return null
})
.filter(Boolean) as Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
sendMessage(textContent, true, attachments)
} }
}, [deleteMessage, getMessages, item, sendMessage]) }, [deleteMessage, getMessages, item, sendMessage])
@ -255,22 +290,68 @@ export const ThreadContent = memo(
return ( return (
<Fragment> <Fragment>
{item.content?.[0]?.text && item.role === 'user' && ( {item.role === 'user' && (
<div className="w-full"> <div className="w-full">
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal"> {/* Render attachments above the message bubble */}
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] "> {item.content?.some(
<div className="select-text"> (c) => (c.type === 'image_url' && c.image_url?.url) || false
<RenderMarkdown ) && (
content={item.content?.[0].text.value} <div className="flex justify-end w-full mb-2">
components={linkComponents} <div className="flex flex-wrap gap-2 max-w-[80%] justify-end">
isUser {item.content
/> ?.filter(
(c) =>
(c.type === 'image_url' && c.image_url?.url) || false
)
.map((contentPart, index) => {
// Handle images
if (
contentPart.type === 'image_url' &&
contentPart.image_url?.url
) {
return (
<div key={index} className="relative">
<img
src={contentPart.image_url.url}
alt="Uploaded attachment"
className="size-40 rounded-md object-cover border border-main-view-fg/10"
/>
</div>
)
}
return null
})}
</div> </div>
</div> </div>
</div> )}
{/* Render text content in the message bubble */}
{item.content?.some((c) => c.type === 'text' && c.text?.value) && (
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal">
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] ">
<div className="select-text">
{item.content
?.filter((c) => c.type === 'text' && c.text?.value)
.map((contentPart, index) => (
<div key={index}>
<RenderMarkdown
content={contentPart.text!.value}
components={linkComponents}
isUser
/>
</div>
))}
</div>
</div>
</div>
)}
<div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2"> <div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2">
<EditDialog <EditDialog
message={item.content?.[0]?.text.value} message={
item.content?.find((c) => c.type === 'text')?.text?.value ||
''
}
setMessage={(message) => { setMessage={(message) => {
if (item.updateMessage) { if (item.updateMessage) {
item.updateMessage(item, message) item.updateMessage(item, message)

View File

@ -73,6 +73,11 @@ vi.mock('@/services/mcp', () => ({
vi.mock('@/services/models', () => ({ vi.mock('@/services/models', () => ({
stopAllModels: vi.fn(), stopAllModels: vi.fn(),
checkMmprojExists: vi.fn(() => Promise.resolve(true)),
}))
vi.mock('../MovingBorder', () => ({
MovingBorder: ({ children }: { children: React.ReactNode }) => <div data-testid="moving-border">{children}</div>,
})) }))
describe('ChatInput', () => { describe('ChatInput', () => {
@ -231,7 +236,7 @@ describe('ChatInput', () => {
const sendButton = document.querySelector('[data-test-id="send-message-button"]') const sendButton = document.querySelector('[data-test-id="send-message-button"]')
await user.click(sendButton) await user.click(sendButton)
expect(mockSendMessage).toHaveBeenCalledWith('Hello world') expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
}) })
it('sends message when Enter key is pressed', async () => { it('sends message when Enter key is pressed', async () => {
@ -248,7 +253,7 @@ describe('ChatInput', () => {
const textarea = screen.getByRole('textbox') const textarea = screen.getByRole('textbox')
await user.type(textarea, '{Enter}') await user.type(textarea, '{Enter}')
expect(mockSendMessage).toHaveBeenCalledWith('Hello world') expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
}) })
it('does not send message when Shift+Enter is pressed', async () => { it('does not send message when Shift+Enter is pressed', async () => {
@ -343,9 +348,12 @@ describe('ChatInput', () => {
const user = userEvent.setup() const user = userEvent.setup()
renderWithRouter() renderWithRouter()
// File upload is rendered as hidden input element // Wait for async effects to complete (mmproj check)
const fileInput = document.querySelector('input[type="file"]') await waitFor(() => {
expect(fileInput).toBeInTheDocument() // File upload is rendered as hidden input element
const fileInput = document.querySelector('input[type="file"]')
expect(fileInput).toBeInTheDocument()
})
}) })
it('disables input when streaming', () => { it('disables input when streaming', () => {
@ -361,7 +369,7 @@ describe('ChatInput', () => {
renderWithRouter() renderWithRouter()
}) })
const textarea = screen.getByRole('textbox') const textarea = screen.getByTestId('chat-input')
expect(textarea).toBeDisabled() expect(textarea).toBeDisabled()
}) })
@ -378,4 +386,28 @@ describe('ChatInput', () => {
expect(toolsIcon).toBeInTheDocument() expect(toolsIcon).toBeInTheDocument()
}) })
}) })
it('uses selectedProvider for provider checks', () => {
// Test that the component correctly uses selectedProvider instead of selectedModel.provider
vi.mocked(useModelProvider).mockReturnValue({
selectedModel: {
id: 'test-model',
capabilities: ['vision'],
},
providers: [],
getModelBy: vi.fn(),
selectModelProvider: vi.fn(),
selectedProvider: 'llamacpp',
setProviders: vi.fn(),
getProviderByName: vi.fn(),
updateProvider: vi.fn(),
addProvider: vi.fn(),
deleteProvider: vi.fn(),
deleteModel: vi.fn(),
deletedModels: [],
})
// This test ensures the component renders without errors when using selectedProvider
expect(() => renderWithRouter()).not.toThrow()
})
}) })

View File

@ -63,6 +63,7 @@ export default function AddEditAssistant({
const emojiPickerRef = useRef<HTMLDivElement>(null) const emojiPickerRef = useRef<HTMLDivElement>(null)
const emojiPickerTriggerRef = useRef<HTMLDivElement>(null) const emojiPickerTriggerRef = useRef<HTMLDivElement>(null)
const [nameError, setNameError] = useState<string | null>(null) const [nameError, setNameError] = useState<string | null>(null)
const [toolSteps, setToolSteps] = useState(20)
// Handle click outside emoji picker or trigger // Handle click outside emoji picker or trigger
useEffect(() => { useEffect(() => {
@ -94,6 +95,8 @@ export default function AddEditAssistant({
setDescription(initialData.description) setDescription(initialData.description)
setInstructions(initialData.instructions) setInstructions(initialData.instructions)
setShowEmojiPicker(false) setShowEmojiPicker(false)
setToolSteps(initialData.tool_steps ?? 20)
// Convert parameters object to arrays of keys and values // Convert parameters object to arrays of keys and values
const keys = Object.keys(initialData.parameters || {}) const keys = Object.keys(initialData.parameters || {})
const values = Object.values(initialData.parameters || {}) const values = Object.values(initialData.parameters || {})
@ -125,6 +128,7 @@ export default function AddEditAssistant({
setParamsTypes(['string']) setParamsTypes(['string'])
setNameError(null) setNameError(null)
setShowEmojiPicker(false) setShowEmojiPicker(false)
setToolSteps(20)
} }
const handleParameterChange = ( const handleParameterChange = (
@ -221,6 +225,7 @@ export default function AddEditAssistant({
description, description,
instructions, instructions,
parameters: parameters || {}, parameters: parameters || {},
tool_steps: toolSteps,
} }
onSave(assistant) onSave(assistant)
onOpenChange(false) onOpenChange(false)
@ -329,6 +334,29 @@ export default function AddEditAssistant({
/> />
</div> </div>
<div className="space-y-2 my-4 mt-6">
<div className="flex items-center justify-between">
<label className="text-sm">{t('common:settings')}</label>
</div>
<div className="flex justify-between items-center gap-2">
<div className="w-full">
<p className="text-sm">{t('assistants:maxToolSteps')}</p>
</div>
<Input
value={toolSteps}
type="number"
min={0}
onChange={(e) => {
const newSteps = e.target.value
const stepNumber = Number(newSteps)
setToolSteps(isNaN(stepNumber) ? 20 : stepNumber)
}}
placeholder="20"
className="w-18 text-right"
/>
</div>
</div>
<div className="space-y-2 my-4"> <div className="space-y-2 my-4">
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
<label className="text-sm"> <label className="text-sm">

View File

@ -8,7 +8,13 @@ import {
} from '@/components/ui/dialog' } from '@/components/ui/dialog'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input' import { Input } from '@/components/ui/input'
import { IconPlus, IconTrash, IconGripVertical } from '@tabler/icons-react' import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'
import {
IconPlus,
IconTrash,
IconGripVertical,
IconCodeDots,
} from '@tabler/icons-react'
import { MCPServerConfig } from '@/hooks/useMCPServers' import { MCPServerConfig } from '@/hooks/useMCPServers'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { import {
@ -27,6 +33,8 @@ import {
} from '@dnd-kit/sortable' } from '@dnd-kit/sortable'
import { CSS } from '@dnd-kit/utilities' import { CSS } from '@dnd-kit/utilities'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import CodeEditor from '@uiw/react-textarea-code-editor'
import '@uiw/react-textarea-code-editor/dist.css'
// Sortable argument item component // Sortable argument item component
function SortableArgItem({ function SortableArgItem({
@ -114,13 +122,34 @@ export default function AddEditMCPServer({
const [args, setArgs] = useState<string[]>(['']) const [args, setArgs] = useState<string[]>([''])
const [envKeys, setEnvKeys] = useState<string[]>(['']) const [envKeys, setEnvKeys] = useState<string[]>([''])
const [envValues, setEnvValues] = useState<string[]>(['']) const [envValues, setEnvValues] = useState<string[]>([''])
const [transportType, setTransportType] = useState<'stdio' | 'http' | 'sse'>(
'stdio'
)
const [url, setUrl] = useState('')
const [headerKeys, setHeaderKeys] = useState<string[]>([''])
const [headerValues, setHeaderValues] = useState<string[]>([''])
const [timeout, setTimeout] = useState('')
const [isToggled, setIsToggled] = useState(false)
const [jsonContent, setJsonContent] = useState('')
const [error, setError] = useState<string | null>(null)
// Reset form when modal opens/closes or editing key changes // Reset form when modal opens/closes or editing key changes
useEffect(() => { useEffect(() => {
if (open && editingKey && initialData) { if (open && editingKey && initialData) {
setServerName(editingKey) setServerName(editingKey)
setCommand(initialData.command) setCommand(initialData.command || '')
setUrl(initialData.url || '')
setTimeout(initialData.timeout ? initialData.timeout.toString() : '')
setArgs(initialData.args?.length > 0 ? initialData.args : ['']) setArgs(initialData.args?.length > 0 ? initialData.args : [''])
setTransportType(initialData?.type || 'stdio')
// Initialize JSON content for toggle mode
try {
const jsonData = { [editingKey]: initialData }
setJsonContent(JSON.stringify(jsonData, null, 2))
} catch {
setJsonContent('')
}
if (initialData.env) { if (initialData.env) {
// Convert env object to arrays of keys and values // Convert env object to arrays of keys and values
@ -130,6 +159,17 @@ export default function AddEditMCPServer({
setEnvKeys(keys.length > 0 ? keys : ['']) setEnvKeys(keys.length > 0 ? keys : [''])
setEnvValues(values.length > 0 ? values : ['']) setEnvValues(values.length > 0 ? values : [''])
} }
if (initialData.headers) {
// Convert headers object to arrays of keys and values
const headerKeysList = Object.keys(initialData.headers)
const headerValuesList = headerKeysList.map(
(key) => initialData.headers![key]
)
setHeaderKeys(headerKeysList.length > 0 ? headerKeysList : [''])
setHeaderValues(headerValuesList.length > 0 ? headerValuesList : [''])
}
} else if (open) { } else if (open) {
// Add mode - reset form // Add mode - reset form
resetForm() resetForm()
@ -139,9 +179,17 @@ export default function AddEditMCPServer({
const resetForm = () => { const resetForm = () => {
setServerName('') setServerName('')
setCommand('') setCommand('')
setUrl('')
setTimeout('')
setArgs(['']) setArgs([''])
setEnvKeys(['']) setEnvKeys([''])
setEnvValues(['']) setEnvValues([''])
setHeaderKeys([''])
setHeaderValues([''])
setTransportType('stdio')
setIsToggled(false)
setJsonContent('')
setError(null)
} }
const handleAddArg = () => { const handleAddArg = () => {
@ -201,7 +249,57 @@ export default function AddEditMCPServer({
setEnvValues(newValues) setEnvValues(newValues)
} }
const handleAddHeader = () => {
setHeaderKeys([...headerKeys, ''])
setHeaderValues([...headerValues, ''])
}
const handleRemoveHeader = (index: number) => {
const newKeys = [...headerKeys]
const newValues = [...headerValues]
newKeys.splice(index, 1)
newValues.splice(index, 1)
setHeaderKeys(newKeys.length > 0 ? newKeys : [''])
setHeaderValues(newValues.length > 0 ? newValues : [''])
}
const handleHeaderKeyChange = (index: number, value: string) => {
const newKeys = [...headerKeys]
newKeys[index] = value
setHeaderKeys(newKeys)
}
const handleHeaderValueChange = (index: number, value: string) => {
const newValues = [...headerValues]
newValues[index] = value
setHeaderValues(newValues)
}
const handleSave = () => { const handleSave = () => {
// Handle JSON mode
if (isToggled) {
try {
const parsedData = JSON.parse(jsonContent)
// Validate that it's an object with server configurations
if (typeof parsedData !== 'object' || parsedData === null) {
setError(t('mcp-servers:editJson.errorFormat'))
return
}
// For each server in the JSON, call onSave
Object.entries(parsedData).forEach(([serverName, config]) => {
onSave(serverName.trim(), config as MCPServerConfig)
})
onOpenChange(false)
resetForm()
setError(null)
return
} catch {
setError(t('mcp-servers:editJson.errorFormat'))
return
}
}
// Handle form mode
// Convert env arrays to object // Convert env arrays to object
const envObj: Record<string, string> = {} const envObj: Record<string, string> = {}
envKeys.forEach((key, index) => { envKeys.forEach((key, index) => {
@ -211,13 +309,28 @@ export default function AddEditMCPServer({
} }
}) })
// Convert headers arrays to object
const headersObj: Record<string, string> = {}
headerKeys.forEach((key, index) => {
const keyName = key.trim()
if (keyName !== '') {
headersObj[keyName] = headerValues[index]?.trim() || ''
}
})
// Filter out empty args // Filter out empty args
const filteredArgs = args.map((arg) => arg.trim()).filter((arg) => arg) const filteredArgs = args.map((arg) => arg.trim()).filter((arg) => arg)
const config: MCPServerConfig = { const config: MCPServerConfig = {
command: command.trim(), command: transportType === 'stdio' ? command.trim() : '',
args: filteredArgs, args: transportType === 'stdio' ? filteredArgs : [],
env: envObj, env: transportType === 'stdio' ? envObj : {},
type: transportType,
...(transportType !== 'stdio' && {
url: url.trim(),
headers: Object.keys(headersObj).length > 0 ? headersObj : undefined,
timeout: timeout.trim() !== '' ? parseInt(timeout) : undefined,
}),
} }
if (serverName.trim() !== '') { if (serverName.trim() !== '') {
@ -229,122 +342,303 @@ export default function AddEditMCPServer({
return ( return (
<Dialog open={open} onOpenChange={onOpenChange}> <Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent> <DialogContent showCloseButton={false}>
<DialogHeader> <DialogHeader>
<DialogTitle> <DialogTitle className="flex items-center justify-between">
{editingKey <span>
? t('mcp-servers:editServer') {editingKey
: t('mcp-servers:addServer')} ? t('mcp-servers:editServer')
: t('mcp-servers:addServer')}
</span>
<div
className={cn(
'size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out',
isToggled && 'bg-main-view-fg/10 text-accent'
)}
title="Add server by JSON"
onClick={() => setIsToggled(!isToggled)}
>
<IconCodeDots className="h-5 w-5 cursor-pointer transition-colors duration-200" />
</div>
</DialogTitle> </DialogTitle>
</DialogHeader> </DialogHeader>
<div className="space-y-4"> {isToggled ? (
<div className="space-y-2"> <div className="space-y-4">
<label className="text-sm mb-2 inline-block"> <div className="space-y-2">
{t('mcp-servers:serverName')} <label className="text-sm mb-2 inline-block">
</label> {t('mcp-servers:editJson.placeholder')}
<Input </label>
value={serverName} <div className="border border-main-view-fg/10 rounded-md overflow-hidden">
onChange={(e) => setServerName(e.target.value)} <CodeEditor
placeholder={t('mcp-servers:enterServerName')} value={jsonContent}
autoFocus language="json"
/> placeholder={`{
</div> "serverName": {
"command": "command",
<div className="space-y-2"> "args": ["arg1", "arg2"],
<label className="text-sm mb-2 inline-block"> "env": {
{t('mcp-servers:command')} "KEY": "value"
</label> }
<Input }
value={command} }`}
onChange={(e) => setCommand(e.target.value)} onChange={(e) => {
placeholder={t('mcp-servers:enterCommand')} setJsonContent(e.target.value)
/> setError(null)
</div> }}
onPaste={() => setError(null)}
<div className="space-y-2"> style={{
<div className="flex items-center justify-between"> fontFamily: 'ui-monospace',
<label className="text-sm">{t('mcp-servers:arguments')}</label> backgroundColor: 'transparent',
<div wordBreak: 'break-all',
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out" overflowWrap: 'anywhere',
onClick={handleAddArg} whiteSpace: 'pre-wrap',
> }}
<IconPlus size={18} className="text-main-view-fg/60" /> className="w-full !text-sm min-h-[300px]"
/>
</div> </div>
{error && <div className="text-destructive text-sm">{error}</div>}
</div>
</div>
) : (
<div className="space-y-4">
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:serverName')}
</label>
<Input
value={serverName}
onChange={(e) => setServerName(e.target.value)}
placeholder={t('mcp-servers:enterServerName')}
autoFocus
/>
</div> </div>
<DndContext <div className="space-y-2">
sensors={sensors} <label className="text-sm mb-2 inline-block">
collisionDetection={closestCenter} Transport Type
onDragEnd={(event) => { </label>
const { active, over } = event <RadioGroup
if (active.id !== over?.id) { value={transportType}
const oldIndex = parseInt(active.id.toString()) onValueChange={(value) =>
const newIndex = parseInt(over?.id.toString() || '0') setTransportType(value as 'http' | 'sse')
handleReorderArgs(oldIndex, newIndex)
} }
}} className="flex gap-6"
>
<SortableContext
items={args.map((_, index) => index)}
strategy={verticalListSortingStrategy}
> >
{args.map((arg, index) => ( <div className="flex items-center space-x-2">
<SortableArgItem <RadioGroupItem value="stdio" id="stdio" />
key={index} <label
id={index} htmlFor="stdio"
value={arg} className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
onChange={(value) => handleArgChange(index, value)} >
onRemove={() => handleRemoveArg(index)} STDIO
canRemove={args.length > 1} </label>
placeholder={t('mcp-servers:argument', { </div>
index: index + 1, <div className="flex items-center space-x-2">
})} <RadioGroupItem value="http" id="http" />
/> <label
))} htmlFor="http"
</SortableContext> className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
</DndContext> >
</div> HTTP
</label>
<div className="space-y-2"> </div>
<div className="flex items-center justify-between"> <div className="flex items-center space-x-2">
<label className="text-sm">{t('mcp-servers:envVars')}</label> <RadioGroupItem value="sse" id="sse" />
<div <label
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out" htmlFor="sse"
onClick={handleAddEnv} className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
> >
<IconPlus size={18} className="text-main-view-fg/60" /> SSE
</div> </label>
</div>
</RadioGroup>
</div> </div>
{envKeys.map((key, index) => ( {transportType === 'stdio' ? (
<div key={`env-${index}`} className="flex items-center gap-2"> <div className="space-y-2">
<label className="text-sm mb-2 inline-block">
{t('mcp-servers:command')}
</label>
<Input <Input
value={key} value={command}
onChange={(e) => handleEnvKeyChange(index, e.target.value)} onChange={(e) => setCommand(e.target.value)}
placeholder={t('mcp-servers:key')} placeholder={t('mcp-servers:enterCommand')}
className="flex-1"
/> />
</div>
) : (
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">URL</label>
<Input <Input
value={envValues[index] || ''} value={url}
onChange={(e) => handleEnvValueChange(index, e.target.value)} onChange={(e) => setUrl(e.target.value)}
placeholder={t('mcp-servers:value')} placeholder="Enter URL"
className="flex-1"
/> />
{envKeys.length > 1 && ( </div>
)}
{transportType === 'stdio' && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">
{t('mcp-servers:arguments')}
</label>
<div <div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out" className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveEnv(index)} onClick={handleAddArg}
> >
<IconTrash size={18} className="text-destructive" /> <IconPlus size={18} className="text-main-view-fg/60" />
</div> </div>
)} </div>
<DndContext
sensors={sensors}
collisionDetection={closestCenter}
onDragEnd={(event) => {
const { active, over } = event
if (active.id !== over?.id) {
const oldIndex = parseInt(active.id.toString())
const newIndex = parseInt(over?.id.toString() || '0')
handleReorderArgs(oldIndex, newIndex)
}
}}
>
<SortableContext
items={args.map((_, index) => index)}
strategy={verticalListSortingStrategy}
>
{args.map((arg, index) => (
<SortableArgItem
key={index}
id={index}
value={arg}
onChange={(value) => handleArgChange(index, value)}
onRemove={() => handleRemoveArg(index)}
canRemove={args.length > 1}
placeholder={t('mcp-servers:argument', {
index: index + 1,
})}
/>
))}
</SortableContext>
</DndContext>
</div> </div>
))} )}
{transportType === 'stdio' && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">{t('mcp-servers:envVars')}</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddEnv}
>
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
</div>
{envKeys.map((key, index) => (
<div key={`env-${index}`} className="flex items-center gap-2">
<Input
value={key}
onChange={(e) =>
handleEnvKeyChange(index, e.target.value)
}
placeholder={t('mcp-servers:key')}
className="flex-1"
/>
<Input
value={envValues[index] || ''}
onChange={(e) =>
handleEnvValueChange(index, e.target.value)
}
placeholder={t('mcp-servers:value')}
className="flex-1"
/>
{envKeys.length > 1 && (
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveEnv(index)}
>
<IconTrash size={18} className="text-destructive" />
</div>
)}
</div>
))}
</div>
)}
{(transportType === 'http' || transportType === 'sse') && (
<>
<div className="space-y-2">
<div className="flex items-center justify-between">
<label className="text-sm">Headers</label>
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={handleAddHeader}
>
<IconPlus size={18} className="text-main-view-fg/60" />
</div>
</div>
{headerKeys.map((key, index) => (
<div
key={`header-${index}`}
className="flex items-center gap-2"
>
<Input
value={key}
onChange={(e) =>
handleHeaderKeyChange(index, e.target.value)
}
placeholder="Header name"
className="flex-1"
/>
<Input
value={headerValues[index] || ''}
onChange={(e) =>
handleHeaderValueChange(index, e.target.value)
}
placeholder="Header value"
className="flex-1"
/>
{headerKeys.length > 1 && (
<div
className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out"
onClick={() => handleRemoveHeader(index)}
>
<IconTrash size={18} className="text-destructive" />
</div>
)}
</div>
))}
</div>
<div className="space-y-2">
<label className="text-sm mb-2 inline-block">
Timeout (seconds)
</label>
<Input
value={timeout}
onChange={(e) => setTimeout(e.target.value)}
placeholder="Enter timeout in seconds"
type="number"
/>
</div>
</>
)}
</div> </div>
</div> )}
<DialogFooter> <DialogFooter>
<Button onClick={handleSave}>{t('mcp-servers:save')}</Button> <Button variant="link" onClick={() => onOpenChange(false)}>
{t('common:cancel')}
</Button>
<Button
onClick={handleSave}
disabled={!isToggled && serverName.trim() === ''}
>
{t('mcp-servers:save')}
</Button>
</DialogFooter> </DialogFooter>
</DialogContent> </DialogContent>
</Dialog> </Dialog>

View File

@ -18,6 +18,7 @@ import { IconTrash } from '@tabler/icons-react'
import { useState, useEffect } from 'react' import { useState, useEffect } from 'react'
import { toast } from 'sonner' import { toast } from 'sonner'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
type DialogDeleteModelProps = { type DialogDeleteModelProps = {
provider: ModelProvider provider: ModelProvider
@ -31,8 +32,12 @@ export const DialogDeleteModel = ({
const { t } = useTranslation() const { t } = useTranslation()
const [selectedModelId, setSelectedModelId] = useState<string>('') const [selectedModelId, setSelectedModelId] = useState<string>('')
const { setProviders, deleteModel: deleteModelCache } = useModelProvider() const { setProviders, deleteModel: deleteModelCache } = useModelProvider()
const { removeFavorite } = useFavoriteModel()
const removeModel = async () => { const removeModel = async () => {
// Remove model from favorites if it exists
removeFavorite(selectedModelId)
deleteModelCache(selectedModelId) deleteModelCache(selectedModelId)
deleteModel(selectedModelId).then(() => { deleteModel(selectedModelId).then(() => {
getProviders().then((providers) => { getProviders().then((providers) => {

View File

@ -18,6 +18,7 @@ import { useRouter } from '@tanstack/react-router'
import { route } from '@/constants/routes' import { route } from '@/constants/routes'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { predefinedProviders } from '@/consts/providers' import { predefinedProviders } from '@/consts/providers'
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
type Props = { type Props = {
provider?: ProviderObject provider?: ProviderObject
@ -25,6 +26,7 @@ type Props = {
const DeleteProvider = ({ provider }: Props) => { const DeleteProvider = ({ provider }: Props) => {
const { t } = useTranslation() const { t } = useTranslation()
const { deleteProvider, providers } = useModelProvider() const { deleteProvider, providers } = useModelProvider()
const { favoriteModels, removeFavorite } = useFavoriteModel()
const router = useRouter() const router = useRouter()
if ( if (
!provider || !provider ||
@ -34,6 +36,14 @@ const DeleteProvider = ({ provider }: Props) => {
return null return null
const removeProvider = async () => { const removeProvider = async () => {
// Remove favorite models that belong to this provider
const providerModelIds = provider.models.map((model) => model.id)
favoriteModels.forEach((favoriteModel) => {
if (providerModelIds.includes(favoriteModel.id)) {
removeFavorite(favoriteModel.id)
}
})
deleteProvider(provider.provider) deleteProvider(provider.provider)
toast.success(t('providers:deleteProvider.title'), { toast.success(t('providers:deleteProvider.title'), {
id: `delete-provider-${provider.provider}`, id: `delete-provider-${provider.provider}`,

View File

@ -43,19 +43,9 @@ export default function EditJsonMCPserver({
} }
}, [open, initialData, t]) }, [open, initialData, t])
const handlePaste = (e: React.ClipboardEvent) => { const handlePaste = () => {
const pastedText = e.clipboardData.getData('text') // Clear any existing errors when pasting
try { setError(null)
const parsedJson = JSON.parse(pastedText)
const prettifiedJson = JSON.stringify(parsedJson, null, 2)
e.preventDefault()
setJsonContent(prettifiedJson)
setError(null)
} catch (error) {
e.preventDefault()
setError(t('mcp-servers:editJson.errorPaste'))
console.error('Paste error:', error)
}
} }
const handleSave = () => { const handleSave = () => {
@ -80,7 +70,18 @@ export default function EditJsonMCPserver({
</DialogTitle> </DialogTitle>
</DialogHeader> </DialogHeader>
<div className="space-y-2"> <div className="space-y-2">
<div className="border border-main-view-fg/10 rounded-md overflow-hidden"> <div className="border border-main-view-fg/10 rounded-md !overflow-hidden">
<style>{`
.w-tc-editor textarea {
word-break: break-all !important;
overflow-wrap: anywhere !important;
white-space: pre-wrap !important;
}
.w-tc-editor .token.string {
word-break: break-all !important;
overflow-wrap: anywhere !important;
}
`}</style>
<CodeEditor <CodeEditor
value={jsonContent} value={jsonContent}
language="json" language="json"
@ -90,8 +91,11 @@ export default function EditJsonMCPserver({
style={{ style={{
fontFamily: 'ui-monospace', fontFamily: 'ui-monospace',
backgroundColor: 'transparent', backgroundColor: 'transparent',
wordBreak: 'break-all',
overflowWrap: 'anywhere',
whiteSpace: 'pre-wrap',
}} }}
className="w-full !text-sm " className="w-full !text-sm overflow-hidden break-all"
/> />
</div> </div>
{error && <div className="text-destructive text-sm">{error}</div>} {error && <div className="text-destructive text-sm">{error}</div>}

View File

@ -1,253 +0,0 @@
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { Switch } from '@/components/ui/switch'
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '@/components/ui/tooltip'
import { useModelProvider } from '@/hooks/useModelProvider'
import {
IconPencil,
IconEye,
IconTool,
// IconWorld,
// IconAtom,
IconCodeCircle2,
} from '@tabler/icons-react'
import { useState, useEffect } from 'react'
import { useTranslation } from '@/i18n/react-i18next-compat'
// No need to define our own interface, we'll use the existing Model type
type DialogEditModelProps = {
provider: ModelProvider
modelId?: string // Optional model ID to edit
}
export const DialogEditModel = ({
provider,
modelId,
}: DialogEditModelProps) => {
const { t } = useTranslation()
const { updateProvider } = useModelProvider()
const [selectedModelId, setSelectedModelId] = useState<string>('')
const [capabilities, setCapabilities] = useState<Record<string, boolean>>({
completion: false,
vision: false,
tools: false,
reasoning: false,
embeddings: false,
web_search: false,
})
// Initialize with the provided model ID or the first model if available
useEffect(() => {
if (modelId) {
setSelectedModelId(modelId)
} else if (provider.models && provider.models.length > 0) {
setSelectedModelId(provider.models[0].id)
}
}, [provider, modelId])
// Get the currently selected model
const selectedModel = provider.models.find(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(m: any) => m.id === selectedModelId
)
// Initialize capabilities from selected model
useEffect(() => {
if (selectedModel) {
const modelCapabilities = selectedModel.capabilities || []
setCapabilities({
completion: modelCapabilities.includes('completion'),
vision: modelCapabilities.includes('vision'),
tools: modelCapabilities.includes('tools'),
embeddings: modelCapabilities.includes('embeddings'),
web_search: modelCapabilities.includes('web_search'),
reasoning: modelCapabilities.includes('reasoning'),
})
}
}, [selectedModel])
// Track if capabilities were updated by user action
const [capabilitiesUpdated, setCapabilitiesUpdated] = useState(false)
// Update model capabilities - only update local state
const handleCapabilityChange = (capability: string, enabled: boolean) => {
setCapabilities((prev) => ({
...prev,
[capability]: enabled,
}))
// Mark that capabilities were updated by user action
setCapabilitiesUpdated(true)
}
// Use effect to update the provider when capabilities are explicitly changed by user
useEffect(() => {
// Only run if capabilities were updated by user action and we have a selected model
if (!capabilitiesUpdated || !selectedModel) return
// Reset the flag
setCapabilitiesUpdated(false)
// Create updated capabilities array from the state
const updatedCapabilities = Object.entries(capabilities)
.filter(([, isEnabled]) => isEnabled)
.map(([capName]) => capName)
// Find and update the model in the provider
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const updatedModels = provider.models.map((m: any) => {
if (m.id === selectedModelId) {
return {
...m,
capabilities: updatedCapabilities,
}
}
return m
})
// Update the provider with the updated models
updateProvider(provider.provider, {
...provider,
models: updatedModels,
})
}, [
capabilitiesUpdated,
capabilities,
provider,
selectedModel,
selectedModelId,
updateProvider,
])
if (!selectedModel) {
return null
}
return (
<Dialog>
<DialogTrigger asChild>
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
<IconPencil size={18} className="text-main-view-fg/50" />
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle className="line-clamp-1" title={selectedModel.id}>
{t('providers:editModel.title', { modelId: selectedModel.id })}
</DialogTitle>
<DialogDescription>
{t('providers:editModel.description')}
</DialogDescription>
</DialogHeader>
<div className="py-1">
<h3 className="text-sm font-medium mb-3">
{t('providers:editModel.capabilities')}
</h3>
<div className="space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconTool className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.tools')}
</span>
</div>
<Switch
id="tools-capability"
checked={capabilities.tools}
onCheckedChange={(checked) =>
handleCapabilityChange('tools', checked)
}
/>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconEye className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.vision')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="vision-capability"
checked={capabilities.vision}
disabled={true}
onCheckedChange={(checked) =>
handleCapabilityChange('vision', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconCodeCircle2 className="size-4 text-main-view-fg/70" />
<span className="text-sm">
{t('providers:editModel.embeddings')}
</span>
</div>
<Tooltip>
<TooltipTrigger>
<Switch
id="embedding-capability"
disabled={true}
checked={capabilities.embeddings}
onCheckedChange={(checked) =>
handleCapabilityChange('embeddings', checked)
}
/>
</TooltipTrigger>
<TooltipContent>
{t('providers:editModel.notAvailable')}
</TooltipContent>
</Tooltip>
</div>
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconWorld className="size-4 text-main-view-fg/70" />
<span className="text-sm">Web Search</span>
</div>
<Switch
id="web_search-capability"
checked={capabilities.web_search}
onCheckedChange={(checked) =>
handleCapabilityChange('web_search', checked)
}
/>
</div> */}
{/* <div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<IconAtom className="size-4 text-main-view-fg/70" />
<span className="text-sm">{t('reasoning')}</span>
</div>
<Switch
id="reasoning-capability"
checked={capabilities.reasoning}
onCheckedChange={(checked) =>
handleCapabilityChange('reasoning', checked)
}
/>
</div> */}
</div>
</div>
</DialogContent>
</Dialog>
)
}

View File

@ -0,0 +1,123 @@
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Button } from '@/components/ui/button'
import { AlertTriangle, ChevronDown, ChevronRight } from 'lucide-react'
import { IconCopy, IconCopyCheck } from '@tabler/icons-react'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { toast } from 'sonner'
import { useState } from 'react'
import { useAppState } from '@/hooks/useAppState'
export default function ErrorDialog() {
const { t } = useTranslation()
const { errorMessage, setErrorMessage } = useAppState()
const [isCopying, setIsCopying] = useState(false)
const [isDetailExpanded, setIsDetailExpanded] = useState(true)
const handleCopy = async () => {
setIsCopying(true)
try {
await navigator.clipboard.writeText(errorMessage?.message ?? '')
toast.success('Copy successful', {
id: 'copy-model',
description: 'Model load error information copied to clipboard',
})
} catch {
toast.error('Failed to copy', {
id: 'copy-model-error',
description: 'Failed to copy error information to clipboard',
})
} finally {
setTimeout(() => setIsCopying(false), 2000)
}
}
const handleDialogOpen = (open: boolean) => {
setErrorMessage(open ? errorMessage : undefined)
}
return (
<Dialog open={!!errorMessage} onOpenChange={handleDialogOpen}>
<DialogContent showCloseButton={false}>
<DialogHeader>
<div className="flex items-start gap-3">
<div className="shrink-0">
<AlertTriangle className="size-4 text-destructive" />
</div>
<div>
<DialogTitle>{t('common:error')}</DialogTitle>
<DialogDescription className="mt-1 text-main-view-fg/70">
{errorMessage?.title ?? 'Something went wrong'}
</DialogDescription>
</div>
</div>
</DialogHeader>
<div className="bg-main-view-fg/2 p-2 border border-main-view-fg/5 rounded-lg space-y-2">
<div>
<button
onClick={() => setIsDetailExpanded(!isDetailExpanded)}
className="flex items-center gap-1 text-sm text-main-view-fg/60 hover:text-main-view-fg/80 transition-colors cursor-pointer"
>
{isDetailExpanded ? (
<ChevronDown className="size-3" />
) : (
<ChevronRight className="size-3" />
)}
Details
</button>
{isDetailExpanded && (
<div
className="mt-2 text-sm text-main-view-fg/70 leading-relaxed max-h-[150px] overflow-y-auto break-all bg-main-view-fg/10 p-2 rounded border border-main-view-fg/5"
ref={(el) => {
if (el) {
el.scrollTop = el.scrollHeight
}
}}
>
{errorMessage?.message}
</div>
)}
</div>
<span className="text-sm text-main-view-fg/60">{errorMessage?.subtitle}</span>
</div>
<DialogFooter className="flex flex-col gap-2 sm:flex-row sm:justify-right">
<Button
variant="link"
onClick={() => handleDialogOpen(false)}
className="flex-1 text-right sm:flex-none"
>
{t('common:cancel')}
</Button>
<Button
variant="link"
onClick={() => handleCopy()}
disabled={isCopying}
autoFocus
className="flex-1 text-right sm:flex-none border border-main-view-fg/20 !px-2"
>
{isCopying ? (
<>
<IconCopyCheck className="text-accent" />
{t('common:copied')}
</>
) : (
<>
<IconCopy />
{t('common:copy')}
</>
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
)
}

View File

@ -0,0 +1,180 @@
import { renderHook, act } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { SystemEvent } from '@/types/events'
// Mock functions
const mockGetTools = vi.fn()
const mockUpdateTools = vi.fn()
const mockListen = vi.fn()
const mockUnsubscribe = vi.fn()
// Mock the dependencies
vi.mock('@/services/mcp', () => ({
getTools: mockGetTools,
}))
vi.mock('../useAppState', () => ({
useAppState: () => ({
updateTools: mockUpdateTools,
}),
}))
vi.mock('@tauri-apps/api/event', () => ({
listen: mockListen,
}))
describe('useTools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockListen.mockResolvedValue(mockUnsubscribe)
mockGetTools.mockResolvedValue([])
})
afterEach(() => {
vi.restoreAllMocks()
})
it('should call getTools and updateTools on mount', async () => {
const { useTools } = await import('../useTools')
const mockTools = [
{ name: 'test-tool', description: 'A test tool' },
{ name: 'another-tool', description: 'Another test tool' },
]
mockGetTools.mockResolvedValue(mockTools)
renderHook(() => useTools())
// Wait for async operations to complete
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
})
it('should set up event listener for MCP_UPDATE', async () => {
const { useTools } = await import('../useTools')
renderHook(() => useTools())
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockListen).toHaveBeenCalledWith(
SystemEvent.MCP_UPDATE,
expect.any(Function)
)
})
it('should call setTools when MCP_UPDATE event is triggered', async () => {
const { useTools } = await import('../useTools')
const mockTools = [{ name: 'updated-tool', description: 'Updated tool' }]
mockGetTools.mockResolvedValue(mockTools)
let eventCallback: () => void
mockListen.mockImplementation((_event, callback) => {
eventCallback = callback
return Promise.resolve(mockUnsubscribe)
})
renderHook(() => useTools())
// Wait for initial setup
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
// Clear the initial calls
vi.clearAllMocks()
mockGetTools.mockResolvedValue(mockTools)
// Trigger the event
await act(async () => {
eventCallback()
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockUpdateTools).toHaveBeenCalledWith(mockTools)
})
it('should return unsubscribe function for cleanup', async () => {
const { useTools } = await import('../useTools')
const { unmount } = renderHook(() => useTools())
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(mockListen).toHaveBeenCalled()
// Unmount should call the unsubscribe function
unmount()
expect(mockListen).toHaveBeenCalledWith(
SystemEvent.MCP_UPDATE,
expect.any(Function)
)
})
it('should handle getTools errors gracefully', async () => {
const { useTools } = await import('../useTools')
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
mockGetTools.mockRejectedValue(new Error('Failed to get tools'))
renderHook(() => useTools())
await act(async () => {
// Give enough time for the promise to be handled
await new Promise(resolve => setTimeout(resolve, 100))
})
expect(mockGetTools).toHaveBeenCalledTimes(1)
// updateTools should not be called if getTools fails
expect(mockUpdateTools).not.toHaveBeenCalled()
consoleErrorSpy.mockRestore()
})
it('should handle event listener setup errors gracefully', async () => {
const { useTools } = await import('../useTools')
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
mockListen.mockRejectedValue(new Error('Failed to set up listener'))
renderHook(() => useTools())
await act(async () => {
// Give enough time for the promise to be handled
await new Promise(resolve => setTimeout(resolve, 100))
})
// Initial getTools should still work
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalled()
consoleErrorSpy.mockRestore()
})
it('should only set up effect once with empty dependency array', async () => {
const { useTools } = await import('../useTools')
const { rerender } = renderHook(() => useTools())
// Initial render
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalledTimes(1)
// Rerender should not trigger additional calls
rerender()
expect(mockGetTools).toHaveBeenCalledTimes(1)
expect(mockListen).toHaveBeenCalledTimes(1)
})
})

View File

@ -4,6 +4,12 @@ import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant' import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
type AppErrorMessage = {
message?: string
title?: string
subtitle: string
}
type AppState = { type AppState = {
streamingContent?: ThreadMessage streamingContent?: ThreadMessage
loadingModel?: boolean loadingModel?: boolean
@ -13,6 +19,8 @@ type AppState = {
tokenSpeed?: TokenSpeed tokenSpeed?: TokenSpeed
currentToolCall?: ChatCompletionMessageToolCall currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean showOutOfContextDialog?: boolean
errorMessage?: AppErrorMessage
cancelToolCall?: () => void
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void updateStreamingContent: (content: ThreadMessage | undefined) => void
updateCurrentToolCall: ( updateCurrentToolCall: (
@ -24,6 +32,8 @@ type AppState = {
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
resetTokenSpeed: () => void resetTokenSpeed: () => void
setOutOfContextDialog: (show: boolean) => void setOutOfContextDialog: (show: boolean) => void
setCancelToolCall: (cancel: (() => void) | undefined) => void
setErrorMessage: (error: AppErrorMessage | undefined) => void
} }
export const useAppState = create<AppState>()((set) => ({ export const useAppState = create<AppState>()((set) => ({
@ -34,6 +44,7 @@ export const useAppState = create<AppState>()((set) => ({
abortControllers: {}, abortControllers: {},
tokenSpeed: undefined, tokenSpeed: undefined,
currentToolCall: undefined, currentToolCall: undefined,
cancelToolCall: undefined,
updateStreamingContent: (content: ThreadMessage | undefined) => { updateStreamingContent: (content: ThreadMessage | undefined) => {
const assistants = useAssistant.getState().assistants const assistants = useAssistant.getState().assistants
const currentAssistant = useAssistant.getState().currentAssistant const currentAssistant = useAssistant.getState().currentAssistant
@ -112,4 +123,14 @@ export const useAppState = create<AppState>()((set) => ({
showOutOfContextDialog: show, showOutOfContextDialog: show,
})) }))
}, },
setCancelToolCall: (cancel) => {
set(() => ({
cancelToolCall: cancel,
}))
},
setErrorMessage: (error) => {
set(() => ({
errorMessage: error,
}))
},
})) }))

View File

@ -1,4 +1,4 @@
import { useCallback, useEffect, useMemo } from 'react' import { useCallback, useMemo } from 'react'
import { usePrompt } from './usePrompt' import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider' import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads' import { useThreads } from './useThreads'
@ -19,10 +19,7 @@ import {
import { CompletionMessagesBuilder } from '@/lib/messages' import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant' import { useAssistant } from './useAssistant'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models' import { stopModel, startModel, stopAllModels } from '@/services/models'
import { useToolApproval } from '@/hooks/useToolApproval' import { useToolApproval } from '@/hooks/useToolApproval'
@ -32,6 +29,10 @@ import { updateSettings } from '@/services/providers'
import { useContextSizeApproval } from './useModelContextApproval' import { useContextSizeApproval } from './useModelContextApproval'
import { useModelLoad } from './useModelLoad' import { useModelLoad } from './useModelLoad'
import { useGeneralSetting } from './useGeneralSetting' import { useGeneralSetting } from './useGeneralSetting'
import {
ReasoningProcessor,
extractReasoningFromMessage,
} from '@/utils/reasoning'
export const useChat = () => { export const useChat = () => {
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
@ -40,7 +41,6 @@ export const useChat = () => {
tools, tools,
updateTokenSpeed, updateTokenSpeed,
resetTokenSpeed, resetTokenSpeed,
updateTools,
updateStreamingContent, updateStreamingContent,
updateLoadingModel, updateLoadingModel,
setAbortController, setAbortController,
@ -77,22 +77,6 @@ export const useChat = () => {
const selectedAssistant = const selectedAssistant =
assistants.find((a) => a.id === currentAssistant.id) || assistants[0] assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
const getCurrentThread = useCallback(async () => { const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread() let currentThread = retrieveThread()
@ -219,7 +203,17 @@ export const useChat = () => {
) )
const sendMessage = useCallback( const sendMessage = useCallback(
async (message: string, troubleshooting = true) => { async (
message: string,
troubleshooting = true,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
) => {
const activeThread = await getCurrentThread() const activeThread = await getCurrentThread()
resetTokenSpeed() resetTokenSpeed()
@ -233,7 +227,7 @@ export const useChat = () => {
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent)
// Do not add new message on retry // Do not add new message on retry
if (troubleshooting) if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message)) addMessage(newUserThreadContent(activeThread.id, message, attachments))
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id)
setPrompt('') setPrompt('')
try { try {
@ -247,7 +241,7 @@ export const useChat = () => {
messages, messages,
currentAssistant?.instructions currentAssistant?.instructions
) )
if (troubleshooting) builder.addUserMessage(message) if (troubleshooting) builder.addUserMessage(message, attachments)
let isCompleted = false let isCompleted = false
@ -260,8 +254,8 @@ export const useChat = () => {
}) })
: [] : []
// TODO: Later replaced by Agent setup? let assistantLoopSteps = 0
const followUpWithToolUse = true
while ( while (
!isCompleted && !isCompleted &&
!abortController.signal.aborted && !abortController.signal.aborted &&
@ -270,6 +264,7 @@ export const useChat = () => {
const modelConfig = activeProvider.models.find( const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id (m) => m.id === selectedModel?.id
) )
assistantLoopSteps += 1
const modelSettings = modelConfig?.settings const modelSettings = modelConfig?.settings
? Object.fromEntries( ? Object.fromEntries(
@ -305,18 +300,27 @@ export const useChat = () => {
const toolCalls: ChatCompletionMessageToolCall[] = [] const toolCalls: ChatCompletionMessageToolCall[] = []
try { try {
if (isCompletionResponse(completion)) { if (isCompletionResponse(completion)) {
accumulatedText = const message = completion.choices[0]?.message
(completion.choices[0]?.message?.content as string) || '' accumulatedText = (message?.content as string) || ''
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls) // Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
}
if (message?.tool_calls) {
toolCalls.push(...message.tool_calls)
} }
} else { } else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false let rafScheduled = false
let rafHandle: number | undefined let rafHandle: number | undefined
let pendingDeltaCount = 0 let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const scheduleFlush = () => { const scheduleFlush = () => {
if (rafScheduled) return if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true rafScheduled = true
const doSchedule = (cb: () => void) => { const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') { if (typeof requestAnimationFrame !== 'undefined') {
@ -328,6 +332,12 @@ export const useChat = () => {
} }
} }
doSchedule(() => { doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent( const currentContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText, accumulatedText,
@ -348,7 +358,10 @@ export const useChat = () => {
} }
const flushIfPending = () => { const flushIfPending = () => {
if (!rafScheduled) return if (!rafScheduled) return
if (typeof cancelAnimationFrame !== 'undefined' && rafHandle !== undefined) { if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle) cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) { } else if (rafHandle !== undefined) {
clearTimeout(rafHandle) clearTimeout(rafHandle)
@ -371,31 +384,63 @@ export const useChat = () => {
pendingDeltaCount = 0 pendingDeltaCount = 0
rafScheduled = false rafScheduled = false
} }
for await (const part of completion) { try {
// Error message for await (const part of completion) {
if (!part.choices) { // Check if aborted before processing each part
throw new Error( if (abortController.signal.aborted) {
'message' in part break
? (part.message as string) }
: (JSON.stringify(part) ?? '')
)
}
const delta = part.choices[0]?.delta?.content || ''
if (part.choices[0]?.delta?.tool_calls) { // Error message
extractToolCall(part, currentCall, toolCalls) if (!part.choices) {
// Schedule a flush to reflect tool update throw new Error(
scheduleFlush() 'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
} }
if (delta) { } finally {
accumulatedText += delta // Always clean up scheduled RAF when stream ends (either normally or via abort)
pendingDeltaCount += 1 if (rafHandle !== undefined) {
// Batch UI update on next animation frame if (typeof cancelAnimationFrame !== 'undefined') {
scheduleFlush() cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
} }
} }
// Ensure any pending buffered content is rendered at the end
flushIfPending()
} }
} catch (error) { } catch (error) {
const errorMessage = const errorMessage =
@ -464,7 +509,11 @@ export const useChat = () => {
isCompleted = !toolCalls.length isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it // Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = [] // Check if assistant loop steps are within limits
if (assistantLoopSteps >= (currentAssistant?.tool_steps ?? 20)) {
// Stop the assistant tool call if it exceeds the maximum steps
availableTools = []
}
} }
} catch (error) { } catch (error) {
if (!abortController.signal.aborted) { if (!abortController.signal.aborted) {

View File

@ -0,0 +1,53 @@
import { create } from 'zustand'
import { persist, createJSONStorage } from 'zustand/middleware'
import { localStorageKey } from '@/constants/localStorage'
interface FavoriteModelState {
favoriteModels: Model[]
addFavorite: (model: Model) => void
removeFavorite: (modelId: string) => void
isFavorite: (modelId: string) => boolean
toggleFavorite: (model: Model) => void
}
export const useFavoriteModel = create<FavoriteModelState>()(
persist(
(set, get) => ({
favoriteModels: [],
addFavorite: (model: Model) => {
set((state) => {
if (!state.favoriteModels.some((fav) => fav.id === model.id)) {
return {
favoriteModels: [...state.favoriteModels, model],
}
}
return state
})
},
removeFavorite: (modelId: string) => {
set((state) => ({
favoriteModels: state.favoriteModels.filter((model) => model.id !== modelId),
}))
},
isFavorite: (modelId: string) => {
return get().favoriteModels.some((model) => model.id === modelId)
},
toggleFavorite: (model: Model) => {
const { isFavorite, addFavorite, removeFavorite } = get()
if (isFavorite(model.id)) {
removeFavorite(model.id)
} else {
addFavorite(model)
}
},
}),
{
name: localStorageKey.favoriteModels,
storage: createJSONStorage(() => localStorage),
}
)
)

View File

@ -7,6 +7,10 @@ export type MCPServerConfig = {
args: string[] args: string[]
env: Record<string, string> env: Record<string, string>
active?: boolean active?: boolean
type?: 'stdio' | 'http' | 'sse'
url?: string
headers?: Record<string, string>
timeout?: number
} }
// Define the structure of all MCP servers // Define the structure of all MCP servers

View File

@ -0,0 +1,31 @@
import { useEffect } from 'react'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { useAppState } from './useAppState'
export const useTools = () => {
const { updateTools } = useAppState()
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
}).catch((error) => {
console.error('Failed to fetch MCP tools:', error)
})
}
setTools()
let unsubscribe = () => {}
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
}).catch((error) => {
console.error('Failed to set up MCP update listener:', error)
})
return unsubscribe
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
}

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { import {
ContentType, ContentType,
ChatCompletionRole, ChatCompletionRole,
@ -31,8 +32,9 @@ import { ulid } from 'ulidx'
import { MCPTool } from '@/types/completion' import { MCPTool } from '@/types/completion'
import { CompletionMessagesBuilder } from './messages' import { CompletionMessagesBuilder } from './messages'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { callTool } from '@/services/mcp' import { callToolWithCancellation } from '@/services/mcp'
import { ExtensionManager } from './extension' import { ExtensionManager } from './extension'
import { useAppState } from '@/hooks/useAppState'
export type ChatCompletionResponse = export type ChatCompletionResponse =
| chatCompletion | chatCompletion
@ -50,11 +52,16 @@ export type ChatCompletionResponse =
*/ */
export const newUserThreadContent = ( export const newUserThreadContent = (
threadId: string, threadId: string,
content: string content: string,
): ThreadMessage => ({ attachments?: Array<{
type: 'text', name: string
role: ChatCompletionRole.User, type: string
content: [ size: number
base64: string
dataUrl: string
}>
): ThreadMessage => {
const contentParts = [
{ {
type: ContentType.Text, type: ContentType.Text,
text: { text: {
@ -62,14 +69,35 @@ export const newUserThreadContent = (
annotations: [], annotations: [],
}, },
}, },
], ]
id: ulid(),
object: 'thread.message', // Add attachments to content array
thread_id: threadId, if (attachments) {
status: MessageStatus.Ready, attachments.forEach((attachment) => {
created_at: 0, if (attachment.type.startsWith('image/')) {
completed_at: 0, contentParts.push({
}) type: ContentType.Image,
image_url: {
url: `data:${attachment.type};base64,${attachment.base64}`,
detail: 'auto',
},
} as any)
}
})
}
return {
type: 'text',
role: ChatCompletionRole.User,
content: contentParts,
id: ulid(),
object: 'thread.message',
thread_id: threadId,
status: MessageStatus.Ready,
created_at: 0,
completed_at: 0,
}
}
/** /**
* @fileoverview Helper functions for creating thread content. * @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects * These functions are used to create thread content objects
@ -161,13 +189,11 @@ export const sendCompletion = async (
if ( if (
thread.model.id && thread.model.id &&
!Object.values(models[providerName]).flat().includes(thread.model.id) && !Object.values(models[providerName]).flat().includes(thread.model.id) &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
!tokenJS.extendedModelExist(providerName as any, thread.model.id) && !tokenJS.extendedModelExist(providerName as any, thread.model.id) &&
provider.provider !== 'llamacpp' provider.provider !== 'llamacpp'
) { ) {
try { try {
tokenJS.extendModelList( tokenJS.extendModelList(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
providerName as any, providerName as any,
thread.model.id, thread.model.id,
// This is to inherit the model capabilities from another built-in model // This is to inherit the model capabilities from another built-in model
@ -200,7 +226,7 @@ export const sendCompletion = async (
? await tokenJS.chat.completions.create( ? await tokenJS.chat.completions.create(
{ {
stream: true, stream: true,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
provider: providerName as any, provider: providerName as any,
model: thread.model?.id, model: thread.model?.id,
messages, messages,
@ -381,13 +407,17 @@ export const postMessageProcessing = async (
) )
: true) : true)
const { promise, cancel } = callToolWithCancellation({
toolName: toolCall.function.name,
arguments: toolCall.function.arguments.length
? JSON.parse(toolCall.function.arguments)
: {},
})
useAppState.getState().setCancelToolCall(cancel)
let result = approved let result = approved
? await callTool({ ? await promise.catch((e) => {
toolName: toolCall.function.name,
arguments: toolCall.function.arguments.length
? JSON.parse(toolCall.function.arguments)
: {},
}).catch((e) => {
console.error('Tool call failed:', e) console.error('Tool call failed:', e)
return { return {
content: [ content: [

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageParam } from 'token.js'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
@ -19,32 +20,106 @@ export class CompletionMessagesBuilder {
this.messages.push( this.messages.push(
...messages ...messages
.filter((e) => !e.metadata?.error) .filter((e) => !e.metadata?.error)
.map<ChatCompletionMessageParam>( .map<ChatCompletionMessageParam>((msg) => {
(msg) => if (msg.role === 'assistant') {
({ return {
role: msg.role, role: msg.role,
content: content: this.normalizeContent(
msg.role === 'assistant' msg.content[0]?.text?.value || '.'
? this.normalizeContent(msg.content[0]?.text?.value || '.') ),
: msg.content[0]?.text?.value || '.', } as ChatCompletionMessageParam
}) as ChatCompletionMessageParam } else {
) // For user messages, handle multimodal content
if (msg.content.length > 1) {
// Multiple content parts (text + images + files)
const content = msg.content.map((contentPart) => {
if (contentPart.type === 'text') {
return {
type: 'text',
text: contentPart.text?.value || '',
}
} else if (contentPart.type === 'image_url') {
return {
type: 'image_url',
image_url: {
url: contentPart.image_url?.url || '',
detail: contentPart.image_url?.detail || 'auto',
},
}
} else {
return contentPart
}
})
return {
role: msg.role,
content,
} as ChatCompletionMessageParam
} else {
// Single text content
return {
role: msg.role,
content: msg.content[0]?.text?.value || '.',
} as ChatCompletionMessageParam
}
}
})
) )
} }
/** /**
* Add a user message to the messages array. * Add a user message to the messages array.
* @param content - The content of the user message. * @param content - The content of the user message.
* @param attachments - Optional attachments for the message.
*/ */
addUserMessage(content: string) { addUserMessage(
content: string,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
) {
// Ensure no consecutive user messages // Ensure no consecutive user messages
if (this.messages[this.messages.length - 1]?.role === 'user') { if (this.messages[this.messages.length - 1]?.role === 'user') {
this.messages.pop() this.messages.pop()
} }
this.messages.push({
role: 'user', // Handle multimodal content with attachments
content: content, if (attachments && attachments.length > 0) {
}) const messageContent: any[] = [
{
type: 'text',
text: content,
},
]
// Add attachments (images and PDFs)
attachments.forEach((attachment) => {
if (attachment.type.startsWith('image/')) {
messageContent.push({
type: 'image_url',
image_url: {
url: `data:${attachment.type};base64,${attachment.base64}`,
detail: 'auto',
},
})
}
})
this.messages.push({
role: 'user',
content: messageContent,
} as any)
} else {
// Text-only message
this.messages.push({
role: 'user',
content: content,
})
}
} }
/** /**

View File

@ -5,6 +5,7 @@ export const AppRoutes = [
'installExtensions', 'installExtensions',
'getTools', 'getTools',
'callTool', 'callTool',
'cancelToolCall',
'listThreads', 'listThreads',
'createThread', 'createThread',
'modifyThread', 'modifyThread',

View File

@ -29,5 +29,6 @@
"save": "Speichern", "save": "Speichern",
"createNew": "Neuen Assistenten anlegen", "createNew": "Neuen Assistenten anlegen",
"personality": "Persönlichkeit", "personality": "Persönlichkeit",
"capabilities": "Fähigkeiten" "capabilities": "Fähigkeiten",
"maxToolSteps": "Maximale Werkzeugschritte"
} }

View File

@ -29,5 +29,6 @@
"save": "Save", "save": "Save",
"createNew": "Create New Assistant", "createNew": "Create New Assistant",
"personality": "Personality", "personality": "Personality",
"capabilities": "Capabilities" "capabilities": "Capabilities",
"maxToolSteps": "Max tool steps"
} }

View File

@ -29,5 +29,6 @@
"save": "Simpan", "save": "Simpan",
"createNew": "Buat Asisten Baru", "createNew": "Buat Asisten Baru",
"personality": "Kepribadian", "personality": "Kepribadian",
"capabilities": "Kemampuan" "capabilities": "Kemampuan",
"maxToolSteps": "Langkah alat maksimum"
} }

View File

@ -29,5 +29,6 @@
"save": "Lưu", "save": "Lưu",
"createNew": "Tạo Trợ lý Mới", "createNew": "Tạo Trợ lý Mới",
"personality": "Tính cách", "personality": "Tính cách",
"capabilities": "Khả năng" "capabilities": "Khả năng",
"maxToolSteps": "Bước tối đa của công cụ"
} }

View File

@ -29,5 +29,6 @@
"save": "保存", "save": "保存",
"createNew": "创建新助手", "createNew": "创建新助手",
"personality": "个性", "personality": "个性",
"capabilities": "能力" "capabilities": "能力",
"maxToolSteps": "最大工具步骤"
} }

View File

@ -29,5 +29,6 @@
"save": "儲存", "save": "儲存",
"createNew": "建立新助理", "createNew": "建立新助理",
"personality": "個性", "personality": "個性",
"capabilities": "能力" "capabilities": "能力",
"maxToolSteps": "最大工具步驟"
} }

View File

@ -17,6 +17,7 @@ import {
import { useNavigate } from '@tanstack/react-router' import { useNavigate } from '@tanstack/react-router'
import { route } from '@/constants/routes' import { route } from '@/constants/routes'
import { useThreads } from '@/hooks/useThreads' import { useThreads } from '@/hooks/useThreads'
import { AppEvent, events } from '@janhq/core'
export function DataProvider() { export function DataProvider() {
const { setProviders } = useModelProvider() const { setProviders } = useModelProvider()
@ -70,6 +71,13 @@ export function DataProvider() {
} }
}, [checkForUpdate]) }, [checkForUpdate])
useEffect(() => {
events.on(AppEvent.onModelImported, () => {
getProviders().then(setProviders)
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const handleDeepLink = (urls: string[] | null) => { const handleDeepLink = (urls: string[] | null) => {
if (!urls) return if (!urls) return
console.log('Received deeplink:', urls) console.log('Received deeplink:', urls)

View File

@ -26,9 +26,10 @@ import {
ResizablePanel, ResizablePanel,
ResizableHandle, ResizableHandle,
} from '@/components/ui/resizable' } from '@/components/ui/resizable'
import { useCallback } from 'react' import { useCallback, useEffect } from 'react'
import GlobalError from '@/containers/GlobalError' import GlobalError from '@/containers/GlobalError'
import { GlobalEventHandler } from '@/providers/GlobalEventHandler' import { GlobalEventHandler } from '@/providers/GlobalEventHandler'
import ErrorDialog from '@/containers/dialogs/ErrorDialog'
export const Route = createRootRoute({ export const Route = createRootRoute({
component: RootLayout, component: RootLayout,
@ -65,6 +66,41 @@ const AppLayout = () => {
[setLeftPanelSize, setLeftPanel] [setLeftPanelSize, setLeftPanel]
) )
// Prevent default drag and drop behavior globally
useEffect(() => {
const preventDefaults = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
}
const handleGlobalDrop = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only prevent if the target is not within a chat input or other valid drop zone
const target = e.target as Element
const isValidDropZone = target?.closest('[data-drop-zone="true"]') ||
target?.closest('.chat-input-drop-zone') ||
target?.closest('[data-tauri-drag-region]')
if (!isValidDropZone) {
// Prevent the file from opening in the window
return false
}
}
// Add event listeners to prevent default drag/drop behavior
window.addEventListener('dragenter', preventDefaults)
window.addEventListener('dragover', preventDefaults)
window.addEventListener('drop', handleGlobalDrop)
return () => {
window.removeEventListener('dragenter', preventDefaults)
window.removeEventListener('dragover', preventDefaults)
window.removeEventListener('drop', handleGlobalDrop)
}
}, [])
return ( return (
<Fragment> <Fragment>
<AnalyticProvider /> <AnalyticProvider />
@ -168,6 +204,7 @@ function RootLayout() {
{/* <TanStackRouterDevtools position="bottom-right" /> */} {/* <TanStackRouterDevtools position="bottom-right" /> */}
<ToolApproval /> <ToolApproval />
<LoadModelErrorDialog /> <LoadModelErrorDialog />
<ErrorDialog />
<OutOfContextPromiseModal /> <OutOfContextPromiseModal />
</TranslationProvider> </TranslationProvider>
</Fragment> </Fragment>

View File

@ -17,8 +17,20 @@ import { useModelProvider } from '@/hooks/useModelProvider'
import { Card, CardItem } from '@/containers/Card' import { Card, CardItem } from '@/containers/Card'
import { RenderMarkdown } from '@/containers/RenderMarkdown' import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { extractModelName, extractDescription } from '@/lib/models' import { extractModelName, extractDescription } from '@/lib/models'
import { IconDownload, IconFileCode, IconSearch } from '@tabler/icons-react' import {
IconDownload,
IconFileCode,
IconEye,
IconSearch,
IconTool,
} from '@tabler/icons-react'
import { Switch } from '@/components/ui/switch' import { Switch } from '@/components/ui/switch'
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from '@/components/ui/tooltip'
import Joyride, { CallBackProps, STATUS } from 'react-joyride' import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import { import {
@ -133,18 +145,24 @@ function Hub() {
if (debouncedSearchValue.length) { if (debouncedSearchValue.length) {
const fuse = new Fuse(filtered, searchOptions) const fuse = new Fuse(filtered, searchOptions)
// Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model") // Remove domain from search value (e.g., "huggingface.co/author/model" -> "author/model")
const cleanedSearchValue = debouncedSearchValue.replace(/^https?:\/\/[^/]+\//, '') const cleanedSearchValue = debouncedSearchValue.replace(
/^https?:\/\/[^/]+\//,
''
)
filtered = fuse.search(cleanedSearchValue).map((result) => result.item) filtered = fuse.search(cleanedSearchValue).map((result) => result.item)
} }
// Apply downloaded filter // Apply downloaded filter
if (showOnlyDownloaded) { if (showOnlyDownloaded) {
filtered = filtered?.filter((model) => filtered = filtered
model.quants.some((variant) => ?.map((model) => ({
llamaProvider?.models.some( ...model,
(m: { id: string }) => m.id === variant.model_id quants: model.quants.filter((variant) =>
) llamaProvider?.models.some(
) (m: { id: string }) => m.id === variant.model_id
) )
),
}))
.filter((model) => model.quants.length > 0)
} }
// Add HuggingFace repo at the beginning if available // Add HuggingFace repo at the beginning if available
if (huggingFaceRepo) { if (huggingFaceRepo) {
@ -419,43 +437,44 @@ function Hub() {
const isLastStep = currentStepIndex === steps.length - 1 const isLastStep = currentStepIndex === steps.length - 1
const renderFilter = () => { const renderFilter = () => {
return ( if (searchValue.length === 0)
<> return (
<DropdownMenu> <>
<DropdownMenuTrigger> <DropdownMenu>
<span className="flex cursor-pointer items-center gap-1 px-2 py-1 rounded-sm bg-main-view-fg/15 text-sm outline-none text-main-view-fg font-medium"> <DropdownMenuTrigger>
{ <span className="flex cursor-pointer items-center gap-1 px-2 py-1 rounded-sm bg-main-view-fg/15 text-sm outline-none text-main-view-fg font-medium">
sortOptions.find((option) => option.value === sortSelected) {
?.name sortOptions.find((option) => option.value === sortSelected)
} ?.name
}
</span>
</DropdownMenuTrigger>
<DropdownMenuContent side="bottom" align="end">
{sortOptions.map((option) => (
<DropdownMenuItem
className={cn(
'cursor-pointer my-0.5',
sortSelected === option.value && 'bg-main-view-fg/5'
)}
key={option.value}
onClick={() => setSortSelected(option.value)}
>
{option.name}
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
<div className="flex items-center gap-2">
<Switch
checked={showOnlyDownloaded}
onCheckedChange={setShowOnlyDownloaded}
/>
<span className="text-xs text-main-view-fg/70 font-medium whitespace-nowrap">
{t('hub:downloaded')}
</span> </span>
</DropdownMenuTrigger> </div>
<DropdownMenuContent side="bottom" align="end"> </>
{sortOptions.map((option) => ( )
<DropdownMenuItem
className={cn(
'cursor-pointer my-0.5',
sortSelected === option.value && 'bg-main-view-fg/5'
)}
key={option.value}
onClick={() => setSortSelected(option.value)}
>
{option.name}
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
<div className="flex items-center gap-2">
<Switch
checked={showOnlyDownloaded}
onCheckedChange={setShowOnlyDownloaded}
/>
<span className="text-xs text-main-view-fg/70 font-medium whitespace-nowrap">
{t('hub:downloaded')}
</span>
</div>
</>
)
} }
return ( return (
@ -647,6 +666,45 @@ function Hub() {
?.length || 0} ?.length || 0}
</span> </span>
</div> </div>
{filteredModels[virtualItem.index].tools && (
<div className="flex items-center gap-1">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div>
<IconTool
size={17}
className="text-main-view-fg/50"
/>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{t('tools')}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
)}
{filteredModels[virtualItem.index].num_mmproj >
0 && (
<div className="flex items-center gap-1">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div>
<IconEye
size={17}
className="text-main-view-fg/50"
/>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{t('vision')}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
)}
{filteredModels[virtualItem.index].quants.length > {filteredModels[virtualItem.index].quants.length >
1 && ( 1 && (
<div className="flex items-center gap-2 hub-show-variants-step"> <div className="flex items-center gap-2 hub-show-variants-step">

View File

@ -3,6 +3,7 @@ import { createFileRoute, useSearch } from '@tanstack/react-router'
import ChatInput from '@/containers/ChatInput' import ChatInput from '@/containers/ChatInput'
import HeaderPage from '@/containers/HeaderPage' import HeaderPage from '@/containers/HeaderPage'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useTools } from '@/hooks/useTools'
import { useModelProvider } from '@/hooks/useModelProvider' import { useModelProvider } from '@/hooks/useModelProvider'
import SetupScreen from '@/containers/SetupScreen' import SetupScreen from '@/containers/SetupScreen'
@ -31,6 +32,7 @@ function Index() {
const search = useSearch({ from: route.home as any }) const search = useSearch({ from: route.home as any })
const selectedModel = search.model const selectedModel = search.model
const { setCurrentThreadId } = useThreads() const { setCurrentThreadId } = useThreads()
useTools()
// Conditional to check if there are any valid providers // Conditional to check if there are any valid providers
// required min 1 api_key or 1 model in llama.cpp // required min 1 api_key or 1 model in llama.cpp

View File

@ -11,7 +11,7 @@ export const Route = createFileRoute(route.localApiServerlogs as any)({
component: LogsViewer, component: LogsViewer,
}) })
const SERVER_LOG_TARGET = 'app_lib::core::server' const SERVER_LOG_TARGET = 'app_lib::core::server::proxy'
const LOG_EVENT_NAME = 'log://log' const LOG_EVENT_NAME = 'log://log'
function LogsViewer() { function LogsViewer() {

View File

@ -21,6 +21,7 @@ import { useToolApproval } from '@/hooks/useToolApproval'
import { toast } from 'sonner' import { toast } from 'sonner'
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useAppState } from '@/hooks/useAppState'
// Function to mask sensitive values // Function to mask sensitive values
const maskSensitiveValue = (value: string) => { const maskSensitiveValue = (value: string) => {
@ -29,6 +30,58 @@ const maskSensitiveValue = (value: string) => {
return value.slice(0, 4) + '*'.repeat(value.length - 8) + value.slice(-4) return value.slice(0, 4) + '*'.repeat(value.length - 8) + value.slice(-4)
} }
// Function to mask sensitive URL parameters
const maskSensitiveUrl = (url: string) => {
if (!url) return url
try {
const urlObj = new URL(url)
const params = urlObj.searchParams
// List of sensitive parameter names (case-insensitive)
const sensitiveParams = [
'api_key',
'apikey',
'key',
'token',
'secret',
'password',
'pwd',
'auth',
'authorization',
'bearer',
'access_token',
'refresh_token',
'client_secret',
'private_key',
'signature',
'hash',
]
// Mask sensitive parameters
sensitiveParams.forEach((paramName) => {
// Check both exact match and case-insensitive match
for (const [key, value] of params.entries()) {
if (key.toLowerCase() === paramName.toLowerCase()) {
params.set(key, maskSensitiveValue(value))
}
}
})
// Reconstruct URL with masked parameters
urlObj.search = params.toString()
return urlObj.toString()
} catch {
// If URL parsing fails, just mask the entire query string after '?'
const queryIndex = url.indexOf('?')
if (queryIndex === -1) return url
const baseUrl = url.substring(0, queryIndex + 1)
const queryString = url.substring(queryIndex + 1)
return baseUrl + maskSensitiveValue(queryString)
}
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.mcp_servers as any)({ export const Route = createFileRoute(route.settings.mcp_servers as any)({
component: MCPServers, component: MCPServers,
@ -68,6 +121,7 @@ function MCPServers() {
const [loadingServers, setLoadingServers] = useState<{ const [loadingServers, setLoadingServers] = useState<{
[key: string]: boolean [key: string]: boolean
}>({}) }>({})
const { setErrorMessage } = useAppState()
const handleOpenDialog = (serverKey?: string) => { const handleOpenDialog = (serverKey?: string) => {
if (serverKey) { if (serverKey) {
@ -199,8 +253,9 @@ function MCPServers() {
...(config ?? (mcpServers[serverKey] as MCPServerConfig)), ...(config ?? (mcpServers[serverKey] as MCPServerConfig)),
active: false, active: false,
}) })
toast.error(error, { setErrorMessage({
description: t('mcp-servers:checkParams'), message: error,
subtitle: t('mcp-servers:checkParams'),
}) })
}) })
.finally(() => { .finally(() => {
@ -326,22 +381,56 @@ function MCPServers() {
} }
descriptionOutside={ descriptionOutside={
<div className="text-sm text-main-view-fg/70"> <div className="text-sm text-main-view-fg/70">
<div> <div className="mb-1">
{t('mcp-servers:command')}: {config.command} Transport:{' '}
<span className="uppercase">
{config.type || 'stdio'}
</span>
</div> </div>
<div className="my-1 break-all">
{t('mcp-servers:args')}: {config?.args?.join(', ')} {config.type === 'stdio' || !config.type ? (
</div> <>
{config.env && Object.keys(config.env).length > 0 && ( <div>
<div className="break-all"> {t('mcp-servers:command')}: {config.command}
{t('mcp-servers:env')}:{' '} </div>
{Object.entries(config.env) <div className="my-1 break-all">
.map( {t('mcp-servers:args')}:{' '}
([key, value]) => {config?.args?.join(', ')}
`${key}=${maskSensitiveValue(value)}` </div>
) {config.env &&
.join(', ')} Object.keys(config.env).length > 0 && (
</div> <div className="break-all">
{t('mcp-servers:env')}:{' '}
{Object.entries(config.env)
.map(
([key, value]) =>
`${key}=${maskSensitiveValue(value)}`
)
.join(', ')}
</div>
)}
</>
) : (
<>
<div className="break-all">
URL: {maskSensitiveUrl(config.url || '')}
</div>
{config.headers &&
Object.keys(config.headers).length > 0 && (
<div className="my-1 break-all">
Headers:{' '}
{Object.entries(config.headers)
.map(
([key, value]) =>
`${key}=${maskSensitiveValue(value)}`
)
.join(', ')}
</div>
)}
{config.timeout && (
<div>Timeout: {config.timeout}s</div>
)}
</>
)} )}
</div> </div>
} }

View File

@ -22,10 +22,10 @@ import { useTranslation } from '@/i18n/react-i18next-compat'
import Capabilities from '@/containers/Capabilities' import Capabilities from '@/containers/Capabilities'
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting' import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
import { RenderMarkdown } from '@/containers/RenderMarkdown' import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { DialogEditModel } from '@/containers/dialogs/EditModel'
import { DialogAddModel } from '@/containers/dialogs/AddModel' import { DialogAddModel } from '@/containers/dialogs/AddModel'
import { ModelSetting } from '@/containers/ModelSetting' import { ModelSetting } from '@/containers/ModelSetting'
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel' import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
import { FavoriteModelAction } from '@/containers/FavoriteModelAction'
import Joyride, { CallBackProps, STATUS } from 'react-joyride' import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import { route } from '@/constants/routes' import { route } from '@/constants/routes'
@ -79,6 +79,7 @@ function ProviderDetail() {
const [activeModels, setActiveModels] = useState<string[]>([]) const [activeModels, setActiveModels] = useState<string[]>([])
const [loadingModels, setLoadingModels] = useState<string[]>([]) const [loadingModels, setLoadingModels] = useState<string[]>([])
const [refreshingModels, setRefreshingModels] = useState(false) const [refreshingModels, setRefreshingModels] = useState(false)
const [importingModel, setImportingModel] = useState(false)
const { providerName } = useParams({ from: Route.id }) const { providerName } = useParams({ from: Route.id })
const { getProviderByName, setProviders, updateProvider } = useModelProvider() const { getProviderByName, setProviders, updateProvider } = useModelProvider()
const provider = getProviderByName(providerName) const provider = getProviderByName(providerName)
@ -95,6 +96,60 @@ function ProviderDetail() {
!setting.controller_props.value) !setting.controller_props.value)
) )
const handleImportModel = async () => {
if (!provider) {
return
}
setImportingModel(true)
const selectedFile = await open({
multiple: false,
directory: false,
})
// If the dialog returns a file path, extract just the file name
const fileName =
typeof selectedFile === 'string'
? selectedFile.split(/[\\/]/).pop()?.replace(/\s/g, '-')
: undefined
if (selectedFile && fileName) {
// Check if model already exists
const modelExists = provider.models.some(
(model) => model.name === fileName
)
if (modelExists) {
toast.error('Model already exists', {
description: `${fileName} already imported`,
})
setImportingModel(false)
return
}
try {
await pullModel(fileName, selectedFile)
// Refresh the provider to update the models list
await getProviders().then(setProviders)
toast.success(t('providers:import'), {
id: `import-model-${provider.provider}`,
description: t('providers:importModelSuccess', {
provider: fileName,
}),
})
} catch (error) {
console.error(t('providers:importModelError'), error)
toast.error(t('providers:importModelError'), {
description:
error instanceof Error ? error.message : 'Unknown error occurred',
})
} finally {
setImportingModel(false)
}
} else {
setImportingModel(false)
}
}
useEffect(() => { useEffect(() => {
// Initial data fetch // Initial data fetch
getActiveModels().then((models) => setActiveModels(models || [])) getActiveModels().then((models) => setActiveModels(models || []))
@ -482,52 +537,25 @@ function ProviderDetail() {
variant="link" variant="link"
size="sm" size="sm"
className="hover:no-underline" className="hover:no-underline"
onClick={async () => { disabled={importingModel}
const selectedFile = await open({ onClick={handleImportModel}
multiple: false,
directory: false,
filters: [
{
name: 'GGUF',
extensions: ['gguf'],
},
],
})
// If the dialog returns a file path, extract just the file name
const fileName =
typeof selectedFile === 'string'
? selectedFile.split(/[\\/]/).pop()
: undefined
if (selectedFile && fileName) {
try {
await pullModel(fileName, selectedFile)
} catch (error) {
console.error(
t('providers:importModelError'),
error
)
} finally {
// Refresh the provider to update the models list
getProviders().then(setProviders)
toast.success(t('providers:import'), {
id: `import-model-${provider.provider}`,
description: t(
'providers:importModelSuccess',
{ provider: provider.provider }
),
})
}
}
}}
> >
<div className="cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/15 bg-main-view-fg/10 transition-all duration-200 ease-in-out p-1.5 py-1 gap-1 -mr-2"> <div className="cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/15 bg-main-view-fg/10 transition-all duration-200 ease-in-out p-1.5 py-1 gap-1 -mr-2">
<IconFolderPlus {importingModel ? (
size={18} <IconLoader
className="text-main-view-fg/50" size={18}
/> className="text-main-view-fg/50 animate-spin"
/>
) : (
<IconFolderPlus
size={18}
className="text-main-view-fg/50"
/>
)}
<span className="text-main-view-fg/70"> <span className="text-main-view-fg/70">
{t('providers:import')} {importingModel
? 'Importing...'
: t('providers:import')}
</span> </span>
</div> </div>
</Button> </Button>
@ -554,17 +582,24 @@ function ProviderDetail() {
</div> </div>
} }
actions={ actions={
<div className="flex items-center gap-1"> <div className="flex items-center gap-0.5">
<DialogEditModel
provider={provider}
modelId={model.id}
/>
{model.settings && ( {model.settings && (
<ModelSetting <ModelSetting
provider={provider} provider={provider}
model={model} model={model}
/> />
)} )}
{((provider &&
!predefinedProviders.some(
(p) => p.provider === provider.provider
)) ||
(provider &&
predefinedProviders.some(
(p) => p.provider === provider.provider
) &&
Boolean(provider.api_key?.length))) && (
<FavoriteModelAction model={model} />
)}
<DialogDeleteModel <DialogDeleteModel
provider={provider} provider={provider}
modelId={model.id} modelId={model.id}

View File

@ -25,6 +25,7 @@ import { useCallback, useState } from 'react'
import { openAIProviderSettings } from '@/consts/providers' import { openAIProviderSettings } from '@/consts/providers'
import cloneDeep from 'lodash/cloneDeep' import cloneDeep from 'lodash/cloneDeep'
import { toast } from 'sonner' import { toast } from 'sonner'
import { stopAllModels } from '@/services/models'
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.model_providers as any)({ export const Route = createFileRoute(route.settings.model_providers as any)({
@ -169,7 +170,10 @@ function ModelProviders() {
)} )}
<Switch <Switch
checked={provider.active} checked={provider.active}
onCheckedChange={(e) => { onCheckedChange={async (e) => {
if (!e && provider.provider.toLowerCase() === 'llamacpp') {
await stopAllModels()
}
updateProvider(provider.provider, { updateProvider(provider.provider, {
...provider, ...provider,
active: e, active: e,

View File

@ -23,6 +23,7 @@ import { ContentType, ThreadMessage } from '@janhq/core'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useChat } from '@/hooks/useChat' import { useChat } from '@/hooks/useChat'
import { useSmallScreen } from '@/hooks/useMediaQuery' import { useSmallScreen } from '@/hooks/useMediaQuery'
import { useTools } from '@/hooks/useTools'
// as route.threadsDetail // as route.threadsDetail
export const Route = createFileRoute('/threads/$threadId')({ export const Route = createFileRoute('/threads/$threadId')({
@ -36,6 +37,8 @@ function ThreadDetail() {
const [isAtBottom, setIsAtBottom] = useState(true) const [isAtBottom, setIsAtBottom] = useState(true)
const [hasScrollbar, setHasScrollbar] = useState(false) const [hasScrollbar, setHasScrollbar] = useState(false)
const lastScrollTopRef = useRef(0) const lastScrollTopRef = useRef(0)
const userIntendedPositionRef = useRef<number | null>(null)
const wasStreamingRef = useRef(false)
const { currentThreadId, setCurrentThreadId } = useThreads() const { currentThreadId, setCurrentThreadId } = useThreads()
const { setCurrentAssistant, assistants } = useAssistant() const { setCurrentAssistant, assistants } = useAssistant()
const { setMessages, deleteMessage } = useMessages() const { setMessages, deleteMessage } = useMessages()
@ -43,6 +46,7 @@ function ThreadDetail() {
const { appMainViewBgColor, chatWidth } = useAppearance() const { appMainViewBgColor, chatWidth } = useAppearance()
const { sendMessage } = useChat() const { sendMessage } = useChat()
const isSmallScreen = useSmallScreen() const isSmallScreen = useSmallScreen()
useTools()
const { messages } = useMessages( const { messages } = useMessages(
useShallow((state) => ({ useShallow((state) => ({
@ -110,6 +114,8 @@ function ThreadDetail() {
scrollToBottom() scrollToBottom()
setIsAtBottom(true) setIsAtBottom(true)
setIsUserScrolling(false) setIsUserScrolling(false)
userIntendedPositionRef.current = null
wasStreamingRef.current = false
checkScrollState() checkScrollState()
return return
} }
@ -121,11 +127,39 @@ function ThreadDetail() {
scrollToBottom() scrollToBottom()
setIsAtBottom(true) setIsAtBottom(true)
setIsUserScrolling(false) setIsUserScrolling(false)
userIntendedPositionRef.current = null
wasStreamingRef.current = false
checkScrollState() checkScrollState()
}, [threadId]) }, [threadId])
// Single useEffect for all auto-scrolling logic // Single useEffect for all auto-scrolling logic
useEffect(() => { useEffect(() => {
// Track streaming state changes
const isCurrentlyStreaming = !!streamingContent
const justFinishedStreaming = wasStreamingRef.current && !isCurrentlyStreaming
wasStreamingRef.current = isCurrentlyStreaming
// If streaming just finished and user had an intended position, restore it
if (justFinishedStreaming && userIntendedPositionRef.current !== null) {
// Small delay to ensure DOM has updated
setTimeout(() => {
if (scrollContainerRef.current && userIntendedPositionRef.current !== null) {
scrollContainerRef.current.scrollTo({
top: userIntendedPositionRef.current,
behavior: 'smooth'
})
userIntendedPositionRef.current = null
setIsUserScrolling(false)
}
}, 100)
return
}
// Clear intended position when streaming starts fresh
if (isCurrentlyStreaming && !wasStreamingRef.current) {
userIntendedPositionRef.current = null
}
// Only auto-scroll when the user is not actively scrolling // Only auto-scroll when the user is not actively scrolling
// AND either at the bottom OR there's streaming content // AND either at the bottom OR there's streaming content
if (!isUserScrolling && (streamingContent || isAtBottom) && messagesCount) { if (!isUserScrolling && (streamingContent || isAtBottom) && messagesCount) {
@ -161,6 +195,11 @@ function ThreadDetail() {
// Detect if this is a user-initiated scroll // Detect if this is a user-initiated scroll
if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) { if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) {
setIsUserScrolling(!isBottom) setIsUserScrolling(!isBottom)
// If user scrolls during streaming and moves away from bottom, record their intended position
if (streamingContent && !isBottom) {
userIntendedPositionRef.current = scrollTop
}
} }
setIsAtBottom(isBottom) setIsAtBottom(isBottom)
setHasScrollbar(hasScroll) setHasScrollbar(hasScroll)
@ -178,6 +217,11 @@ function ThreadDetail() {
// Detect if this is a user-initiated scroll // Detect if this is a user-initiated scroll
if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) { if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) {
setIsUserScrolling(!isBottom) setIsUserScrolling(!isBottom)
// If user scrolls during streaming and moves away from bottom, record their intended position
if (streamingContent && !isBottom) {
userIntendedPositionRef.current = scrollTop
}
} }
setIsAtBottom(isBottom) setIsAtBottom(isBottom)
setHasScrollbar(hasScroll) setHasScrollbar(hasScroll)

View File

@ -290,7 +290,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational', 'pytorch'], tags: ['conversational', 'pytorch'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -443,7 +443,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -471,7 +471,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -510,7 +510,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -559,7 +559,7 @@ describe('models service', () => {
likes: 75, likes: 75,
tags: ['pytorch', 'transformers', 'text-generation'], tags: ['pytorch', 'transformers', 'text-generation'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2021-01-01T00:00:00Z', createdAt: '2021-01-01T00:00:00Z',
last_modified: '2021-12-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -605,6 +605,8 @@ describe('models service', () => {
file_size: '4.0 GB', file_size: '4.0 GB',
}, },
], ],
num_mmproj: 0,
mmproj_models: [],
created_at: '2021-01-01T00:00:00Z', created_at: '2021-01-01T00:00:00Z',
readme: readme:
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md', 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
@ -820,7 +822,7 @@ describe('models service', () => {
downloads: 0, downloads: 0,
likes: 0, likes: 0,
tags: [], tags: [],
created_at: '2021-01-01T00:00:00Z', createdAt: '2021-01-01T00:00:00Z',
last_modified: '2021-12-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,

View File

@ -5,7 +5,7 @@ import {
updateSettings, updateSettings,
} from '../providers' } from '../providers'
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/mock/data' import { predefinedProviders } from '@/consts/providers'
import { EngineManager } from '@janhq/core' import { EngineManager } from '@janhq/core'
import { fetchModels } from '../models' import { fetchModels } from '../models'
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
@ -21,7 +21,7 @@ vi.mock('token.js', () => ({
}, },
})) }))
vi.mock('@/mock/data', () => ({ vi.mock('@/consts/providers', () => ({
predefinedProviders: [ predefinedProviders: [
{ {
active: true, active: true,
@ -69,6 +69,7 @@ vi.mock('../models', () => ({
{ id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' }, { id: 'llama-2-7b', name: 'Llama 2 7B', description: 'Llama model' },
]) ])
), ),
isToolSupported: vi.fn(() => Promise.resolve(false)),
})) }))
vi.mock('@/lib/extension', () => ({ vi.mock('@/lib/extension', () => ({
@ -116,7 +117,7 @@ describe('providers service', () => {
it('should return builtin and runtime providers', async () => { it('should return builtin and runtime providers', async () => {
const providers = await getProviders() const providers = await getProviders()
expect(providers).toHaveLength(9) // 8 runtime + 1 builtin expect(providers).toHaveLength(2) // 1 runtime + 1 builtin (mocked)
expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true) expect(providers.some((p) => p.provider === 'llamacpp')).toBe(true)
expect(providers.some((p) => p.provider === 'openai')).toBe(true) expect(providers.some((p) => p.provider === 'openai')).toBe(true)
}) })
@ -156,7 +157,7 @@ describe('providers service', () => {
provider: 'openai', provider: 'openai',
base_url: 'https://api.openai.com/v1', base_url: 'https://api.openai.com/v1',
api_key: 'test-key', api_key: 'test-key',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -185,7 +186,7 @@ describe('providers service', () => {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
api_key: '', api_key: '',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -204,7 +205,7 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -214,7 +215,7 @@ describe('providers service', () => {
it('should throw error when provider has no base_url', async () => { it('should throw error when provider has no base_url', async () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Provider must have base_url configured' 'Provider must have base_url configured'
@ -232,10 +233,10 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com' 'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
) )
}) })
@ -245,10 +246,10 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
await expect(fetchModelsFromProvider(provider)).rejects.toThrow( await expect(fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com' 'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.'
) )
}) })
@ -264,7 +265,7 @@ describe('providers service', () => {
const provider = { const provider = {
provider: 'custom', provider: 'custom',
base_url: 'https://api.custom.com', base_url: 'https://api.custom.com',
} as ModelProvider }
const models = await fetchModelsFromProvider(provider) const models = await fetchModelsFromProvider(provider)
@ -298,7 +299,7 @@ describe('providers service', () => {
controller_type: 'input', controller_type: 'input',
controller_props: { value: 'test-key' }, controller_props: { value: 'test-key' },
}, },
] as ProviderSetting[] ]
await updateSettings('openai', settings) await updateSettings('openai', settings)
@ -324,7 +325,7 @@ describe('providers service', () => {
mockExtensionManager mockExtensionManager
) )
const settings = [] as ProviderSetting[] const settings = []
const result = await updateSettings('nonexistent', settings) const result = await updateSettings('nonexistent', settings)
@ -350,7 +351,7 @@ describe('providers service', () => {
controller_type: 'input', controller_type: 'input',
controller_props: { value: undefined }, controller_props: { value: undefined },
}, },
] as ProviderSetting[] ]
await updateSettings('openai', settings) await updateSettings('openai', settings)

View File

@ -56,3 +56,44 @@ export const callTool = (args: {
}): Promise<{ error: string; content: { text: string }[] }> => { }): Promise<{ error: string; content: { text: string }[] }> => {
return window.core?.api?.callTool(args) return window.core?.api?.callTool(args)
} }
/**
* @description Enhanced function to invoke an MCP tool with cancellation support
* @param args - Tool call arguments
* @param cancellationToken - Optional cancellation token
* @returns Promise with tool result and cancellation function
*/
export const callToolWithCancellation = (args: {
toolName: string
arguments: object
cancellationToken?: string
}): {
promise: Promise<{ error: string; content: { text: string }[] }>
cancel: () => Promise<void>
token: string
} => {
// Generate a unique cancellation token if not provided
const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
// Create the tool call promise with cancellation token
const promise = window.core?.api?.callTool({
...args,
cancellationToken: token
})
// Create cancel function
const cancel = async () => {
await window.core?.api?.cancelToolCall({ cancellationToken: token })
}
return { promise, cancel, token }
}
/**
* @description This function cancels a running tool call
* @param cancellationToken - The token identifying the tool call to cancel
* @returns
*/
export const cancelToolCall = (cancellationToken: string): Promise<void> => {
return window.core?.api?.cancelToolCall({ cancellationToken })
}

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { sanitizeModelId } from '@/lib/utils' import { sanitizeModelId } from '@/lib/utils'
import { import {
AIEngine, AIEngine,
@ -27,8 +28,10 @@ export interface CatalogModel {
num_quants: number num_quants: number
quants: ModelQuant[] quants: ModelQuant[]
mmproj_models?: MMProjModel[] mmproj_models?: MMProjModel[]
num_mmproj: number
created_at?: string created_at?: string
readme?: string readme?: string
tools?: boolean
} }
export type ModelCatalog = CatalogModel[] export type ModelCatalog = CatalogModel[]
@ -43,7 +46,7 @@ export interface HuggingFaceRepo {
library_name?: string library_name?: string
tags: string[] tags: string[]
pipeline_tag?: string pipeline_tag?: string
created_at: string createdAt: string
last_modified: string last_modified: string
private: boolean private: boolean
disabled: boolean disabled: boolean
@ -154,21 +157,30 @@ export const fetchHuggingFaceRepo = async (
export const convertHfRepoToCatalogModel = ( export const convertHfRepoToCatalogModel = (
repo: HuggingFaceRepo repo: HuggingFaceRepo
): CatalogModel => { ): CatalogModel => {
// Format file size helper
const formatFileSize = (size?: number) => {
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
// Extract GGUF files from the repository siblings // Extract GGUF files from the repository siblings
const ggufFiles = const ggufFiles =
repo.siblings?.filter((file) => repo.siblings?.filter((file) =>
file.rfilename.toLowerCase().endsWith('.gguf') file.rfilename.toLowerCase().endsWith('.gguf')
) || [] ) || []
// Convert GGUF files to quants format // Separate regular GGUF files from mmproj files
const quants = ggufFiles.map((file) => { const regularGgufFiles = ggufFiles.filter(
// Format file size (file) => !file.rfilename.toLowerCase().includes('mmproj')
const formatFileSize = (size?: number) => { )
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
const mmprojFiles = ggufFiles.filter((file) =>
file.rfilename.toLowerCase().includes('mmproj')
)
// Convert regular GGUF files to quants format
const quants = regularGgufFiles.map((file) => {
// Generate model_id from filename (remove .gguf extension, case-insensitive) // Generate model_id from filename (remove .gguf extension, case-insensitive)
const modelId = file.rfilename.replace(/\.gguf$/i, '') const modelId = file.rfilename.replace(/\.gguf$/i, '')
@ -179,15 +191,28 @@ export const convertHfRepoToCatalogModel = (
} }
}) })
// Convert mmproj files to mmproj_models format
const mmprojModels = mmprojFiles.map((file) => {
const modelId = file.rfilename.replace(/\.gguf$/i, '')
return {
model_id: sanitizeModelId(modelId),
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
file_size: formatFileSize(file.size),
}
})
return { return {
model_name: repo.modelId, model_name: repo.modelId,
description: `**Tags**: ${repo.tags?.join(', ')}`,
developer: repo.author, developer: repo.author,
downloads: repo.downloads || 0, downloads: repo.downloads || 0,
created_at: repo.createdAt,
num_quants: quants.length, num_quants: quants.length,
quants: quants, quants: quants,
created_at: repo.created_at, num_mmproj: mmprojModels.length,
mmproj_models: mmprojModels,
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`, readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
description: `**Tags**: ${repo.tags?.join(', ')}`,
} }
} }
@ -313,3 +338,150 @@ export const startModel = async (
throw error throw error
}) })
} }
/**
* Check if model support tool use capability
* Returned by backend engine
* @param modelId
* @returns
*/
export const isToolSupported = async (modelId: string): Promise<boolean> => {
const engine = getEngine()
if (!engine) return false
return engine.isToolSupported(modelId)
}
/**
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
* Also checks if the model has offload_mmproj setting.
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
* @param modelId - The model ID to check for mmproj.gguf
* @param updateProvider - Function to update the provider state
* @param getProviderByName - Function to get provider by name
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
*/
export const checkMmprojExistsAndUpdateOffloadMMprojSetting = async (
modelId: string,
updateProvider?: (providerName: string, data: Partial<ModelProvider>) => void,
getProviderByName?: (providerName: string) => ModelProvider | undefined
): Promise<{ exists: boolean; settingsUpdated: boolean }> => {
let settingsUpdated = false
try {
const engine = getEngine('llamacpp') as AIEngine & {
checkMmprojExists?: (id: string) => Promise<boolean>
}
if (engine && typeof engine.checkMmprojExists === 'function') {
const exists = await engine.checkMmprojExists(modelId)
// If we have the store functions, use them; otherwise fall back to localStorage
if (updateProvider && getProviderByName) {
const provider = getProviderByName('llamacpp')
if (provider) {
const model = provider.models.find((m) => m.id === modelId)
if (model?.settings) {
const hasOffloadMmproj = 'offload_mmproj' in model.settings
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
if (exists && !hasOffloadMmproj) {
// Create updated models array with the new setting
const updatedModels = provider.models.map((m) => {
if (m.id === modelId) {
return {
...m,
settings: {
...m.settings,
offload_mmproj: {
key: 'offload_mmproj',
title: 'Offload MMProj',
description:
'Offload multimodal projection layers to GPU',
controller_type: 'checkbox',
controller_props: {
value: true,
},
},
},
}
}
return m
})
// Update the provider with the new models array
updateProvider('llamacpp', { models: updatedModels })
settingsUpdated = true
}
}
}
} else {
// Fall back to localStorage approach for backwards compatibility
try {
const modelProviderData = JSON.parse(
localStorage.getItem('model-provider') || '{}'
)
const llamacppProvider = modelProviderData.state?.providers?.find(
(p: any) => p.provider === 'llamacpp'
)
const model = llamacppProvider?.models?.find(
(m: any) => m.id === modelId
)
if (model?.settings) {
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
if (exists) {
if (!model.settings.offload_mmproj) {
model.settings.offload_mmproj = {
key: 'offload_mmproj',
title: 'Offload MMProj',
description: 'Offload multimodal projection layers to GPU',
controller_type: 'checkbox',
controller_props: {
value: true,
},
}
// Save updated settings back to localStorage
localStorage.setItem(
'model-provider',
JSON.stringify(modelProviderData)
)
settingsUpdated = true
}
}
}
} catch (localStorageError) {
console.error(
`Error checking localStorage for model ${modelId}:`,
localStorageError
)
}
}
return { exists, settingsUpdated }
}
} catch (error) {
console.error(`Error checking mmproj for model ${modelId}:`, error)
}
return { exists: false, settingsUpdated }
}
/**
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
* @param modelId - The model ID to check for mmproj.gguf
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
*/
export const checkMmprojExists = async (modelId: string): Promise<boolean> => {
try {
const engine = getEngine('llamacpp') as AIEngine & {
checkMmprojExists?: (id: string) => Promise<boolean>
}
if (engine && typeof engine.checkMmprojExists === 'function') {
return await engine.checkMmprojExists(modelId)
}
} catch (error) {
console.error(`Error checking mmproj for model ${modelId}:`, error)
}
return false
}

View File

@ -1,12 +1,9 @@
import { models as providerModels } from 'token.js' import { models as providerModels } from 'token.js'
import { predefinedProviders } from '@/consts/providers' import { predefinedProviders } from '@/consts/providers'
import { EngineManager, SettingComponentProps } from '@janhq/core' import { EngineManager, SettingComponentProps } from '@janhq/core'
import { import { ModelCapabilities } from '@/types/models'
DefaultToolUseSupportedModels,
ModelCapabilities,
} from '@/types/models'
import { modelSettings } from '@/lib/predefined' import { modelSettings } from '@/lib/predefined'
import { fetchModels } from './models' import { fetchModels, isToolSupported } from './models'
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http' import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
@ -65,52 +62,41 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
controller_props: setting.controllerProps as unknown, controller_props: setting.controllerProps as unknown,
} }
}) as ProviderSetting[], }) as ProviderSetting[],
models: models.map((model) => ({ models: await Promise.all(
id: model.id, models.map(
model: model.id, async (model) =>
name: model.name, ({
description: model.description, id: model.id,
capabilities: model: model.id,
'capabilities' in model name: model.name,
? (model.capabilities as string[]) description: model.description,
: [ capabilities:
ModelCapabilities.COMPLETION, 'capabilities' in model
...(Object.values(DefaultToolUseSupportedModels).some((v) => ? (model.capabilities as string[])
model.id.toLowerCase().includes(v.toLowerCase()) : (await isToolSupported(model.id))
) ? [ModelCapabilities.TOOLS]
? [ModelCapabilities.TOOLS] : [],
: []), provider: providerName,
], settings: Object.values(modelSettings).reduce(
provider: providerName, (acc, setting) => {
settings: Object.values(modelSettings).reduce( let value = setting.controller_props.value
(acc, setting) => { if (setting.key === 'ctx_len') {
let value = setting.controller_props.value value = 8192 // Default context length for Llama.cpp models
if (setting.key === 'ctx_len') { }
value = 8192 // Default context length for Llama.cpp models acc[setting.key] = {
} ...setting,
// Set temperature to 0.6 for DefaultToolUseSupportedModels controller_props: {
if ( ...setting.controller_props,
Object.values(DefaultToolUseSupportedModels).some((v) => value: value,
model.id.toLowerCase().includes(v.toLowerCase()) },
) }
) { return acc
if (setting.key === 'temperature') value = 0.7 // Default temperature for tool-supported models },
if (setting.key === 'top_k') value = 20 // Default top_k for tool-supported models {} as Record<string, ProviderSetting>
if (setting.key === 'top_p') value = 0.8 // Default top_p for tool-supported models ),
if (setting.key === 'min_p') value = 0 // Default min_p for tool-supported models }) as Model
} )
acc[setting.key] = { ),
...setting,
controller_props: {
...setting.controller_props,
value: value,
},
}
return acc
},
{} as Record<string, ProviderSetting>
),
})),
} }
runtimeProviders.push(provider) runtimeProviders.push(provider)
} }

View File

@ -6,4 +6,5 @@ export type MCPTool = {
name: string name: string
description: string description: string
inputSchema: Record<string, unknown> inputSchema: Record<string, unknown>
server: string
} }

Some files were not shown because too many files have changed in this diff Show More