From 45b845857059b80a5b2a970d589360dce27d288d Mon Sep 17 00:00:00 2001 From: David Date: Mon, 14 Apr 2025 00:20:33 +0700 Subject: [PATCH 01/12] fix: added border for search textfield --- web/containers/ModelSearch/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/containers/ModelSearch/index.tsx b/web/containers/ModelSearch/index.tsx index aa40f8331..ceecacd39 100644 --- a/web/containers/ModelSearch/index.tsx +++ b/web/containers/ModelSearch/index.tsx @@ -83,7 +83,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => { value={searchText} clearable={searchText.length > 0} onClear={onClear} - className="border-0 bg-[hsla(var(--app-bg))]" + className="bg-[hsla(var(--app-bg))]" onClick={() => { onSearchLocal?.(inputRef.current?.value ?? '') }} From 57786e5e45ce1139814e92e3951ea07a8ba405ce Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 14 Apr 2025 15:23:31 +0700 Subject: [PATCH 02/12] Merge pull request #4900 from menloresearch/feat/jan-ui-with-tool-use feat: jan UI with Tool use UX --- core/src/types/inference/inferenceEntity.ts | 4 + core/src/types/message/messageEntity.ts | 2 + web/containers/Providers/ModelHandler.tsx | 106 +++--- web/helpers/atoms/ChatMessage.atom.ts | 3 + web/hooks/useSendChatMessage.ts | 323 ++++++++++++------ .../Hub/ModelFilter/ModelSize/index.tsx | 5 +- .../ThreadCenterPanel/ChatInput/index.tsx | 78 ++++- .../TextMessage/ToolCallBlock.tsx | 57 ++++ .../ThreadCenterPanel/TextMessage/index.tsx | 91 +++-- web/utils/messageRequestBuilder.ts | 48 +-- 10 files changed, 485 insertions(+), 232 deletions(-) create mode 100644 web/screens/Thread/ThreadCenterPanel/TextMessage/ToolCallBlock.tsx diff --git a/core/src/types/inference/inferenceEntity.ts b/core/src/types/inference/inferenceEntity.ts index c37e3b079..ac2e48d32 100644 --- a/core/src/types/inference/inferenceEntity.ts +++ b/core/src/types/inference/inferenceEntity.ts @@ -7,6 +7,7 @@ export enum ChatCompletionRole { System = 'system', Assistant = 'assistant', User = 'user', + Tool = 'tool', } /** @@ -18,6 +19,9 @@ export type ChatCompletionMessage = { content?: ChatCompletionMessageContent /** The role of the author of this message. **/ role: ChatCompletionRole + type?: string + output?: string + tool_call_id?: string } export type ChatCompletionMessageContent = diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 280ce75a3..20979c68e 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -36,6 +36,8 @@ export type ThreadMessage = { type?: string /** The error code which explain what error type. Used in conjunction with MessageStatus.Error */ error_code?: ErrorCode + + tool_call_id?: string } /** diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 786dbd4f0..a1fffa011 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -114,7 +114,7 @@ export default function ModelHandler() { const onNewMessageResponse = useCallback( async (message: ThreadMessage) => { - if (message.type === MessageRequestType.Thread) { + if (message.type !== MessageRequestType.Summary) { addNewMessage(message) } }, @@ -129,35 +129,20 @@ export default function ModelHandler() { const updateThreadTitle = useCallback( (message: ThreadMessage) => { // Update only when it's finished - if (message.status !== MessageStatus.Ready) { - return - } + if (message.status !== MessageStatus.Ready) return const thread = threadsRef.current?.find((e) => e.id == message.thread_id) - if (!thread) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Thread not found!` - ) - return - } - let messageContent = message.content[0]?.text?.value - if (!messageContent) { - console.warn( - `Failed to update title for thread ${message.thread_id}: Responded content is null!` - ) - return - } + if (!thread || !messageContent) return // No new line character is presented in the title // And non-alphanumeric characters should be removed - if (messageContent.includes('\n')) { + if (messageContent.includes('\n')) messageContent = messageContent.replace(/\n/g, ' ') - } + const match = messageContent.match(/<\/think>(.*)$/) - if (match) { - messageContent = match[1] - } + if (match) messageContent = match[1] + // Remove non-alphanumeric characters const cleanedMessageContent = messageContent .replace(/[^\p{L}\s]+/gu, '') @@ -193,18 +178,13 @@ export default function ModelHandler() { const updateThreadMessage = useCallback( (message: ThreadMessage) => { - if ( - messageGenerationSubscriber.current && - message.thread_id === activeThreadRef.current?.id && - !messageGenerationSubscriber.current!.thread_id - ) { - updateMessage( - message.id, - message.thread_id, - message.content, - message.status - ) - } + updateMessage( + message.id, + message.thread_id, + message.content, + message.metadata, + message.status + ) if (message.status === MessageStatus.Pending) { if (message.content.length) { @@ -243,16 +223,19 @@ export default function ModelHandler() { engines && isLocalEngine(engines, activeModelRef.current.engine) ) { - ;(async () => { - if ( - !(await extensionManager - .get(ExtensionTypeEnum.Model) - ?.isModelLoaded(activeModelRef.current?.id as string)) - ) { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: undefined }) - } - })() + extensionManager + .get(ExtensionTypeEnum.Model) + ?.isModelLoaded(activeModelRef.current?.id as string) + .then((isLoaded) => { + if (!isLoaded) { + setActiveModel(undefined) + setStateModel({ + state: 'start', + loading: false, + model: undefined, + }) + } + }) } // Mark the thread as not waiting for response updateThreadWaiting(message.thread_id, false) @@ -296,19 +279,10 @@ export default function ModelHandler() { error_code: message.error_code, } } - ;(async () => { - const updatedMessage = await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.createMessage(message) - .catch(() => undefined) - if (updatedMessage) { - deleteMessage(message.id) - addNewMessage(updatedMessage) - setTokenSpeed((prev) => - prev ? { ...prev, message: updatedMessage.id } : undefined - ) - } - })() + + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(message) // Attempt to generate the title of the Thread when needed generateThreadTitle(message, thread) @@ -319,25 +293,21 @@ export default function ModelHandler() { const onMessageResponseUpdate = useCallback( (message: ThreadMessage) => { - switch (message.type) { - case MessageRequestType.Summary: - updateThreadTitle(message) - break - default: - updateThreadMessage(message) - break - } + if (message.type === MessageRequestType.Summary) + updateThreadTitle(message) + else updateThreadMessage(message) }, [updateThreadMessage, updateThreadTitle] ) const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { // If this is the first ever prompt in the thread - if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle) + if ( + !activeModelRef.current || + (thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle + ) return - if (!activeModelRef.current) return - // Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp if ( activeModelRef.current?.engine !== InferenceEngine.cortex && diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 1847aa422..faae6e298 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -165,6 +165,7 @@ export const updateMessageAtom = atom( id: string, conversationId: string, text: ThreadContent[], + metadata: Record | undefined, status: MessageStatus ) => { const messages = get(chatMessages)[conversationId] ?? [] @@ -172,6 +173,7 @@ export const updateMessageAtom = atom( if (message) { message.content = text message.status = status + message.metadata = metadata const updatedMessages = [...messages] const newData: Record = { @@ -192,6 +194,7 @@ export const updateMessageAtom = atom( created_at: Date.now() / 1000, completed_at: Date.now() / 1000, object: 'thread.message', + metadata: metadata, }) } } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 3242b085c..a76e01325 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -26,6 +26,7 @@ import { ChatCompletionTool, } from 'openai/resources/chat' +import { Stream } from 'openai/streaming' import { ulid } from 'ulidx' import { modelDropdownStateAtom } from '@/containers/ModelDropdown' @@ -258,111 +259,63 @@ export default function useSendChatMessage() { baseURL: `${API_BASE_URL}/v1`, dangerouslyAllowBrowser: true, }) + let parentMessageId: string | undefined while (!isDone) { + let messageId = ulid() + if (!parentMessageId) { + parentMessageId = ulid() + messageId = parentMessageId + } const data = requestBuilder.build() + const message: ThreadMessage = { + id: messageId, + object: 'message', + thread_id: activeThreadRef.current.id, + assistant_id: activeAssistantRef.current.assistant_id, + role: ChatCompletionRole.Assistant, + content: [], + metadata: { + ...(messageId !== parentMessageId + ? { parent_id: parentMessageId } + : {}), + }, + status: MessageStatus.Pending, + created_at: Date.now() / 1000, + completed_at: Date.now() / 1000, + } + events.emit(MessageEvent.OnMessageResponse, message) const response = await openai.chat.completions.create({ - messages: (data.messages ?? []).map((e) => { - return { - role: e.role as OpenAIChatCompletionRole, - content: e.content, - } - }) as ChatCompletionMessageParam[], + messages: requestBuilder.messages as ChatCompletionMessageParam[], model: data.model?.id ?? '', tools: data.tools as ChatCompletionTool[], - stream: false, + stream: data.model?.parameters?.stream ?? false, + tool_choice: 'auto', }) - if (response.choices[0]?.message.content) { - const newMessage: ThreadMessage = { - id: ulid(), - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: response.choices[0].message.role as ChatCompletionRole, - content: [ - { - type: ContentType.Text, - text: { - value: response.choices[0].message.content - ? response.choices[0].message.content - : '', - annotations: [], - }, + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = [ + { + type: ContentType.Text, + text: { + value: '', + annotations: [], }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage( - response.choices[0].message.content ?? '' + }, + ] + } + if (data.model?.parameters?.stream) + isDone = await processStreamingResponse( + response as Stream, + requestBuilder, + message + ) + else { + isDone = await processNonStreamingResponse( + response as OpenAI.Chat.Completions.ChatCompletion, + requestBuilder, + message ) - events.emit(MessageEvent.OnMessageUpdate, newMessage) } - - if (response.choices[0]?.message.tool_calls) { - for (const toolCall of response.choices[0].message.tool_calls) { - const id = ulid() - const toolMessage: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`, - annotations: [], - }, - }, - ], - status: MessageStatus.Pending, - created_at: Date.now(), - completed_at: Date.now(), - } - events.emit(MessageEvent.OnMessageUpdate, toolMessage) - const result = await window.core.api.callTool({ - toolName: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments), - }) - if (result.error) { - console.error(result.error) - break - } - const message: ThreadMessage = { - id: id, - object: 'message', - thread_id: activeThreadRef.current.id, - assistant_id: activeAssistantRef.current.assistant_id, - attachments: [], - role: ChatCompletionRole.Assistant, - content: [ - { - type: ContentType.Text, - text: { - value: - `Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}` + - (result.content[0]?.text ?? ''), - annotations: [], - }, - }, - ], - status: MessageStatus.Ready, - created_at: Date.now(), - completed_at: Date.now(), - } - requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '') - requestBuilder.pushMessage('Go for the next step') - events.emit(MessageEvent.OnMessageUpdate, message) - } - } - - isDone = - !response.choices[0]?.message.tool_calls || - !response.choices[0]?.message.tool_calls.length } } else { // Request for inference @@ -376,6 +329,184 @@ export default function useSendChatMessage() { setEngineParamsUpdate(false) } + const processNonStreamingResponse = async ( + response: OpenAI.Chat.Completions.ChatCompletion, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Handle tool calls in the response + const toolCalls: ChatCompletionMessageToolCall[] = + response.choices[0]?.message?.tool_calls ?? [] + const content = response.choices[0].message?.content + message.content = [ + { + type: ContentType.Text, + text: { + value: content ?? '', + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + content ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const processStreamingResponse = async ( + response: Stream, + requestBuilder: MessageRequestBuilder, + message: ThreadMessage + ): Promise => { + // Variables to track and accumulate streaming content + let currentToolCall: { + id: string + function: { name: string; arguments: string } + } | null = null + let accumulatedContent = '' + const toolCalls: ChatCompletionMessageToolCall[] = [] + // Process the streaming chunks + for await (const chunk of response) { + // Handle tool calls in the chunk + if (chunk.choices[0]?.delta?.tool_calls) { + const deltaToolCalls = chunk.choices[0].delta.tool_calls + + // Handle the beginning of a new tool call + if ( + deltaToolCalls[0]?.index !== undefined && + deltaToolCalls[0]?.function + ) { + const index = deltaToolCalls[0].index + + // Create new tool call if this is the first chunk for it + if (!toolCalls[index]) { + toolCalls[index] = { + id: deltaToolCalls[0]?.id || '', + function: { + name: deltaToolCalls[0]?.function?.name || '', + arguments: deltaToolCalls[0]?.function?.arguments || '', + }, + type: 'function', + } + currentToolCall = toolCalls[index] + } else { + // Continuation of existing tool call + currentToolCall = toolCalls[index] + + // Append to function name or arguments if they exist in this chunk + if (deltaToolCalls[0]?.function?.name) { + currentToolCall!.function.name += deltaToolCalls[0].function.name + } + + if (deltaToolCalls[0]?.function?.arguments) { + currentToolCall!.function.arguments += + deltaToolCalls[0].function.arguments + } + } + } + } + + // Handle regular content in the chunk + if (chunk.choices[0]?.delta?.content) { + const content = chunk.choices[0].delta.content + accumulatedContent += content + + message.content = [ + { + type: ContentType.Text, + text: { + value: accumulatedContent, + annotations: [], + }, + }, + ] + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + + await postMessageProcessing( + toolCalls ?? [], + requestBuilder, + message, + accumulatedContent ?? '' + ) + return !toolCalls || !toolCalls.length + } + + const postMessageProcessing = async ( + toolCalls: ChatCompletionMessageToolCall[], + requestBuilder: MessageRequestBuilder, + message: ThreadMessage, + content: string + ) => { + requestBuilder.pushAssistantMessage({ + content, + role: 'assistant', + refusal: null, + tool_calls: toolCalls, + }) + + // Handle completed tool calls + if (toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolId = ulid() + const toolCallsMetadata = + message.metadata?.tool_calls && + Array.isArray(message.metadata?.tool_calls) + ? message.metadata?.tool_calls + : [] + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: undefined, + state: 'pending', + }, + ], + } + events.emit(MessageEvent.OnMessageUpdate, message) + + const result = await window.core.api.callTool({ + toolName: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments), + }) + if (result.error) break + + message.metadata = { + ...(message.metadata ?? {}), + tool_calls: [ + ...toolCallsMetadata, + { + tool: { + ...toolCall, + id: toolId, + }, + response: result, + state: 'ready', + }, + ], + } + + requestBuilder.pushToolMessage( + result.content[0]?.text ?? '', + toolCall.id + ) + events.emit(MessageEvent.OnMessageUpdate, message) + } + } + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) + } + return { sendChatMessage, resendChatMessage, diff --git a/web/screens/Hub/ModelFilter/ModelSize/index.tsx b/web/screens/Hub/ModelFilter/ModelSize/index.tsx index b95d57f8b..a8d411e33 100644 --- a/web/screens/Hub/ModelFilter/ModelSize/index.tsx +++ b/web/screens/Hub/ModelFilter/ModelSize/index.tsx @@ -1,9 +1,8 @@ -import { useRef, useState } from 'react' +import { useState } from 'react' -import { Slider, Input, Tooltip } from '@janhq/joi' +import { Slider, Input } from '@janhq/joi' import { atom, useAtom } from 'jotai' -import { InfoIcon } from 'lucide-react' export const hubModelSizeMinAtom = atom(0) export const hubModelSizeMaxAtom = atom(100) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index 23b137be8..5404fd85e 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -3,7 +3,15 @@ import { useEffect, useRef, useState } from 'react' import { InferenceEngine } from '@janhq/core' -import { TextArea, Button, Tooltip, useClickOutside, Badge } from '@janhq/joi' +import { + TextArea, + Button, + Tooltip, + useClickOutside, + Badge, + Modal, + ModalClose, +} from '@janhq/joi' import { useAtom, useAtomValue } from 'jotai' import { FileTextIcon, @@ -13,6 +21,7 @@ import { SettingsIcon, ChevronUpIcon, Settings2Icon, + WrenchIcon, } from 'lucide-react' import { twMerge } from 'tailwind-merge' @@ -45,6 +54,7 @@ import { isBlockingSendAtom, } from '@/helpers/atoms/Thread.atom' import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom' +import { ModelTool } from '@/types/model' const ChatInput = () => { const activeThread = useAtomValue(activeThreadAtom) @@ -69,6 +79,8 @@ const ChatInput = () => { const isBlockingSend = useAtomValue(isBlockingSendAtom) const activeAssistant = useAtomValue(activeAssistantAtom) const { stopInference } = useActiveModel() + const [tools, setTools] = useState([]) + const [showToolsModal, setShowToolsModal] = useState(false) const upload = uploader() const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( @@ -92,6 +104,12 @@ const ChatInput = () => { } }, [activeSettingInputBox, selectedModel, setActiveSettingInputBox]) + useEffect(() => { + window.core?.api?.getTools().then((data: ModelTool[]) => { + setTools(data) + }) + }, []) + const onStopInferenceClick = async () => { stopInference() } @@ -136,6 +154,8 @@ const ChatInput = () => { } } + console.log(tools) + return (
{renderPreview(fileUpload)} @@ -385,6 +405,62 @@ const ChatInput = () => { className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]" /> + {tools && tools.length > 0 && ( + <> + setShowToolsModal(true)} + > + + {tools.length} + + + +
+ Jan can use tools provided by specialized servers using + Model Context Protocol.{' '} + + Learn more about MCP + +
+ {tools.map((tool: any) => ( +
+ +
+
{tool.name}
+
+ {tool.description} +
+
+
+ ))} +
+ } + /> + + )} {selectedModel && ( + + + {isExpanded && ( +
+ {result.trim()} +
+ )} + + + ) +} + +export default ToolCallBlock diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx index 7103f7914..522b27a0d 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx @@ -18,6 +18,8 @@ import ImageMessage from './ImageMessage' import { MarkdownTextMessage } from './MarkdownTextMessage' import ThinkingBlock from './ThinkingBlock' +import ToolCallBlock from './ToolCallBlock' + import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { editMessageAtom, @@ -65,57 +67,62 @@ const MessageContainer: React.FC< [props.content] ) - const attachedFile = useMemo(() => 'attachments' in props, [props]) + const attachedFile = useMemo( + () => 'attachments' in props && props.attachments?.length, + [props] + ) return (
-
- {!isUser && !isSystem && } - + {!(props.metadata && 'parent_id' in props.metadata) && (
- {!isUser && ( - <> - {props.metadata && 'model' in props.metadata - ? (props.metadata?.model as string) - : props.isCurrentMessage - ? selectedModel?.name - : (activeAssistant?.assistant_name ?? props.role)} - - )} + {!isUser && !isSystem && } + +
+ {!isUser && ( + <> + {props.metadata && 'model' in props.metadata + ? (props.metadata?.model as string) + : props.isCurrentMessage + ? selectedModel?.name + : (activeAssistant?.assistant_name ?? props.role)} + + )} +
+ +

+ {props.created_at && + displayDate(props.created_at ?? Date.now() / 1000)} +

+ )} -

- {props.created_at && - displayDate(props.created_at ?? Date.now() / 1000)} -

-
- -
+
@@ -179,6 +186,22 @@ const MessageContainer: React.FC< />
)} + {props.metadata && + 'tool_calls' in props.metadata && + Array.isArray(props.metadata.tool_calls) && + props.metadata.tool_calls.length && ( + <> + {props.metadata.tool_calls.map((toolCall) => ( + + ))} + + )}
diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts index 8733aff2c..fe827da62 100644 --- a/web/utils/messageRequestBuilder.ts +++ b/web/utils/messageRequestBuilder.ts @@ -11,6 +11,7 @@ import { Thread, ThreadMessage, } from '@janhq/core' +import { ChatCompletionMessage as OAIChatCompletionMessage } from 'openai/resources/chat' import { ulid } from 'ulidx' import { Stack } from '@/utils/Stack' @@ -45,12 +46,26 @@ export class MessageRequestBuilder { this.tools = tools } - pushAssistantMessage(message: string) { + pushAssistantMessage(message: OAIChatCompletionMessage) { + const { content, refusal, ...rest } = message + const normalizedMessage = { + ...rest, + ...(content ? { content } : {}), + ...(refusal ? { refusal } : {}), + } + this.messages = [ + ...this.messages, + normalizedMessage as ChatCompletionMessage, + ] + } + + pushToolMessage(message: string, toolCallId: string) { this.messages = [ ...this.messages, { - role: ChatCompletionRole.Assistant, + role: ChatCompletionRole.Tool, content: message, + tool_call_id: toolCallId, }, ] } @@ -140,40 +155,13 @@ export class MessageRequestBuilder { return this } - normalizeMessages = ( - messages: ChatCompletionMessage[] - ): ChatCompletionMessage[] => { - const stack = new Stack() - for (const message of messages) { - if (stack.isEmpty()) { - stack.push(message) - continue - } - const topMessage = stack.peek() - - if (message.role === topMessage.role) { - // add an empty message - stack.push({ - role: - topMessage.role === ChatCompletionRole.User - ? ChatCompletionRole.Assistant - : ChatCompletionRole.User, - content: '.', // some model requires not empty message - }) - } - stack.push(message) - } - - return stack.reverseOutput() - } - build(): MessageRequest { return { id: this.msgId, type: this.type, attachments: [], threadId: this.thread.id, - messages: this.normalizeMessages(this.messages), + messages: this.messages, model: this.model, thread: this.thread, tools: this.tools, From 31f707397712b1876e87dbb80b2df1be13b3d509 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 14 Apr 2025 15:33:53 +0700 Subject: [PATCH 03/12] chore: missing import --- web/hooks/useSendChatMessage.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index a76e01325..bfe3a601f 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -24,6 +24,7 @@ import { ChatCompletionMessageParam, ChatCompletionRole as OpenAIChatCompletionRole, ChatCompletionTool, + ChatCompletionMessageToolCall, } from 'openai/resources/chat' import { Stream } from 'openai/streaming' From b252f716d746c9072abd67ba9888083a0e939ba5 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 15 Apr 2025 18:57:43 +0700 Subject: [PATCH 04/12] refactor: Jan manages threads for a better performance (#4912) * refactor: Jan manages threads for a better performance * test: add tests --- .../conversational-extension/package.json | 4 +- .../conversational-extension/src/index.ts | 91 +-- src-tauri/Cargo.toml | 2 + src-tauri/src/core/cmd.rs | 4 + src-tauri/src/core/fs.rs | 13 +- src-tauri/src/core/mod.rs | 2 + src-tauri/src/core/threads.rs | 598 ++++++++++++++++++ src-tauri/src/core/utils/mod.rs | 48 ++ src-tauri/src/lib.rs | 14 +- .../ThreadCenterPanel/TextMessage/index.tsx | 5 +- web/services/tauriService.ts | 11 + 11 files changed, 703 insertions(+), 89 deletions(-) create mode 100644 src-tauri/src/core/threads.rs create mode 100644 src-tauri/src/core/utils/mod.rs diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json index a5224b99b..693adf6d6 100644 --- a/extensions/conversational-extension/package.json +++ b/extensions/conversational-extension/package.json @@ -23,9 +23,7 @@ "typescript": "^5.7.2" }, "dependencies": { - "@janhq/core": "../../core/package.tgz", - "ky": "^1.7.2", - "p-queue": "^8.0.1" + "@janhq/core": "../../core/package.tgz" }, "engines": { "node": ">=18.0.0" diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index e2e068939..720291d88 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -4,40 +4,12 @@ import { ThreadAssistantInfo, ThreadMessage, } from '@janhq/core' -import ky, { KyInstance } from 'ky' - -type ThreadList = { - data: Thread[] -} - -type MessageList = { - data: ThreadMessage[] -} /** * JSONConversationalExtension is a ConversationalExtension implementation that provides * functionality for managing threads. */ export default class CortexConversationalExtension extends ConversationalExtension { - api?: KyInstance - /** - * Get the API instance - * @returns - */ - async apiInstance(): Promise { - if (this.api) return this.api - const apiKey = (await window.core?.api.appToken()) - this.api = ky.extend({ - prefixUrl: API_URL, - headers: apiKey - ? { - Authorization: `Bearer ${apiKey}`, - } - : {}, - retry: 10, - }) - return this.api - } /** * Called when the extension is loaded. */ @@ -54,12 +26,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * Returns a Promise that resolves to an array of Conversation objects. */ async listThreads(): Promise { - return this.apiInstance().then((api) => - api - .get('v1/threads?limit=-1') - .json() - .then((e) => e.data) - ) as Promise + return window.core.api.listThreads() } /** @@ -67,9 +34,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param thread The Thread object to save. */ async createThread(thread: Thread): Promise { - return this.apiInstance().then((api) => - api.post('v1/threads', { json: thread }).json() - ) as Promise + return window.core.api.createThread({ thread }) } /** @@ -77,10 +42,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param thread The Thread object to save. */ async modifyThread(thread: Thread): Promise { - return this.apiInstance() - .then((api) => api.patch(`v1/threads/${thread.id}`, { json: thread })) - - .then() + return window.core.api.modifyThread({ thread }) } /** @@ -88,9 +50,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @param threadId The ID of the thread to delete. */ async deleteThread(threadId: string): Promise { - return this.apiInstance() - .then((api) => api.delete(`v1/threads/${threadId}`)) - .then() + return window.core.api.deleteThread({ threadId }) } /** @@ -99,13 +59,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves when the message has been added. */ async createMessage(message: ThreadMessage): Promise { - return this.apiInstance().then((api) => - api - .post(`v1/threads/${message.thread_id}/messages`, { - json: message, - }) - .json() - ) as Promise + return window.core.api.createMessage({ message }) } /** @@ -114,13 +68,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns */ async modifyMessage(message: ThreadMessage): Promise { - return this.apiInstance().then((api) => - api - .patch(`v1/threads/${message.thread_id}/messages/${message.id}`, { - json: message, - }) - .json() - ) as Promise + return window.core.api.modifyMessage({ message }) } /** @@ -130,9 +78,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves when the message has been successfully deleted. */ async deleteMessage(threadId: string, messageId: string): Promise { - return this.apiInstance() - .then((api) => api.delete(`v1/threads/${threadId}/messages/${messageId}`)) - .then() + return window.core.api.deleteMessage({ threadId, messageId }) } /** @@ -141,12 +87,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * @returns A Promise that resolves to an array of ThreadMessage objects. */ async listMessages(threadId: string): Promise { - return this.apiInstance().then((api) => - api - .get(`v1/threads/${threadId}/messages?order=asc&limit=-1`) - .json() - .then((e) => e.data) - ) as Promise + return window.core.api.listMessages({ threadId }) } /** @@ -156,9 +97,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * the details of the assistant associated with the specified thread. */ async getThreadAssistant(threadId: string): Promise { - return this.apiInstance().then((api) => - api.get(`v1/assistants/${threadId}?limit=-1`).json() - ) as Promise + return window.core.api.getThreadAssistant({ threadId }) } /** * Creates a new assistant for the specified thread. @@ -170,11 +109,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi threadId: string, assistant: ThreadAssistantInfo ): Promise { - return this.apiInstance().then((api) => - api - .post(`v1/assistants/${threadId}`, { json: assistant }) - .json() - ) as Promise + return window.core.api.createThreadAssistant(threadId, assistant) } /** @@ -187,10 +122,6 @@ export default class CortexConversationalExtension extends ConversationalExtensi threadId: string, assistant: ThreadAssistantInfo ): Promise { - return this.apiInstance().then((api) => - api - .patch(`v1/assistants/${threadId}`, { json: assistant }) - .json() - ) as Promise + return window.core.api.modifyThreadAssistant({ threadId, assistant }) } } diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 5ef4e7a4e..f8444dcba 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -40,6 +40,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai "transport-child-process", "tower", ] } +uuid = { version = "1.7", features = ["v4"] } [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-updater = "2" +once_cell = "1.18" diff --git a/src-tauri/src/core/cmd.rs b/src-tauri/src/core/cmd.rs index 0fe706a1a..ca3d051af 100644 --- a/src-tauri/src/core/cmd.rs +++ b/src-tauri/src/core/cmd.rs @@ -93,6 +93,10 @@ pub fn update_app_configuration( #[tauri::command] pub fn get_jan_data_folder_path(app_handle: tauri::AppHandle) -> PathBuf { + if cfg!(test) { + return PathBuf::from("./data"); + } + let app_configurations = get_app_configurations(app_handle); PathBuf::from(app_configurations.data_folder) } diff --git a/src-tauri/src/core/fs.rs b/src-tauri/src/core/fs.rs index 9e77a812c..c0d7d423d 100644 --- a/src-tauri/src/core/fs.rs +++ b/src-tauri/src/core/fs.rs @@ -107,6 +107,7 @@ mod tests { use super::*; use std::fs::{self, File}; use std::io::Write; + use serde_json::to_string; use tauri::test::mock_app; #[test] @@ -154,9 +155,11 @@ mod tests { fn test_exists_sync() { let app = mock_app(); let path = "file://test_exists_sync_file"; - let file_path = get_jan_data_folder_path(app.handle().clone()).join(path); + let dir_path = get_jan_data_folder_path(app.handle().clone()); + fs::create_dir_all(&dir_path).unwrap(); + let file_path = dir_path.join("test_exists_sync_file"); File::create(&file_path).unwrap(); - let args = vec![path.to_string()]; + let args: Vec = vec![path.to_string()]; let result = exists_sync(app.handle().clone(), args).unwrap(); assert!(result); fs::remove_file(file_path).unwrap(); @@ -166,7 +169,9 @@ mod tests { fn test_read_file_sync() { let app = mock_app(); let path = "file://test_read_file_sync_file"; - let file_path = get_jan_data_folder_path(app.handle().clone()).join(path); + let dir_path = get_jan_data_folder_path(app.handle().clone()); + fs::create_dir_all(&dir_path).unwrap(); + let file_path = dir_path.join("test_read_file_sync_file"); let mut file = File::create(&file_path).unwrap(); file.write_all(b"test content").unwrap(); let args = vec![path.to_string()]; @@ -184,7 +189,7 @@ mod tests { File::create(dir_path.join("file1.txt")).unwrap(); File::create(dir_path.join("file2.txt")).unwrap(); - let args = vec![path.to_string()]; + let args = vec![dir_path.to_string_lossy().to_string()]; let result = readdir_sync(app.handle().clone(), args).unwrap(); assert_eq!(result.len(), 2); diff --git a/src-tauri/src/core/mod.rs b/src-tauri/src/core/mod.rs index e4f0ee6c4..8d4edde3c 100644 --- a/src-tauri/src/core/mod.rs +++ b/src-tauri/src/core/mod.rs @@ -4,3 +4,5 @@ pub mod mcp; pub mod server; pub mod setup; pub mod state; +pub mod threads; +pub mod utils; \ No newline at end of file diff --git a/src-tauri/src/core/threads.rs b/src-tauri/src/core/threads.rs new file mode 100644 index 000000000..e4b8f2e52 --- /dev/null +++ b/src-tauri/src/core/threads.rs @@ -0,0 +1,598 @@ +/*! + Thread and Message Persistence Module + + This module provides all logic for managing threads and their messages, including creation, modification, deletion, and listing. + Messages for each thread are persisted in a JSONL file (messages.jsonl) per thread directory. + + **Concurrency and Consistency Guarantee:** + - All operations that write or modify messages for a thread are protected by a global, per-thread asynchronous lock. + - This design ensures that only one operation can write to a thread's messages.jsonl file at a time, preventing race conditions. + - As a result, the messages.jsonl file for each thread is always consistent and never corrupted, even under concurrent access. +*/ + +use serde::{Deserialize, Serialize}; +use std::fs::{self, File}; +use std::io::{BufRead, BufReader, Write}; +use tauri::command; +use tauri::Runtime; +use uuid::Uuid; + +// For async file write serialization +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Global per-thread locks for message file writes +static MESSAGE_LOCKS: Lazy>>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +use super::utils::{ + ensure_data_dirs, ensure_thread_dir_exists, get_data_dir, get_messages_path, get_thread_dir, + get_thread_metadata_path, THREADS_FILE, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Thread { + pub id: String, + pub object: String, + pub title: String, + pub assistants: Vec, + pub created: i64, + pub updated: i64, + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadMessage { + pub id: String, + pub object: String, + pub thread_id: String, + pub assistant_id: Option, + pub attachments: Option>, + pub role: String, + pub content: Vec, + pub status: String, + pub created_at: i64, + pub completed_at: i64, + pub metadata: Option, + pub type_: Option, + pub error_code: Option, + pub tool_call_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Attachment { + pub file_id: Option, + pub tools: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum Tool { + #[serde(rename = "file_search")] + FileSearch, + #[serde(rename = "code_interpreter")] + CodeInterpreter, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadContent { + pub type_: String, + pub text: Option, + pub image_url: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ContentValue { + pub value: String, + pub annotations: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageContentValue { + pub detail: Option, + pub url: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadAssistantInfo { + pub assistant_id: String, + pub assistant_name: String, + pub model: ModelInfo, + pub instructions: Option, + pub tools: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub settings: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum AssistantTool { + #[serde(rename = "code_interpreter")] + CodeInterpreter, + #[serde(rename = "retrieval")] + Retrieval, + #[serde(rename = "function")] + Function { + name: String, + description: Option, + parameters: Option, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ThreadState { + pub has_more: bool, + pub waiting_for_response: bool, + pub error: Option, + pub last_message: Option, +} + +/// Lists all threads by reading their metadata from the threads directory. +/// Returns a vector of thread metadata as JSON values. +#[command] +pub async fn list_threads( + app_handle: tauri::AppHandle, +) -> Result, String> { + ensure_data_dirs(app_handle.clone())?; + let data_dir = get_data_dir(app_handle.clone()); + let mut threads = Vec::new(); + + if !data_dir.exists() { + return Ok(threads); + } + + for entry in fs::read_dir(&data_dir).map_err(|e| e.to_string())? { + let entry = entry.map_err(|e| e.to_string())?; + let path = entry.path(); + if path.is_dir() { + let thread_metadata_path = path.join(THREADS_FILE); + if thread_metadata_path.exists() { + let data = fs::read_to_string(&thread_metadata_path).map_err(|e| e.to_string())?; + match serde_json::from_str(&data) { + Ok(thread) => threads.push(thread), + Err(e) => { + println!("Failed to parse thread file: {}", e); + continue; // skip invalid thread files + } + } + } + } + } + + Ok(threads) +} + +/// Creates a new thread, assigns it a unique ID, and persists its metadata. +/// Ensures the thread directory exists and writes thread.json. +#[command] +pub async fn create_thread( + app_handle: tauri::AppHandle, + mut thread: serde_json::Value, +) -> Result { + ensure_data_dirs(app_handle.clone())?; + let uuid = Uuid::new_v4().to_string(); + thread["id"] = serde_json::Value::String(uuid.clone()); + let thread_dir = get_thread_dir(app_handle.clone(), &uuid); + if !thread_dir.exists() { + fs::create_dir_all(&thread_dir).map_err(|e| e.to_string())?; + } + let path = get_thread_metadata_path(app_handle.clone(), &uuid); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(path, data).map_err(|e| e.to_string())?; + Ok(thread) +} + +/// Modifies an existing thread's metadata by overwriting its thread.json file. +/// Returns an error if the thread directory does not exist. +#[command] +pub async fn modify_thread( + app_handle: tauri::AppHandle, + thread: serde_json::Value, +) -> Result<(), String> { + let thread_id = thread + .get("id") + .and_then(|id| id.as_str()) + .ok_or("Missing thread id")?; + let thread_dir = get_thread_dir(app_handle.clone(), thread_id); + if !thread_dir.exists() { + return Err("Thread directory does not exist".to_string()); + } + let path = get_thread_metadata_path(app_handle.clone(), thread_id); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(path, data).map_err(|e| e.to_string())?; + Ok(()) +} + +/// Deletes a thread and all its associated files by removing its directory. +#[command] +pub async fn delete_thread( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result<(), String> { + let thread_dir = get_thread_dir(app_handle.clone(), &thread_id); + if thread_dir.exists() { + fs::remove_dir_all(thread_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} + +/// Lists all messages for a given thread by reading and parsing its messages.jsonl file. +/// Returns a vector of message JSON values. +#[command] +pub async fn list_messages( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result, String> { + let path = get_messages_path(app_handle, &thread_id); + if !path.exists() { + return Ok(vec![]); + } + + let file = File::open(&path).map_err(|e| { + eprintln!("Error opening file {}: {}", path.display(), e); + e.to_string() + })?; + let reader = BufReader::new(file); + + let mut messages = Vec::new(); + for line in reader.lines() { + let line = line.map_err(|e| { + eprintln!("Error reading line from file {}: {}", path.display(), e); + e.to_string() + })?; + let message: serde_json::Value = serde_json::from_str(&line).map_err(|e| { + eprintln!( + "Error parsing JSON from line in file {}: {}", + path.display(), + e + ); + e.to_string() + })?; + messages.push(message); + } + + Ok(messages) +} + +/// Appends a new message to a thread's messages.jsonl file. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. +#[command] +pub async fn create_message( + app_handle: tauri::AppHandle, + mut message: serde_json::Value, +) -> Result { + let thread_id = { + let id = message + .get("thread_id") + .and_then(|v| v.as_str()) + .ok_or("Missing thread_id")?; + id.to_string() + }; + ensure_thread_dir_exists(app_handle.clone(), &thread_id)?; + let path = get_messages_path(app_handle.clone(), &thread_id); + + let uuid = Uuid::new_v4().to_string(); + message["id"] = serde_json::Value::String(uuid); + + // Acquire per-thread lock before writing + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock + + let _guard = lock.lock().await; + + let mut file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .map_err(|e| e.to_string())?; + + let data = serde_json::to_string(&message).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + + Ok(message) +} + +/// Modifies an existing message in a thread's messages.jsonl file. +/// Uses a per-thread async lock to prevent race conditions and ensure file consistency. +/// Rewrites the entire messages.jsonl file for the thread. +#[command] +pub async fn modify_message( + app_handle: tauri::AppHandle, + message: serde_json::Value, +) -> Result { + let thread_id = message + .get("thread_id") + .and_then(|v| v.as_str()) + .ok_or("Missing thread_id")?; + let message_id = message + .get("id") + .and_then(|v| v.as_str()) + .ok_or("Missing message id")?; + + // Acquire per-thread lock before modifying + { + let mut locks = MESSAGE_LOCKS.lock().await; + let lock = locks + .entry(thread_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone(); + drop(locks); // Release the map lock before awaiting the file lock + + let _guard = lock.lock().await; + + let mut messages = list_messages(app_handle.clone(), thread_id.to_string()).await?; + if let Some(index) = messages + .iter() + .position(|m| m.get("id").and_then(|v| v.as_str()) == Some(message_id)) + { + messages[index] = message.clone(); + + // Rewrite all messages + let path = get_messages_path(app_handle.clone(), thread_id); + let mut file = File::create(path).map_err(|e| e.to_string())?; + for msg in messages { + let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + } + } + Ok(message) +} + +/// Deletes a message from a thread's messages.jsonl file by message ID. +/// Rewrites the entire messages.jsonl file for the thread. +#[command] +pub async fn delete_message( + app_handle: tauri::AppHandle, + thread_id: String, + message_id: String, +) -> Result<(), String> { + let mut messages = list_messages(app_handle.clone(), thread_id.clone()).await?; + messages.retain(|m| m.get("id").and_then(|v| v.as_str()) != Some(message_id.as_str())); + + // Rewrite remaining messages + let path = get_messages_path(app_handle.clone(), &thread_id); + let mut file = File::create(path).map_err(|e| e.to_string())?; + for msg in messages { + let data = serde_json::to_string(&msg).map_err(|e| e.to_string())?; + writeln!(file, "{}", data).map_err(|e| e.to_string())?; + } + + Ok(()) +} + +/// Retrieves the first assistant associated with a thread. +/// Returns an error if the thread or assistant is not found. +#[command] +pub async fn get_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, +) -> Result { + let path = get_thread_metadata_path(app_handle, &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + let thread: serde_json::Value = serde_json::from_str(&data).map_err(|e| e.to_string())?; + if let Some(assistants) = thread.get("assistants").and_then(|a| a.as_array()) { + if let Some(first) = assistants.get(0) { + Ok(first.clone()) + } else { + Err("Assistant not found".to_string()) + } + } else { + Err("Assistant not found".to_string()) + } +} + +/// Adds a new assistant to a thread's metadata. +/// Updates thread.json with the new assistant information. +#[command] +pub async fn create_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, + assistant: serde_json::Value, +) -> Result { + let path = get_thread_metadata_path(app_handle.clone(), &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let mut thread: serde_json::Value = { + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + serde_json::from_str(&data).map_err(|e| e.to_string())? + }; + if let Some(assistants) = thread.get_mut("assistants").and_then(|a| a.as_array_mut()) { + assistants.push(assistant.clone()); + } else { + thread["assistants"] = serde_json::Value::Array(vec![assistant.clone()]); + } + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(&path, data).map_err(|e| e.to_string())?; + Ok(assistant) +} + +/// Modifies an existing assistant's information in a thread's metadata. +/// Updates thread.json with the modified assistant data. +#[command] +pub async fn modify_thread_assistant( + app_handle: tauri::AppHandle, + thread_id: String, + assistant: serde_json::Value, +) -> Result { + let path = get_thread_metadata_path(app_handle.clone(), &thread_id); + if !path.exists() { + return Err("Thread not found".to_string()); + } + let mut thread: serde_json::Value = { + let data = fs::read_to_string(&path).map_err(|e| e.to_string())?; + serde_json::from_str(&data).map_err(|e| e.to_string())? + }; + let assistant_id = assistant + .get("id") + .and_then(|v| v.as_str()) + .ok_or("Missing assistant_id")?; + if let Some(assistants) = thread + .get_mut("assistants") + .and_then(|a: &mut serde_json::Value| a.as_array_mut()) + { + if let Some(index) = assistants + .iter() + .position(|a| a.get("id").and_then(|v| v.as_str()) == Some(assistant_id)) + { + assistants[index] = assistant.clone(); + let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; + fs::write(&path, data).map_err(|e| e.to_string())?; + } + } + Ok(assistant) +} + +#[cfg(test)] +mod tests { + use crate::core::cmd::get_jan_data_folder_path; + + use super::*; + use serde_json::json; + use std::fs; + use std::path::PathBuf; + use tauri::test::{mock_app, MockRuntime}; + + // Helper to create a mock app handle with a temp data dir + fn mock_app_with_temp_data_dir() -> (tauri::App, PathBuf) { + let app = mock_app(); + let data_dir = get_jan_data_folder_path(app.handle().clone()); + println!("Mock app data dir: {}", data_dir.display()); + // Patch get_data_dir to use temp dir (requires get_data_dir to be overridable or injectable) + // For now, we assume get_data_dir uses tauri::api::path::app_data_dir(&app_handle) + // and that we can set the environment variable to redirect it. + (app, data_dir) + } + + #[tokio::test] + async fn test_create_and_list_threads() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread + let thread = json!({ + "object": "thread", + "title": "Test Thread", + "assistants": [], + "created": 1234567890, + "updated": 1234567890, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + assert_eq!(created["title"], "Test Thread"); + + // List threads + let threads = list_threads(app.handle().clone()).await.unwrap(); + assert!(threads.len() > 0); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } + + #[tokio::test] + async fn test_create_and_list_messages() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread first + let thread = json!({ + "object": "thread", + "title": "Msg Thread", + "assistants": [], + "created": 123, + "updated": 123, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + let thread_id = created["id"].as_str().unwrap().to_string(); + + // Create a message + let message = json!({ + "object": "message", + "thread_id": thread_id, + "assistant_id": null, + "attachments": null, + "role": "user", + "content": [], + "status": "sent", + "created_at": 123, + "completed_at": 123, + "metadata": null, + "type_": null, + "error_code": null, + "tool_call_id": null + }); + let created_msg = create_message(app.handle().clone(), message).await.unwrap(); + assert_eq!(created_msg["role"], "user"); + + // List messages + let messages = list_messages(app.handle().clone(), thread_id.clone()) + .await + .unwrap(); + assert!(messages.len() > 0); + assert_eq!(messages[0]["role"], "user"); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } + + #[tokio::test] + async fn test_create_and_get_thread_assistant() { + let (app, data_dir) = mock_app_with_temp_data_dir(); + // Create a thread + let thread = json!({ + "object": "thread", + "title": "Assistant Thread", + "assistants": [], + "created": 1, + "updated": 1, + "metadata": null + }); + let created = create_thread(app.handle().clone(), thread.clone()) + .await + .unwrap(); + let thread_id = created["id"].as_str().unwrap().to_string(); + + // Add assistant + let assistant = json!({ + "id": "assistant-1", + "assistant_name": "Test Assistant", + "model": { + "id": "model-1", + "name": "Test Model", + "settings": json!({}) + }, + "instructions": null, + "tools": null + }); + let _ = create_thread_assistant(app.handle().clone(), thread_id.clone(), assistant.clone()) + .await + .unwrap(); + + // Get assistant + let got = get_thread_assistant(app.handle().clone(), thread_id.clone()) + .await + .unwrap(); + assert_eq!(got["assistant_name"], "Test Assistant"); + + // Clean up + fs::remove_dir_all(data_dir).unwrap(); + } +} diff --git a/src-tauri/src/core/utils/mod.rs b/src-tauri/src/core/utils/mod.rs new file mode 100644 index 000000000..7f80e6f3a --- /dev/null +++ b/src-tauri/src/core/utils/mod.rs @@ -0,0 +1,48 @@ +use std::fs; +use std::path::PathBuf; +use tauri::Runtime; + +use super::cmd::get_jan_data_folder_path; + +pub const THREADS_DIR: &str = "threads"; +pub const THREADS_FILE: &str = "thread.json"; +pub const MESSAGES_FILE: &str = "messages.jsonl"; + +pub fn get_data_dir(app_handle: tauri::AppHandle) -> PathBuf { + get_jan_data_folder_path(app_handle).join(THREADS_DIR) +} + +pub fn get_thread_dir(app_handle: tauri::AppHandle, thread_id: &str) -> PathBuf { + get_data_dir(app_handle).join(thread_id) +} + +pub fn get_thread_metadata_path( + app_handle: tauri::AppHandle, + thread_id: &str, +) -> PathBuf { + get_thread_dir(app_handle, thread_id).join(THREADS_FILE) +} + +pub fn get_messages_path(app_handle: tauri::AppHandle, thread_id: &str) -> PathBuf { + get_thread_dir(app_handle, thread_id).join(MESSAGES_FILE) +} + +pub fn ensure_data_dirs(app_handle: tauri::AppHandle) -> Result<(), String> { + let data_dir = get_data_dir(app_handle.clone()); + if !data_dir.exists() { + fs::create_dir_all(&data_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} + +pub fn ensure_thread_dir_exists( + app_handle: tauri::AppHandle, + thread_id: &str, +) -> Result<(), String> { + ensure_data_dirs(app_handle.clone())?; + let thread_dir = get_thread_dir(app_handle, thread_id); + if !thread_dir.exists() { + fs::create_dir(&thread_dir).map_err(|e| e.to_string())?; + } + Ok(()) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 40cd83f57..f37d97ae2 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -41,7 +41,19 @@ pub fn run() { core::cmd::stop_server, // MCP commands core::cmd::get_tools, - core::cmd::call_tool + core::cmd::call_tool, + // Threads + core::threads::list_threads, + core::threads::create_thread, + core::threads::modify_thread, + core::threads::delete_thread, + core::threads::list_messages, + core::threads::create_message, + core::threads::modify_message, + core::threads::delete_message, + core::threads::get_thread_assistant, + core::threads::create_thread_assistant, + core::threads::modify_thread_assistant ]) .manage(AppState { app_token: Some(generate_app_token()), diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx index 522b27a0d..3af675de5 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx @@ -68,7 +68,10 @@ const MessageContainer: React.FC< ) const attachedFile = useMemo( - () => 'attachments' in props && props.attachments?.length, + () => + 'attachments' in props && + !!props.attachments?.length && + props.attachments?.length > 0, [props] ) diff --git a/web/services/tauriService.ts b/web/services/tauriService.ts index 488593658..2c592f951 100644 --- a/web/services/tauriService.ts +++ b/web/services/tauriService.ts @@ -8,6 +8,17 @@ export const Routes = [ 'installExtensions', 'getTools', 'callTool', + 'listThreads', + 'createThread', + 'modifyThread', + 'deleteThread', + 'listMessages', + 'createMessage', + 'modifyMessage', + 'deleteMessage', + 'getThreadAssistant', + 'createThreadAssistant', + 'modifyThreadAssistant', ].map((r) => ({ path: `app`, route: r, From 40d63853ecf18ffafd83c1025b1d563c72bbcc12 Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Mon, 14 Apr 2025 15:26:13 +0700 Subject: [PATCH 05/12] chore: initial commit mcp setting --- web/screens/Settings/MCP/index.tsx | 13 +++++++++++++ web/screens/Settings/SettingDetail/index.tsx | 4 ++++ web/screens/Settings/index.tsx | 1 + 3 files changed, 18 insertions(+) create mode 100644 web/screens/Settings/MCP/index.tsx diff --git a/web/screens/Settings/MCP/index.tsx b/web/screens/Settings/MCP/index.tsx new file mode 100644 index 000000000..905416cea --- /dev/null +++ b/web/screens/Settings/MCP/index.tsx @@ -0,0 +1,13 @@ +import React from 'react' + +const MCP = () => { + return ( +

+ Lorem ipsum dolor sit amet consectetur adipisicing elit. Qui, cumque + deleniti dolorem ducimus nisi rerum cum et sunt maxime, dicta sequi + assumenda sit illum? Minima beatae repudiandae praesentium sed incidunt! +

+ ) +} + +export default MCP diff --git a/web/screens/Settings/SettingDetail/index.tsx b/web/screens/Settings/SettingDetail/index.tsx index 8ceb600e6..43bbd6c1c 100644 --- a/web/screens/Settings/SettingDetail/index.tsx +++ b/web/screens/Settings/SettingDetail/index.tsx @@ -15,6 +15,7 @@ import RemoteEngineSettings from '@/screens/Settings/Engines/RemoteEngineSetting import ExtensionSetting from '@/screens/Settings/ExtensionSetting' import Hardware from '@/screens/Settings/Hardware' import Hotkeys from '@/screens/Settings/Hotkeys' +import MCP from '@/screens/Settings/MCP' import MyModels from '@/screens/Settings/MyModels' import Privacy from '@/screens/Settings/Privacy' @@ -31,6 +32,9 @@ const SettingDetail = () => { case 'Engines': return + case 'MCP Servers': + return + case 'Extensions': return diff --git a/web/screens/Settings/index.tsx b/web/screens/Settings/index.tsx index d126f0d0e..8db1dd0b3 100644 --- a/web/screens/Settings/index.tsx +++ b/web/screens/Settings/index.tsx @@ -19,6 +19,7 @@ export const SettingScreenList = [ 'Privacy', 'Advanced Settings', 'Engines', + 'MCP Servers', 'Extensions', ] as const From ef1a85b58c9d756c8515c01fc296483daef1b372 Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Mon, 14 Apr 2025 21:32:04 +0700 Subject: [PATCH 06/12] chore: setting mcp --- web/screens/Settings/MCP/index.tsx | 126 +++++++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 6 deletions(-) diff --git a/web/screens/Settings/MCP/index.tsx b/web/screens/Settings/MCP/index.tsx index 905416cea..28f26d5f0 100644 --- a/web/screens/Settings/MCP/index.tsx +++ b/web/screens/Settings/MCP/index.tsx @@ -1,12 +1,126 @@ -import React from 'react' +import React, { useState, useEffect, useCallback } from 'react' + +import { fs, joinPath } from '@janhq/core' +import { Button } from '@janhq/joi' +import { useAtomValue } from 'jotai' + +import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' const MCP = () => { + const janDataFolderPath = useAtomValue(janDataFolderPathAtom) + const [configContent, setConfigContent] = useState('') + const [isSaving, setIsSaving] = useState(false) + const [error, setError] = useState('') + const [success, setSuccess] = useState('') + + console.log(janDataFolderPath, 'janDataFolderPath') + + const readConfigFile = useCallback(async () => { + try { + const configPath = await joinPath([janDataFolderPath, 'mcp_config.json']) + + // Check if the file exists + const fileExists = await fs.existsSync(configPath) + + if (fileExists) { + // Read the file + const content = await fs.readFileSync(configPath, 'utf-8') + setConfigContent(content) + } else { + // Create a default config if it doesn't exist + const defaultConfig = JSON.stringify( + { + servers: [], + settings: { + enabled: true, + }, + }, + null, + 2 + ) + + await fs.writeFileSync(configPath, defaultConfig) + setConfigContent(defaultConfig) + } + + setError('') + } catch (err) { + console.error('Error reading config file:', err) + setError('Failed to read config file') + } + }, [janDataFolderPath]) + + useEffect(() => { + if (janDataFolderPath) { + readConfigFile() + } + }, [janDataFolderPath, readConfigFile]) + + const saveConfigFile = useCallback(async () => { + try { + setIsSaving(true) + setSuccess('') + setError('') + + // Validate JSON + try { + JSON.parse(configContent) + } catch (err) { + setError('Invalid JSON format') + setIsSaving(false) + return + } + + const configPath = await joinPath([janDataFolderPath, 'mcp_config.json']) + + // Write to the file + await fs.writeFileSync(configPath, configContent) + + setSuccess('Config saved successfully') + setIsSaving(false) + } catch (err) { + console.error('Error saving config file:', err) + setError('Failed to save config file') + setIsSaving(false) + } + }, [janDataFolderPath, configContent]) + return ( -

- Lorem ipsum dolor sit amet consectetur adipisicing elit. Qui, cumque - deleniti dolorem ducimus nisi rerum cum et sunt maxime, dicta sequi - assumenda sit illum? Minima beatae repudiandae praesentium sed incidunt! -

+
+

MCP Configuration

+ + {error && ( +
+ {error} +
+ )} + + {success && ( +
+ {success} +
+ )} + +
+ +