Merge pull request #4860 from menloresearch/feat/mcp-jan-frontend

feat: Jan Tool Use - MCP frontend implementation
This commit is contained in:
Louis 2025-04-07 02:24:08 +07:00 committed by GitHub
commit 7e2498cc79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 308 additions and 559 deletions

View File

@ -40,12 +40,13 @@ export abstract class AIEngine extends BaseExtension {
* Stops the model. * Stops the model.
*/ */
async unloadModel(model?: Model): Promise<any> { async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() if (model?.engine && model.engine.toString() !== this.provider)
return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {}) events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve() return Promise.resolve()
} }
/* /**
* Inference request * Inference request
*/ */
inference(data: MessageRequest) {} inference(data: MessageRequest) {}

View File

@ -76,7 +76,7 @@ export abstract class OAIEngine extends AIEngine {
const timestamp = Date.now() / 1000 const timestamp = Date.now() / 1000
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
thread_id: data.threadId, thread_id: data.thread?.id ?? data.threadId,
type: data.type, type: data.type,
assistant_id: data.assistantId, assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant, role: ChatCompletionRole.Assistant,
@ -104,6 +104,7 @@ export abstract class OAIEngine extends AIEngine {
messages: data.messages ?? [], messages: data.messages ?? [],
model: model.id, model: model.id,
stream: true, stream: true,
tools: data.tools,
...model.parameters, ...model.parameters,
} }
if (this.transformPayload) { if (this.transformPayload) {

View File

@ -28,12 +28,6 @@ export * from './extension'
*/ */
export * from './extensions' export * from './extensions'
/**
* Export all base tools.
* @module
*/
export * from './tools'
/** /**
* Export all base models. * Export all base models.
* @module * @module

View File

@ -1,5 +0,0 @@
it('should not throw any errors when imported', () => {
expect(() => require('./index')).not.toThrow();
})

View File

@ -1,2 +0,0 @@
export * from './manager'
export * from './tool'

View File

@ -1,47 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
import { InferenceTool } from './tool'
/**
* Manages the registration and retrieval of inference tools.
*/
export class ToolManager {
public tools = new Map<string, InferenceTool>()
/**
* Registers a tool.
* @param tool - The tool to register.
*/
register<T extends InferenceTool>(tool: T) {
this.tools.set(tool.name, tool)
}
/**
* Retrieves a tool by it's name.
* @param name - The name of the tool to retrieve.
* @returns The tool, if found.
*/
get<T extends InferenceTool>(name: string): T | undefined {
return this.tools.get(name) as T | undefined
}
/*
** Process the message request with the tools.
*/
process(request: MessageRequest, tools: AssistantTool[]): Promise<MessageRequest> {
return tools.reduce((prevPromise, currentTool) => {
return prevPromise.then((prevResult) => {
return currentTool.enabled
? this.get(currentTool.type)?.process(prevResult, currentTool) ??
Promise.resolve(prevResult)
: Promise.resolve(prevResult)
})
}, Promise.resolve(request))
}
/**
* The instance of the tool manager.
*/
static instance(): ToolManager {
return (window.core?.toolManager as ToolManager) ?? new ToolManager()
}
}

View File

@ -1,63 +0,0 @@
import { ToolManager } from '../../browser/tools/manager'
import { InferenceTool } from '../../browser/tools/tool'
import { AssistantTool, MessageRequest } from '../../types'
class MockInferenceTool implements InferenceTool {
name = 'mockTool'
process(request: MessageRequest, tool: AssistantTool): Promise<MessageRequest> {
return Promise.resolve(request)
}
}
it('should register a tool', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
expect(manager.get(tool.name)).toBe(tool)
})
it('should retrieve a tool by its name', () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const retrievedTool = manager.get(tool.name)
expect(retrievedTool).toBe(tool)
})
it('should return undefined for a non-existent tool', () => {
const manager = new ToolManager()
const retrievedTool = manager.get('nonExistentTool')
expect(retrievedTool).toBeUndefined()
})
it('should process the message request with enabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: true }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should skip processing for disabled tools', async () => {
const manager = new ToolManager()
const tool = new MockInferenceTool()
manager.register(tool)
const request: MessageRequest = { message: 'test' } as any
const tools: AssistantTool[] = [{ type: 'mockTool', enabled: false }] as any
const result = await manager.process(request, tools)
expect(result).toBe(request)
})
it('should throw an error when process is called without implementation', () => {
class TestTool extends InferenceTool {
name = 'testTool'
}
const tool = new TestTool()
expect(() => tool.process({} as MessageRequest)).toThrowError()
})

View File

@ -1,12 +0,0 @@
import { AssistantTool, MessageRequest } from '../../types'
/**
* Represents a base inference tool.
*/
export abstract class InferenceTool {
abstract name: string
/*
** Process a message request and return the processed message request.
*/
abstract process(request: MessageRequest, tool?: AssistantTool): Promise<MessageRequest>
}

View File

@ -43,6 +43,9 @@ export type ThreadMessage = {
* @data_transfer_object * @data_transfer_object
*/ */
export type MessageRequest = { export type MessageRequest = {
/**
* The id of the message request.
*/
id?: string id?: string
/** /**
@ -71,6 +74,11 @@ export type MessageRequest = {
// TODO: deprecate threadId field // TODO: deprecate threadId field
thread?: Thread thread?: Thread
/**
* ChatCompletion tools
*/
tools?: MessageTool[]
/** Engine name to process */ /** Engine name to process */
engine?: string engine?: string
@ -78,6 +86,24 @@ export type MessageRequest = {
type?: string type?: string
} }
/**
* ChatCompletion Tool parameters
*/
export type MessageTool = {
type: string
function: MessageFunction
}
/**
* ChatCompletion Tool's function parameters
*/
export type MessageFunction = {
name: string
description?: string
parameters?: Record<string, unknown>
strict?: boolean
}
/** /**
* The status of the message. * The status of the message.
* @data_transfer_object * @data_transfer_object

View File

@ -8,17 +8,10 @@
"author": "Jan <service@jan.ai>", "author": "Jan <service@jan.ai>",
"license": "AGPL-3.0", "license": "AGPL-3.0",
"scripts": { "scripts": {
"clean:modules": "rimraf node_modules/pdf-parse/test && cd node_modules/pdf-parse/lib/pdf.js && rimraf v1.9.426 v1.10.88 v2.0.550", "build": "rolldown -c rolldown.config.mjs",
"build-universal-hnswlib": "[ \"$IS_TEST\" = \"true\" ] && echo \"Skip universal build\" || (cd node_modules/hnswlib-node && arch -x86_64 npx node-gyp rebuild --arch=x64 && mv build/Release/addon.node ./addon-amd64.node && node-gyp rebuild --arch=arm64 && mv build/Release/addon.node ./addon-arm64.node && lipo -create -output build/Release/addon.node ./addon-arm64.node ./addon-amd64.node && rm ./addon-arm64.node && rm ./addon-amd64.node)", "build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
"build": "yarn clean:modules && rolldown -c rolldown.config.mjs",
"build:publish:linux": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install",
"build:publish:darwin": "rimraf *.tgz --glob || true && yarn build-universal-hnswlib && yarn build && ../../.github/scripts/auto-sign.sh && npm pack && cpx *.tgz ../../pre-install",
"build:publish:win32": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install",
"build:publish": "run-script-os",
"build:dev": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
}, },
"devDependencies": { "devDependencies": {
"@types/pdf-parse": "^1.1.4",
"cpx": "^1.5.0", "cpx": "^1.5.0",
"rimraf": "^3.0.2", "rimraf": "^3.0.2",
"rolldown": "1.0.0-beta.1", "rolldown": "1.0.0-beta.1",
@ -27,11 +20,6 @@
}, },
"dependencies": { "dependencies": {
"@janhq/core": "../../core/package.tgz", "@janhq/core": "../../core/package.tgz",
"@langchain/community": "0.0.13",
"hnswlib-node": "^1.4.2",
"langchain": "^0.0.214",
"node-gyp": "^11.0.0",
"pdf-parse": "^1.1.1",
"ts-loader": "^9.5.0" "ts-loader": "^9.5.0"
}, },
"files": [ "files": [
@ -40,8 +28,7 @@
"README.md" "README.md"
], ],
"bundleDependencies": [ "bundleDependencies": [
"@janhq/core", "@janhq/core"
"hnswlib-node"
], ],
"installConfig": { "installConfig": {
"hoistingLimits": "workspaces" "hoistingLimits": "workspaces"

View File

@ -13,22 +13,5 @@ export default defineConfig([
NODE: JSON.stringify(`${pkgJson.name}/${pkgJson.node}`), NODE: JSON.stringify(`${pkgJson.name}/${pkgJson.node}`),
VERSION: JSON.stringify(pkgJson.version), VERSION: JSON.stringify(pkgJson.version),
}, },
}, }
{
input: 'src/node/index.ts',
external: ['@janhq/core/node', 'path', 'hnswlib-node'],
output: {
format: 'cjs',
file: 'dist/node/index.js',
sourcemap: false,
inlineDynamicImports: true,
},
resolve: {
extensions: ['.js', '.ts'],
},
define: {
CORTEX_API_URL: JSON.stringify(`http://127.0.0.1:${process.env.CORTEX_API_PORT ?? "39291"}`),
},
platform: 'node',
},
]) ])

View File

@ -1,12 +1,7 @@
import { Assistant, AssistantExtension, ToolManager } from '@janhq/core' import { Assistant, AssistantExtension } from '@janhq/core'
import { RetrievalTool } from './tools/retrieval'
export default class JanAssistantExtension extends AssistantExtension { export default class JanAssistantExtension extends AssistantExtension {
async onLoad() {}
async onLoad() {
// Register the retrieval tool
ToolManager.instance().register(new RetrievalTool())
}
/** /**
* Called when the extension is unloaded. * Called when the extension is unloaded.

View File

@ -1,45 +0,0 @@
import { getJanDataFolderPath } from '@janhq/core/node'
import { retrieval } from './retrieval'
import path from 'path'
export function toolRetrievalUpdateTextSplitter(
chunkSize: number,
chunkOverlap: number
) {
retrieval.updateTextSplitter(chunkSize, chunkOverlap)
}
export async function toolRetrievalIngestNewDocument(
thread: string,
file: string,
model: string,
engine: string,
useTimeWeighted: boolean
) {
const threadPath = path.join(getJanDataFolderPath(), 'threads', thread)
const filePath = path.join(getJanDataFolderPath(), 'files', file)
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalLoadThreadMemory(threadId: string) {
return retrieval
.loadRetrievalAgent(
path.join(getJanDataFolderPath(), 'threads', threadId, 'memory')
)
.catch((err) => {
console.error(err)
})
}
export async function toolRetrievalQueryResult(
query: string,
useTimeWeighted: boolean = false
) {
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
console.error(err)
})
}

View File

@ -1,121 +0,0 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { formatDocumentsAsString } from 'langchain/util/document'
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
export class Retrieval {
public chunkSize: number = 100
public chunkOverlap?: number = 0
private retriever: any
private embeddingModel?: OpenAIEmbeddings = undefined
private textSplitter?: RecursiveCharacterTextSplitter
// to support time-weighted retrieval
private timeWeightedVectorStore: MemoryVectorStore
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap)
this.initialize()
}
private async initialize() {
const apiKey = await window.core?.api.appToken()
// declare time-weighted retriever and storage
this.timeWeightedVectorStore = new MemoryVectorStore(
new OpenAIEmbeddings(
{ openAIApiKey: apiKey },
{ basePath: `${CORTEX_API_URL}/v1` }
)
)
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
vectorStore: this.timeWeightedVectorStore,
memoryStream: [],
searchKwargs: 2,
})
}
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
this.chunkSize = chunkSize
this.chunkOverlap = chunkOverlap
this.textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: chunkSize,
chunkOverlap: chunkOverlap,
})
}
public async updateEmbeddingEngine(model: string, engine: string) {
const apiKey = await window.core?.api.appToken()
this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: apiKey, model },
// TODO: Raw settings
{ basePath: `${CORTEX_API_URL}/v1` }
)
// update time-weighted embedding model
this.timeWeightedVectorStore.embeddings = this.embeddingModel
}
public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string,
useTimeWeighted: boolean
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
})
if (!this.embeddingModel) return Promise.reject()
const doc = await loader.load()
const docs = await this.textSplitter!.splitDocuments(doc)
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
if (useTimeWeighted && this.timeWeightedretriever) {
await (
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
).addDocuments(docs)
}
return vectorStore.save(memoryPath)
}
public loadRetrievalAgent = async (memoryPath: string): Promise<void> => {
if (!this.embeddingModel) return Promise.reject()
const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel)
this.retriever = vectorStore.asRetriever(2)
return Promise.resolve()
}
public generateResult = async (
query: string,
useTimeWeighted: boolean
): Promise<string> => {
if (useTimeWeighted) {
if (!this.timeWeightedretriever) {
return Promise.resolve(' ')
}
// use invoke because getRelevantDocuments is deprecated
const relevantDocs = await this.timeWeightedretriever.invoke(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
if (!this.retriever) {
return Promise.resolve(' ')
}
// should use invoke(query) because getRelevantDocuments is deprecated
const relevantDocs = await this.retriever.getRelevantDocuments(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
}
export const retrieval = new Retrieval()

View File

@ -1,118 +0,0 @@
import {
AssistantTool,
executeOnMain,
fs,
InferenceTool,
joinPath,
MessageRequest,
} from '@janhq/core'
export class RetrievalTool extends InferenceTool {
private _threadDir = 'file://threads'
private retrievalThreadId: string | undefined = undefined
name: string = 'retrieval'
async process(
data: MessageRequest,
tool?: AssistantTool
): Promise<MessageRequest> {
if (!data.model || !data.messages) {
return Promise.resolve(data)
}
const latestMessage = data.messages[data.messages.length - 1]
// 1. Ingest the document if needed
if (
latestMessage &&
latestMessage.content &&
typeof latestMessage.content !== 'string' &&
latestMessage.content.length > 1
) {
const docFile = latestMessage.content[1]?.doc_url?.url
if (docFile) {
await executeOnMain(
NODE,
'toolRetrievalIngestNewDocument',
data.thread?.id,
docFile,
data.model?.id,
data.model?.engine,
tool?.useTimeWeightedRetriever ?? false
)
} else {
return Promise.resolve(data)
}
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([this._threadDir, data.threadId, 'memory'])
))
) {
// No document ingested, reroute the result to inference engine
return Promise.resolve(data)
}
// 2. Load agent on thread changed
if (this.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, 'toolRetrievalLoadThreadMemory', data.threadId)
this.retrievalThreadId = data.threadId
// Update the text splitter
await executeOnMain(
NODE,
'toolRetrievalUpdateTextSplitter',
tool?.settings?.chunk_size ?? 4000,
tool?.settings?.chunk_overlap ?? 200
)
}
// 3. Using the retrieval template with the result and query
if (latestMessage.content) {
const prompt =
typeof latestMessage.content === 'string'
? latestMessage.content
: latestMessage.content[0].text
// Retrieve the result
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt,
tool?.useTimeWeightedRetriever ?? false
)
console.debug('toolRetrievalQueryResult', retrievalResult)
// Update message content
if (retrievalResult)
data.messages[data.messages.length - 1].content =
tool?.settings?.retrieval_template
?.replace('{CONTEXT}', retrievalResult)
.replace('{QUESTION}', prompt)
}
// 4. Reroute the result to inference engine
return Promise.resolve(this.normalize(data))
}
// Filter out all the messages that are not text
// TODO: Remove it until engines can handle multiple content types
normalize(request: MessageRequest): MessageRequest {
request.messages = request.messages?.map((message) => {
if (
message.content &&
typeof message.content !== 'string' &&
(message.content.length ?? 0) > 0
) {
return {
...message,
content: [message.content[0]],
}
}
return message
})
return request
}
}

View File

@ -21,7 +21,6 @@
"dev:electron": "yarn copy:assets && yarn workspace jan dev", "dev:electron": "yarn copy:assets && yarn workspace jan dev",
"dev:web:standalone": "concurrently \"yarn workspace @janhq/web dev\" \"wait-on http://localhost:3000 && rsync -av --prune-empty-dirs --include '*/' --include 'dist/***' --include 'package.json' --include 'tsconfig.json' --exclude '*' ./extensions/ web/.next/static/extensions/\"", "dev:web:standalone": "concurrently \"yarn workspace @janhq/web dev\" \"wait-on http://localhost:3000 && rsync -av --prune-empty-dirs --include '*/' --include 'dist/***' --include 'package.json' --include 'tsconfig.json' --exclude '*' ./extensions/ web/.next/static/extensions/\"",
"dev:web": "yarn workspace @janhq/web dev", "dev:web": "yarn workspace @janhq/web dev",
"dev:web:tauri": "IS_TAURI=true yarn workspace @janhq/web dev",
"dev:server": "yarn workspace @janhq/server dev", "dev:server": "yarn workspace @janhq/server dev",
"dev": "concurrently -n \"NEXT,ELECTRON\" -c \"yellow,blue\" --kill-others \"yarn dev:web\" \"yarn dev:electron\"", "dev": "concurrently -n \"NEXT,ELECTRON\" -c \"yellow,blue\" --kill-others \"yarn dev:web\" \"yarn dev:electron\"",
"install:cortex:linux:darwin": "cd src-tauri/binaries && ./download.sh", "install:cortex:linux:darwin": "cd src-tauri/binaries && ./download.sh",

View File

@ -21,9 +21,7 @@ tauri-build = { version = "2.0.2", features = [] }
serde_json = "1.0" serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
log = "0.4" log = "0.4"
tauri = { version = "2.1.0", features = [ tauri = { version = "2.1.0", features = [ "protocol-asset", "macos-private-api",
"protocol-asset",
'macos-private-api',
"test", "test",
] } ] }
tauri-plugin-log = "2.0.0-rc" tauri-plugin-log = "2.0.0-rc"
@ -36,10 +34,12 @@ tauri-plugin-store = "2"
hyper = { version = "0.14", features = ["server"] } hyper = { version = "0.14", features = ["server"] }
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tracing = "0.1.41"
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
"client", "client",
"transport-sse", "transport-sse",
"transport-child-process", "transport-child-process",
"tower", "tower",
] } ] }
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
tauri-plugin-updater = "2"

View File

@ -2,18 +2,15 @@
"$schema": "../gen/schemas/desktop-schema.json", "$schema": "../gen/schemas/desktop-schema.json",
"identifier": "default", "identifier": "default",
"description": "enables the default permissions", "description": "enables the default permissions",
"windows": [ "windows": ["main"],
"main"
],
"remote": { "remote": {
"urls": [ "urls": ["http://*"]
"http://*"
]
}, },
"permissions": [ "permissions": [
"core:default", "core:default",
"shell:allow-spawn", "shell:allow-spawn",
"shell:allow-open", "shell:allow-open",
"log:default",
{ {
"identifier": "http:default", "identifier": "http:default",
"allow": [ "allow": [
@ -55,4 +52,4 @@
}, },
"store:default" "store:default"
] ]
} }

View File

@ -35,7 +35,7 @@ pub fn get_app_configurations<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Ap
let default_data_folder = default_data_folder_path(app_handle.clone()); let default_data_folder = default_data_folder_path(app_handle.clone());
if !configuration_file.exists() { if !configuration_file.exists() {
println!( log::info!(
"App config not found, creating default config at {:?}", "App config not found, creating default config at {:?}",
configuration_file configuration_file
); );
@ -46,7 +46,7 @@ pub fn get_app_configurations<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Ap
&configuration_file, &configuration_file,
serde_json::to_string(&app_default_configuration).unwrap(), serde_json::to_string(&app_default_configuration).unwrap(),
) { ) {
eprintln!("Failed to create default config: {}", err); log::error!("Failed to create default config: {}", err);
} }
return app_default_configuration; return app_default_configuration;
@ -56,7 +56,7 @@ pub fn get_app_configurations<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Ap
Ok(content) => match serde_json::from_str::<AppConfiguration>(&content) { Ok(content) => match serde_json::from_str::<AppConfiguration>(&content) {
Ok(app_configurations) => app_configurations, Ok(app_configurations) => app_configurations,
Err(err) => { Err(err) => {
eprintln!( log::error!(
"Failed to parse app config, returning default config instead. Error: {}", "Failed to parse app config, returning default config instead. Error: {}",
err err
); );
@ -64,7 +64,7 @@ pub fn get_app_configurations<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Ap
} }
}, },
Err(err) => { Err(err) => {
eprintln!( log::error!(
"Failed to read app config, returning default config instead. Error: {}", "Failed to read app config, returning default config instead. Error: {}",
err err
); );
@ -79,7 +79,7 @@ pub fn update_app_configuration(
configuration: AppConfiguration, configuration: AppConfiguration,
) -> Result<(), String> { ) -> Result<(), String> {
let configuration_file = get_configuration_file_path(app_handle); let configuration_file = get_configuration_file_path(app_handle);
println!( log::info!(
"update_app_configuration, configuration_file: {:?}", "update_app_configuration, configuration_file: {:?}",
configuration_file configuration_file
); );
@ -136,7 +136,7 @@ pub fn read_theme(app_handle: tauri::AppHandle, theme_name: String) -> Result<St
#[tauri::command] #[tauri::command]
pub fn get_configuration_file_path<R: Runtime>(app_handle: tauri::AppHandle<R>) -> PathBuf { pub fn get_configuration_file_path<R: Runtime>(app_handle: tauri::AppHandle<R>) -> PathBuf {
let app_path = app_handle.path().app_data_dir().unwrap_or_else(|err| { let app_path = app_handle.path().app_data_dir().unwrap_or_else(|err| {
eprintln!( log::error!(
"Failed to get app data directory: {}. Using home directory instead.", "Failed to get app data directory: {}. Using home directory instead.",
err err
); );
@ -215,7 +215,7 @@ pub fn open_file_explorer(path: String) {
#[tauri::command] #[tauri::command]
pub fn install_extensions(app: AppHandle) { pub fn install_extensions(app: AppHandle) {
if let Err(err) = setup::install_extensions(app, true) { if let Err(err) = setup::install_extensions(app, true) {
eprintln!("Failed to install extensions: {}", err); log::error!("Failed to install extensions: {}", err);
} }
} }
@ -223,7 +223,7 @@ pub fn install_extensions(app: AppHandle) {
pub fn get_active_extensions(app: AppHandle) -> Vec<serde_json::Value> { pub fn get_active_extensions(app: AppHandle) -> Vec<serde_json::Value> {
let mut path = get_jan_extensions_path(app); let mut path = get_jan_extensions_path(app);
path.push("extensions.json"); path.push("extensions.json");
println!("get jan extensions, path: {:?}", path); log::info!("get jan extensions, path: {:?}", path);
let contents = fs::read_to_string(path); let contents = fs::read_to_string(path);
let contents: Vec<serde_json::Value> = match contents { let contents: Vec<serde_json::Value> = match contents {

View File

@ -73,7 +73,7 @@ pub fn readdir_sync<R: Runtime>(
} }
let path = resolve_path(app_handle, &args[0]); let path = resolve_path(app_handle, &args[0]);
println!("Reading directory: {:?}", path); log::error!("Reading directory: {:?}", path);
let entries = fs::read_dir(&path).map_err(|e| e.to_string())?; let entries = fs::read_dir(&path).map_err(|e| e.to_string())?;
let paths: Vec<String> = entries let paths: Vec<String> = entries
.filter_map(|entry| entry.ok()) .filter_map(|entry| entry.ok())

View File

@ -17,7 +17,7 @@ pub async fn run_mcp_commands(
app_path: String, app_path: String,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
) -> Result<(), String> { ) -> Result<(), String> {
println!( log::info!(
"Load MCP configs from {}", "Load MCP configs from {}",
app_path.clone() + "/mcp_config.json" app_path.clone() + "/mcp_config.json"
); );
@ -29,17 +29,20 @@ pub async fn run_mcp_commands(
.map_err(|e| format!("Failed to parse config: {}", e))?; .map_err(|e| format!("Failed to parse config: {}", e))?;
if let Some(server_map) = mcp_servers.get("mcpServers").and_then(Value::as_object) { if let Some(server_map) = mcp_servers.get("mcpServers").and_then(Value::as_object) {
println!("MCP Servers: {server_map:#?}"); log::info!("MCP Servers: {server_map:#?}");
for (name, config) in server_map { for (name, config) in server_map {
if let Some((command, args)) = extract_command_args(config) { if let Some((command, args)) = extract_command_args(config) {
let mut cmd = Command::new(command); let mut cmd = Command::new(command);
args.iter().filter_map(Value::as_str).for_each(|arg| { cmd.arg(arg); }); args.iter().filter_map(Value::as_str).for_each(|arg| {
cmd.arg(arg);
let service = ().serve(TokioChildProcess::new(&mut cmd).map_err(|e| e.to_string())?) });
.await
.map_err(|e| e.to_string())?; let service =
().serve(TokioChildProcess::new(&mut cmd).map_err(|e| e.to_string())?)
.await
.map_err(|e| e.to_string())?;
servers_state.lock().await.insert(name.clone(), service); servers_state.lock().await.insert(name.clone(), service);
} }
} }
@ -50,7 +53,7 @@ pub async fn run_mcp_commands(
for (_, service) in servers_map.iter() { for (_, service) in servers_map.iter() {
// Initialize // Initialize
let _server_info = service.peer_info(); let _server_info = service.peer_info();
println!("Connected to server: {_server_info:#?}"); log::info!("Connected to server: {_server_info:#?}");
} }
Ok(()) Ok(())
} }

View File

@ -1,6 +1,6 @@
pub mod cmd; pub mod cmd;
pub mod fs; pub mod fs;
pub mod mcp;
pub mod server;
pub mod setup; pub mod setup;
pub mod state; pub mod state;
pub mod server;
pub mod mcp;

View File

@ -6,7 +6,6 @@ use std::net::SocketAddr;
use std::sync::LazyLock; use std::sync::LazyLock;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::{debug, error, info};
/// Server handle type for managing the proxy server lifecycle /// Server handle type for managing the proxy server lifecycle
type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>; type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
@ -24,7 +23,7 @@ struct ProxyConfig {
/// Removes a prefix from a path, ensuring proper formatting /// Removes a prefix from a path, ensuring proper formatting
fn remove_prefix(path: &str, prefix: &str) -> String { fn remove_prefix(path: &str, prefix: &str) -> String {
debug!("Processing path: {}, removing prefix: {}", path, prefix); log::debug!("Processing path: {}, removing prefix: {}", path, prefix);
if !prefix.is_empty() && path.starts_with(prefix) { if !prefix.is_empty() && path.starts_with(prefix) {
let result = path[prefix.len()..].to_string(); let result = path[prefix.len()..].to_string();
@ -42,7 +41,6 @@ fn remove_prefix(path: &str, prefix: &str) -> String {
fn get_destination_path(original_path: &str, prefix: &str) -> String { fn get_destination_path(original_path: &str, prefix: &str) -> String {
let removed_prefix_path = remove_prefix(original_path, prefix); let removed_prefix_path = remove_prefix(original_path, prefix);
println!("Removed prefix path: {}", removed_prefix_path);
// Special paths don't need the /v1 prefix // Special paths don't need the /v1 prefix
if !original_path.contains(prefix) if !original_path.contains(prefix)
|| removed_prefix_path.contains("/healthz") || removed_prefix_path.contains("/healthz")
@ -81,7 +79,7 @@ async fn proxy_request(
// Build the outbound request // Build the outbound request
let upstream_url = build_upstream_url(&config.upstream, &path); let upstream_url = build_upstream_url(&config.upstream, &path);
debug!("Proxying request to: {}", upstream_url); log::debug!("Proxying request to: {}", upstream_url);
let mut outbound_req = client.request(req.method().clone(), &upstream_url); let mut outbound_req = client.request(req.method().clone(), &upstream_url);
@ -100,7 +98,7 @@ async fn proxy_request(
match outbound_req.body(req.into_body()).send().await { match outbound_req.body(req.into_body()).send().await {
Ok(response) => { Ok(response) => {
let status = response.status(); let status = response.status();
debug!("Received response with status: {}", status); log::debug!("Received response with status: {}", status);
let mut builder = Response::builder().status(status); let mut builder = Response::builder().status(status);
@ -113,7 +111,7 @@ async fn proxy_request(
match response.bytes().await { match response.bytes().await {
Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()), Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()),
Err(e) => { Err(e) => {
error!("Failed to read response body: {}", e); log::error!("Failed to read response body: {}", e);
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) .status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Error reading upstream response")) .body(Body::from("Error reading upstream response"))
@ -122,7 +120,7 @@ async fn proxy_request(
} }
} }
Err(e) => { Err(e) => {
error!("Proxy request failed: {}", e); log::error!("Proxy request failed: {}", e);
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY) .status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Upstream error: {}", e))) .body(Body::from(format!("Upstream error: {}", e)))
@ -175,12 +173,12 @@ pub async fn start_server(
// Create and start the server // Create and start the server
let server = Server::bind(&addr).serve(make_svc); let server = Server::bind(&addr).serve(make_svc);
info!("Proxy server started on http://{}", addr); log::info!("Proxy server started on http://{}", addr);
// Spawn server task // Spawn server task
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
if let Err(e) = server.await { if let Err(e) = server.await {
error!("Server error: {}", e); log::error!("Server error: {}", e);
return Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>); return Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>);
} }
Ok(()) Ok(())
@ -196,9 +194,9 @@ pub async fn stop_server() -> Result<(), Box<dyn std::error::Error + Send + Sync
if let Some(handle) = handle_guard.take() { if let Some(handle) = handle_guard.take() {
handle.abort(); handle.abort();
info!("Proxy server stopped"); log::info!("Proxy server stopped");
} else { } else {
debug!("No server was running"); log::debug!("No server was running");
} }
Ok(()) Ok(())

View File

@ -11,6 +11,7 @@ use tauri_plugin_shell::process::CommandEvent;
use tauri_plugin_shell::ShellExt; use tauri_plugin_shell::ShellExt;
use tauri_plugin_store::StoreExt; use tauri_plugin_store::StoreExt;
// MCP
use super::{ use super::{
cmd::{get_jan_data_folder_path, get_jan_extensions_path}, cmd::{get_jan_data_folder_path, get_jan_extensions_path},
mcp::run_mcp_commands, mcp::run_mcp_commands,
@ -39,7 +40,7 @@ pub fn install_extensions(app: tauri::AppHandle, force: bool) -> Result<(), Stri
// Attempt to remove extensions folder // Attempt to remove extensions folder
if extensions_path.exists() { if extensions_path.exists() {
fs::remove_dir_all(&extensions_path).unwrap_or_else(|_| { fs::remove_dir_all(&extensions_path).unwrap_or_else(|_| {
println!("Failed to remove existing extensions folder, it may not exist."); log::info!("Failed to remove existing extensions folder, it may not exist.");
}); });
} }
@ -66,7 +67,7 @@ pub fn install_extensions(app: tauri::AppHandle, force: bool) -> Result<(), Stri
let path = entry.path(); let path = entry.path();
if path.extension().map_or(false, |ext| ext == "tgz") { if path.extension().map_or(false, |ext| ext == "tgz") {
println!("Installing extension from {:?}", path); log::info!("Installing extension from {:?}", path);
let tar_gz = File::open(&path).map_err(|e| e.to_string())?; let tar_gz = File::open(&path).map_err(|e| e.to_string())?;
let gz_decoder = GzDecoder::new(tar_gz); let gz_decoder = GzDecoder::new(tar_gz);
let mut archive = Archive::new(gz_decoder); let mut archive = Archive::new(gz_decoder);
@ -132,7 +133,7 @@ pub fn install_extensions(app: tauri::AppHandle, force: bool) -> Result<(), Stri
extensions_list.push(new_extension); extensions_list.push(new_extension);
println!("Installed extension to {:?}", extension_dir); log::info!("Installed extension to {:?}", extension_dir);
} }
} }
fs::write( fs::write(
@ -186,7 +187,7 @@ pub fn setup_mcp(app: &App) {
let servers = state.mcp_servers.clone(); let servers = state.mcp_servers.clone();
tauri::async_runtime::spawn(async move { tauri::async_runtime::spawn(async move {
if let Err(e) = run_mcp_commands(app_path_str, servers).await { if let Err(e) = run_mcp_commands(app_path_str, servers).await {
eprintln!("Failed to run mcp commands: {}", e); log::error!("Failed to run mcp commands: {}", e);
} }
}); });
} }
@ -252,7 +253,7 @@ pub fn setup_sidecar(app: &App) -> Result<(), String> {
while let Some(event) = rx.recv().await { while let Some(event) = rx.recv().await {
if let CommandEvent::Stdout(line_bytes) = event { if let CommandEvent::Stdout(line_bytes) = event {
let line = String::from_utf8_lossy(&line_bytes); let line = String::from_utf8_lossy(&line_bytes);
println!("Outputs: {:?}", line) log::info!("Outputs: {:?}", line)
} }
} }
}); });
@ -268,7 +269,7 @@ pub fn setup_sidecar(app: &App) -> Result<(), String> {
fn copy_dir_all(src: PathBuf, dst: PathBuf) -> Result<(), String> { fn copy_dir_all(src: PathBuf, dst: PathBuf) -> Result<(), String> {
fs::create_dir_all(&dst).map_err(|e| e.to_string())?; fs::create_dir_all(&dst).map_err(|e| e.to_string())?;
println!("Copying from {:?} to {:?}", src, dst); log::info!("Copying from {:?} to {:?}", src, dst);
for entry in fs::read_dir(src).map_err(|e| e.to_string())? { for entry in fs::read_dir(src).map_err(|e| e.to_string())? {
let entry = entry.map_err(|e| e.to_string())?; let entry = entry.map_err(|e| e.to_string())?;
let ty = entry.file_type().map_err(|e| e.to_string())?; let ty = entry.file_type().map_err(|e| e.to_string())?;
@ -293,10 +294,10 @@ pub fn setup_engine_binaries(app: &App) -> Result<(), String> {
.join("resources"); .join("resources");
if let Err(e) = copy_dir_all(binaries_dir, app_data_dir.clone()) { if let Err(e) = copy_dir_all(binaries_dir, app_data_dir.clone()) {
eprintln!("Failed to copy binaries: {}", e); log::error!("Failed to copy binaries: {}", e);
} }
if let Err(e) = copy_dir_all(themes_dir, app_data_dir.clone()) { if let Err(e) = copy_dir_all(themes_dir, app_data_dir.clone()) {
eprintln!("Failed to copy themes: {}", e); log::error!("Failed to copy themes: {}", e);
} }
Ok(()) Ok(())
} }

View File

@ -7,7 +7,7 @@ use tokio::sync::Mutex;
#[derive(Default)] #[derive(Default)]
pub struct AppState { pub struct AppState {
pub app_token: Option<String>, pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
} }
pub fn generate_app_token() -> String { pub fn generate_app_token() -> String {
rand::thread_rng() rand::thread_rng()

View File

@ -1,5 +1,6 @@
mod core; mod core;
use core::{ use core::{
cmd::get_jan_data_folder_path,
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar}, setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
state::{generate_app_token, AppState}, state::{generate_app_token, AppState},
}; };
@ -11,8 +12,8 @@ use tokio::sync::Mutex;
#[cfg_attr(mobile, tauri::mobile_entry_point)] #[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() { pub fn run() {
tauri::Builder::default() tauri::Builder::default()
.plugin(tauri_plugin_store::Builder::new().build())
.plugin(tauri_plugin_http::init()) .plugin(tauri_plugin_http::init())
.plugin(tauri_plugin_store::Builder::new().build())
.plugin(tauri_plugin_shell::init()) .plugin(tauri_plugin_shell::init())
.invoke_handler(tauri::generate_handler![ .invoke_handler(tauri::generate_handler![
// FS commands - Deperecate soon // FS commands - Deperecate soon
@ -47,25 +48,25 @@ pub fn run() {
mcp_servers: Arc::new(Mutex::new(HashMap::new())), mcp_servers: Arc::new(Mutex::new(HashMap::new())),
}) })
.setup(|app| { .setup(|app| {
if cfg!(debug_assertions) { app.handle().plugin(
app.handle().plugin( tauri_plugin_log::Builder::default()
tauri_plugin_log::Builder::default() .targets([if cfg!(debug_assertions) {
.level(log::LevelFilter::Info) tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Stdout)
.build(), } else {
)?; tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Folder {
} path: get_jan_data_folder_path(app.handle().clone()).join("logs"),
file_name: Some("app".to_string()),
})
}])
.build(),
)?;
// Install extensions // Install extensions
if let Err(e) = setup::install_extensions(app.handle().clone(), false) { if let Err(e) = setup::install_extensions(app.handle().clone(), false) {
eprintln!("Failed to install extensions: {}", e); log::error!("Failed to install extensions: {}", e);
} }
setup_mcp(app); setup_mcp(app);
setup_sidecar(app).expect("Failed to setup sidecar"); setup_sidecar(app).expect("Failed to setup sidecar");
setup_engine_binaries(app).expect("Failed to setup engine binaries"); setup_engine_binaries(app).expect("Failed to setup engine binaries");
Ok(()) Ok(())
}) })
.on_window_event(|window, event| match event { .on_window_event(|window, event| match event {

View File

@ -1,13 +1,13 @@
{ {
"$schema": "../node_modules/@tauri-apps/cli/config.schema.json", "$schema": "https://schema.tauri.app/config/2",
"productName": "Jan", "productName": "Jan",
"version": "0.1.0", "version": "0.1.0",
"identifier": "jan.ai", "identifier": "jan.ai",
"build": { "build": {
"frontendDist": "../web/out", "frontendDist": "../web/out",
"devUrl": "http://localhost:3000", "devUrl": "http://localhost:3000",
"beforeDevCommand": "yarn dev:web:tauri", "beforeDevCommand": "IS_TAURI=true yarn dev:web",
"beforeBuildCommand": "yarn build:web" "beforeBuildCommand": "IS_TAURI=true yarn build:web"
}, },
"app": { "app": {
"macOSPrivateApi": true, "macOSPrivateApi": true,
@ -27,9 +27,10 @@
"csp": { "csp": {
"default-src": "'self' customprotocol: asset: http://localhost:* http://127.0.0.1:* ws://localhost:* ws://127.0.0.1:*", "default-src": "'self' customprotocol: asset: http://localhost:* http://127.0.0.1:* ws://localhost:* ws://127.0.0.1:*",
"connect-src": "ipc: http://ipc.localhost", "connect-src": "ipc: http://ipc.localhost",
"font-src": ["https://fonts.gstatic.com"], "font-src": ["https://fonts.gstatic.com blob: data:"],
"img-src": "'self' asset: http://asset.localhost blob: data:", "img-src": "'self' asset: http://asset.localhost blob: data:",
"style-src": "'unsafe-inline' 'self' https://fonts.googleapis.com" "style-src": "'unsafe-inline' 'self' https://fonts.googleapis.com",
"script-src": "'self' asset: $APPDATA/**.*"
}, },
"assetProtocol": { "assetProtocol": {
"enable": true, "enable": true,
@ -40,9 +41,18 @@
} }
} }
}, },
"plugins": {
"updater": {
"pubkey": "",
"endpoints": [
"https://github.com/menloresearch/jan/releases/latest/download/latest.json"
]
}
},
"bundle": { "bundle": {
"active": true, "active": true,
"targets": "all", "targets": "all",
"createUpdaterArtifacts": true,
"icon": [ "icon": [
"icons/32x32.png", "icons/32x32.png",
"icons/128x128.png", "icons/128x128.png",

View File

@ -1,7 +1,8 @@
import React, { useEffect, useState } from 'react' import React, { useEffect, useRef, useState } from 'react'
import { Button, Modal } from '@janhq/joi' import { Button, Modal } from '@janhq/joi'
import { check, Update } from '@tauri-apps/plugin-updater'
import { useAtom } from 'jotai' import { useAtom } from 'jotai'
import { useGetLatestRelease } from '@/hooks/useGetLatestRelease' import { useGetLatestRelease } from '@/hooks/useGetLatestRelease'
@ -16,6 +17,7 @@ const ModalAppUpdaterChangelog = () => {
const [appUpdateAvailable, setAppUpdateAvailable] = useAtom( const [appUpdateAvailable, setAppUpdateAvailable] = useAtom(
appUpdateAvailableAtom appUpdateAvailableAtom
) )
const updaterRef = useRef<Update | null>(null)
const [open, setOpen] = useState(appUpdateAvailable) const [open, setOpen] = useState(appUpdateAvailable)
@ -26,6 +28,17 @@ const ModalAppUpdaterChangelog = () => {
const beta = VERSION.includes('beta') const beta = VERSION.includes('beta')
const nightly = VERSION.includes('-') const nightly = VERSION.includes('-')
const checkForUpdate = async () => {
const update = await check()
if (update) {
setAppUpdateAvailable(true)
updaterRef.current = update
}
}
useEffect(() => {
checkForUpdate()
}, [])
const { release } = useGetLatestRelease(beta ? true : false) const { release } = useGetLatestRelease(beta ? true : false)
return ( return (
@ -73,8 +86,8 @@ const ModalAppUpdaterChangelog = () => {
</Button> </Button>
<Button <Button
autoFocus autoFocus
onClick={() => { onClick={async () => {
window.core?.api?.appUpdateDownload() await updaterRef.current?.downloadAndInstall((event) => {})
setOpen(false) setOpen(false)
setAppUpdateAvailable(false) setAppUpdateAvailable(false)
}} }}

View File

@ -22,7 +22,8 @@ export const useLoadTheme = () => {
const setNativeTheme = useCallback( const setNativeTheme = useCallback(
(nativeTheme: NativeThemeProps) => { (nativeTheme: NativeThemeProps) => {
if (!('setNativeThemeDark' in window.core.api)) return if (!window.electronAPI) return
if (nativeTheme === 'dark') { if (nativeTheme === 'dark') {
window?.core?.api?.setNativeThemeDark() window?.core?.api?.setNativeThemeDark()
setTheme('dark') setTheme('dark')
@ -58,21 +59,20 @@ export const useLoadTheme = () => {
setThemeOptions(themesOptions) setThemeOptions(themesOptions)
if (!selectedIdTheme.length) return setSelectedIdTheme('joi-light') if (!selectedIdTheme.length) return setSelectedIdTheme('joi-light')
const theme: Theme = JSON.parse( const theme: Theme = JSON.parse(
await window.core.api.readTheme({ await window.core.api.readTheme({
theme: selectedIdTheme, themeName: selectedIdTheme,
}) })
) )
setThemeData(theme) setThemeData(theme)
setNativeTheme(theme.nativeTheme) setNativeTheme(theme.nativeTheme)
applyTheme(theme) applyTheme(theme)
}, []) }, [selectedIdTheme])
const configureTheme = useCallback(async () => { const configureTheme = useCallback(async () => {
if (!themeData || !themeOptions) { if (!themeData || !themeOptions) {
await getThemes() getThemes()
} else { } else {
applyTheme(themeData) applyTheme(themeData)
} }

View File

@ -1,19 +1,30 @@
import { useEffect, useRef } from 'react' import { useEffect, useRef } from 'react'
import { import {
ChatCompletionRole,
MessageRequestType, MessageRequestType,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
ThreadMessage, ThreadMessage,
Model, Model,
ConversationalExtension, ConversationalExtension,
EngineManager,
ThreadAssistantInfo, ThreadAssistantInfo,
events,
MessageEvent,
ContentType,
EngineManager,
InferenceEngine, InferenceEngine,
} from '@janhq/core' } from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessageParam,
ChatCompletionRole,
ChatCompletionTool,
} from 'openai/resources/chat'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown' import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
import { import {
@ -46,6 +57,7 @@ import {
updateThreadAtom, updateThreadAtom,
updateThreadWaitingForResponseAtom, updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
import { ModelTool } from '@/types/model'
export const reloadModelAtom = atom(false) export const reloadModelAtom = atom(false)
@ -99,7 +111,7 @@ export default function useSendChatMessage() {
const newConvoData = Array.from(currentMessages) const newConvoData = Array.from(currentMessages)
let toSendMessage = newConvoData.pop() let toSendMessage = newConvoData.pop()
while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) { while (toSendMessage && toSendMessage?.role !== 'user') {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id) ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
@ -172,7 +184,16 @@ export default function useSendChatMessage() {
parameters: runtimeParams, parameters: runtimeParams,
}, },
activeThreadRef.current, activeThreadRef.current,
messages ?? currentMessages messages ?? currentMessages,
(await window.core.api.getTools())?.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
).addSystemMessage(activeAssistantRef.current?.instructions) ).addSystemMessage(activeAssistantRef.current?.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload) requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
@ -228,10 +249,125 @@ export default function useSendChatMessage() {
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
// Request for inference if (requestBuilder.tools && requestBuilder.tools.length) {
EngineManager.instance() let isDone = false
.get(InferenceEngine.cortex) const openai = new OpenAI({
?.inference(requestBuilder.build()) apiKey: await window.core.api.appToken(),
baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
})
while (!isDone) {
const data = requestBuilder.build()
const response = await openai.chat.completions.create({
messages: (data.messages ?? []).map((e) => {
return {
role: e.role as ChatCompletionRole,
content: e.content,
}
}) as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: false,
})
if (response.choices[0]?.message.content) {
const newMessage: ThreadMessage = {
id: ulid(),
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: response.choices[0].message.role as any,
content: [
{
type: ContentType.Text,
text: {
value: response.choices[0].message.content
? (response.choices[0].message.content as any)
: '',
annotations: [],
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(
(response.choices[0].message.content as any) ?? ''
)
events.emit(MessageEvent.OnMessageUpdate, newMessage)
}
if (response.choices[0]?.message.tool_calls) {
for (const toolCall of response.choices[0].message.tool_calls) {
const id = ulid()
const toolMessage: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`,
annotations: [],
},
},
],
status: 'pending' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) {
console.error(result.error)
break
}
const message: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: 'assistant' as any,
content: [
{
type: ContentType.Text,
text: {
value:
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
(result.content[0]?.text ?? ''),
annotations: [],
},
},
],
status: 'ready' as any,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
requestBuilder.pushMessage('Go for the next step')
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
isDone =
!response.choices[0]?.message.tool_calls ||
!response.choices[0]?.message.tool_calls.length
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
}
// Reset states // Reset states
setReloadModel(false) setReloadModel(false)

View File

@ -22,6 +22,7 @@
"@tanstack/react-virtual": "^3.10.9", "@tanstack/react-virtual": "^3.10.9",
"@tauri-apps/api": "^2.4.0", "@tauri-apps/api": "^2.4.0",
"@tauri-apps/plugin-http": "^2.4.2", "@tauri-apps/plugin-http": "^2.4.2",
"@tauri-apps/plugin-updater": "~2",
"@uppy/core": "^4.3.0", "@uppy/core": "^4.3.0",
"@uppy/react": "^4.0.4", "@uppy/react": "^4.0.4",
"@uppy/xhr-upload": "^4.2.3", "@uppy/xhr-upload": "^4.2.3",
@ -36,6 +37,7 @@
"marked": "^9.1.2", "marked": "^9.1.2",
"next": "14.2.3", "next": "14.2.3",
"next-themes": "^0.2.1", "next-themes": "^0.2.1",
"openai": "^4.90.0",
"postcss": "8.4.31", "postcss": "8.4.31",
"postcss-url": "10.1.3", "postcss-url": "10.1.3",
"posthog-js": "^1.194.6", "posthog-js": "^1.194.6",

View File

@ -1,4 +1,4 @@
import { EngineManager, ToolManager } from '@janhq/core' import { EngineManager } from '@janhq/core'
import { appService } from './appService' import { appService } from './appService'
import { EventEmitter } from './eventsService' import { EventEmitter } from './eventsService'
@ -16,7 +16,6 @@ export const setupCoreServices = () => {
window.core = { window.core = {
events: new EventEmitter(), events: new EventEmitter(),
engineManager: new EngineManager(), engineManager: new EngineManager(),
toolManager: new ToolManager(),
api: { api: {
...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)), ...(window.electronAPI ?? (IS_TAURI ? tauriAPI : restAPI)),
...appService, ...appService,

View File

@ -2,3 +2,9 @@
* ModelParams types * ModelParams types
*/ */
export type ModelParams = ModelRuntimeParams | ModelSettingParams export type ModelParams = ModelRuntimeParams | ModelSettingParams
export type ModelTool = {
name: string
description: string
inputSchema: string
}

View File

@ -6,6 +6,7 @@ import {
ChatCompletionRole, ChatCompletionRole,
MessageRequest, MessageRequest,
MessageRequestType, MessageRequestType,
MessageTool,
ModelInfo, ModelInfo,
Thread, Thread,
ThreadMessage, ThreadMessage,
@ -22,12 +23,14 @@ export class MessageRequestBuilder {
messages: ChatCompletionMessage[] messages: ChatCompletionMessage[]
model: ModelInfo model: ModelInfo
thread: Thread thread: Thread
tools?: MessageTool[]
constructor( constructor(
type: MessageRequestType, type: MessageRequestType,
model: ModelInfo, model: ModelInfo,
thread: Thread, thread: Thread,
messages: ThreadMessage[] messages: ThreadMessage[],
tools?: MessageTool[]
) { ) {
this.msgId = ulid() this.msgId = ulid()
this.type = type this.type = type
@ -39,14 +42,20 @@ export class MessageRequestBuilder {
role: msg.role, role: msg.role,
content: msg.content[0]?.text?.value ?? '.', content: msg.content[0]?.text?.value ?? '.',
})) }))
this.tools = tools
} }
pushAssistantMessage(message: string) {
this.messages = [
...this.messages,
{
role: ChatCompletionRole.Assistant,
content: message,
},
]
}
// Chainable // Chainable
pushMessage( pushMessage(message: string, base64Blob?: string, fileInfo?: FileInfo) {
message: string,
base64Blob: string | undefined,
fileInfo?: FileInfo
) {
if (base64Blob && fileInfo?.type === 'pdf') if (base64Blob && fileInfo?.type === 'pdf')
return this.addDocMessage(message, fileInfo?.name) return this.addDocMessage(message, fileInfo?.name)
else if (base64Blob && fileInfo?.type === 'image') { else if (base64Blob && fileInfo?.type === 'image') {
@ -167,6 +176,7 @@ export class MessageRequestBuilder {
messages: this.normalizeMessages(this.messages), messages: this.normalizeMessages(this.messages),
model: this.model, model: this.model,
thread: this.thread, thread: this.thread,
tools: this.tools,
} }
} }
} }