diff --git a/extensions/llamacpp-extension/src/type.d.ts b/extensions/llamacpp-extension/src/type.d.ts new file mode 100644 index 000000000..88fc84a17 --- /dev/null +++ b/extensions/llamacpp-extension/src/type.d.ts @@ -0,0 +1,12 @@ +export {} + +declare global { + interface RequestInit { + /** + * Tauri HTTP plugin option for connection timeout in milliseconds. + */ + connectTimeout?: number + } +} + + diff --git a/src-tauri/plugins/tauri-plugin-hardware/src/vendor/amd.rs b/src-tauri/plugins/tauri-plugin-hardware/src/vendor/amd.rs index 62d90ca1b..7521fd2b0 100644 --- a/src-tauri/plugins/tauri-plugin-hardware/src/vendor/amd.rs +++ b/src-tauri/plugins/tauri-plugin-hardware/src/vendor/amd.rs @@ -126,13 +126,13 @@ mod windows_impl { pub iOSDisplayIndex: c_int, } - type ADL_MAIN_MALLOC_CALLBACK = Option *mut c_void>; - type ADL_MAIN_CONTROL_CREATE = unsafe extern "C" fn(ADL_MAIN_MALLOC_CALLBACK, c_int) -> c_int; - type ADL_MAIN_CONTROL_DESTROY = unsafe extern "C" fn() -> c_int; - type ADL_ADAPTER_NUMBEROFADAPTERS_GET = unsafe extern "C" fn(*mut c_int) -> c_int; - type ADL_ADAPTER_ADAPTERINFO_GET = unsafe extern "C" fn(*mut AdapterInfo, c_int) -> c_int; - type ADL_ADAPTER_ACTIVE_GET = unsafe extern "C" fn(c_int, *mut c_int) -> c_int; - type ADL_GET_DEDICATED_VRAM_USAGE = + type AdlMainMallocCallback = Option *mut c_void>; + type ADLMAINCONTROLCREATE = unsafe extern "C" fn(AdlMainMallocCallback, c_int) -> c_int; + type ADLMAINCONTROLDESTROY = unsafe extern "C" fn() -> c_int; + type AdlAdapterNumberofadaptersGet = unsafe extern "C" fn(*mut c_int) -> c_int; + type AdlAdapterAdapterinfoGet = unsafe extern "C" fn(*mut AdapterInfo, c_int) -> c_int; + type AdlAdapterActiveGet = unsafe extern "C" fn(c_int, *mut c_int) -> c_int; + type AdlGetDedicatedVramUsage = unsafe extern "C" fn(*mut c_void, c_int, *mut c_int) -> c_int; // === ADL Memory Allocator === @@ -144,24 +144,24 @@ mod windows_impl { unsafe { let lib = Library::new("atiadlxx.dll").or_else(|_| Library::new("atiadlxy.dll"))?; - let adl_main_control_create: Symbol = - lib.get(b"ADL_Main_Control_Create")?; - let adl_main_control_destroy: Symbol = - lib.get(b"ADL_Main_Control_Destroy")?; - let adl_adapter_number_of_adapters_get: Symbol = - lib.get(b"ADL_Adapter_NumberOfAdapters_Get")?; - let adl_adapter_adapter_info_get: Symbol = - lib.get(b"ADL_Adapter_AdapterInfo_Get")?; - let adl_adapter_active_get: Symbol = - lib.get(b"ADL_Adapter_Active_Get")?; - let adl_get_dedicated_vram_usage: Symbol = + let adlmaincontrolcreate: Symbol = + lib.get(b"AdlMainControlCreate")?; + let adlmaincontroldestroy: Symbol = + lib.get(b"AdlMainControlDestroy")?; + let adl_adapter_number_of_adapters_get: Symbol = + lib.get(b"AdlAdapterNumberofadaptersGet")?; + let adl_adapter_adapter_info_get: Symbol = + lib.get(b"AdlAdapterAdapterinfoGet")?; + let AdlAdapterActiveGet: Symbol = + lib.get(b"AdlAdapterActiveGet")?; + let AdlGetDedicatedVramUsage: Symbol = lib.get(b"ADL2_Adapter_DedicatedVRAMUsage_Get")?; // TODO: try to put nullptr here. then we don't need direct libc dep - if adl_main_control_create(Some(adl_malloc), 1) != 0 { + if adlmaincontrolcreate(Some(adl_malloc), 1) != 0 { return Err("ADL initialization error!".into()); } - // NOTE: after this call, we must call ADL_Main_Control_Destroy + // NOTE: after this call, we must call AdlMainControlDestroy // whenver we encounter an error let mut num_adapters: c_int = 0; @@ -184,11 +184,11 @@ mod windows_impl { for adapter in adapter_info.iter() { let mut is_active = 0; - adl_adapter_active_get(adapter.iAdapterIndex, &mut is_active); + AdlAdapterActiveGet(adapter.iAdapterIndex, &mut is_active); if is_active != 0 { let mut vram_mb = 0; - let _ = adl_get_dedicated_vram_usage( + let _ = AdlGetDedicatedVramUsage( ptr::null_mut(), adapter.iAdapterIndex, &mut vram_mb, @@ -202,7 +202,7 @@ mod windows_impl { } } - adl_main_control_destroy(); + adlmaincontroldestroy(); Ok(vram_usages) } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/process.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/process.rs index 3de983c51..06d83fcb0 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/process.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/process.rs @@ -1,8 +1,6 @@ use std::collections::HashSet; -use std::time::Duration; use sysinfo::{Pid, System}; use tauri::{Manager, Runtime, State}; -use tokio::time::timeout; use crate::state::{LlamacppState, SessionInfo}; use jan_utils::generate_random_port; @@ -56,6 +54,8 @@ pub async fn get_random_available_port( pub async fn graceful_terminate_process(child: &mut tokio::process::Child) { use nix::sys::signal::{kill, Signal}; use nix::unistd::Pid; + use std::time::Duration; + use tokio::time::timeout; if let Some(raw_pid) = child.id() { let raw_pid = raw_pid as i32; diff --git a/src-tauri/utils/src/system.rs b/src-tauri/utils/src/system.rs index cf281b3cb..efb137550 100644 --- a/src-tauri/utils/src/system.rs +++ b/src-tauri/utils/src/system.rs @@ -81,7 +81,6 @@ pub fn setup_library_path(library_path: Option<&str>, command: &mut tokio::proce pub fn setup_windows_process_flags(command: &mut tokio::process::Command) { #[cfg(all(windows, target_arch = "x86_64"))] { - use std::os::windows::process::CommandExt; const CREATE_NO_WINDOW: u32 = 0x0800_0000; const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200; command.creation_flags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP); diff --git a/web-app/src/hooks/__tests__/useChat.test.ts b/web-app/src/hooks/__tests__/useChat.test.ts index 3f89e24cd..a694af9a5 100644 --- a/web-app/src/hooks/__tests__/useChat.test.ts +++ b/web-app/src/hooks/__tests__/useChat.test.ts @@ -1,20 +1,30 @@ import { renderHook, act, waitFor } from '@testing-library/react' import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' -import { useChat } from '../useChat' -import * as completionLib from '@/lib/completion' -import * as messagesLib from '@/lib/messages' import { MessageStatus, ContentType } from '@janhq/core' -// Store mock functions for assertions -let mockAddMessage: ReturnType -let mockUpdateMessage: ReturnType -let mockGetMessages: ReturnType -let mockStartModel: ReturnType -let mockSendCompletion: ReturnType -let mockPostMessageProcessing: ReturnType -let mockCompletionMessagesBuilder: any -let mockSetPrompt: ReturnType -let mockResetTokenSpeed: ReturnType +// Store mock functions for assertions - initialize immediately +const mockAddMessage = vi.fn() +const mockUpdateMessage = vi.fn() +const mockGetMessages = vi.fn(() => []) +const mockStartModel = vi.fn(() => Promise.resolve()) +const mockSendCompletion = vi.fn(() => Promise.resolve({ + choices: [{ + message: { + content: 'AI response', + role: 'assistant', + }, + }], +})) +const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) => + Promise.resolve(content) +) +const mockCompletionMessagesBuilder = { + addUserMessage: vi.fn(), + addAssistantMessage: vi.fn(), + getMessages: vi.fn(() => []), +} +const mockSetPrompt = vi.fn() +const mockResetTokenSpeed = vi.fn() // Mock dependencies vi.mock('../usePrompt', () => ({ @@ -231,12 +241,12 @@ vi.mock('@/lib/completion', () => ({ extractToolCall: vi.fn(), newUserThreadContent: vi.fn((threadId, content) => ({ thread_id: threadId, - content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }], + content: [{ type: 'text', text: { value: content, annotations: [] } }], role: 'user' })), newAssistantThreadContent: vi.fn((threadId, content) => ({ thread_id: threadId, - content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }], + content: [{ type: 'text', text: { value: content, annotations: [] } }], role: 'assistant' })), sendCompletion: mockSendCompletion, @@ -274,33 +284,37 @@ vi.mock('sonner', () => ({ }, })) +// Import after mocks to avoid hoisting issues +const { useChat } = await import('../useChat') +const completionLib = await import('@/lib/completion') +const messagesLib = await import('@/lib/messages') + describe('useChat', () => { beforeEach(() => { - // Reset all mocks - mockAddMessage = vi.fn() - mockUpdateMessage = vi.fn() - mockGetMessages = vi.fn(() => []) - mockStartModel = vi.fn(() => Promise.resolve()) - mockSetPrompt = vi.fn() - mockResetTokenSpeed = vi.fn() - mockSendCompletion = vi.fn(() => Promise.resolve({ + // Clear mock call history + vi.clearAllMocks() + + // Reset mock implementations + mockAddMessage.mockClear() + mockUpdateMessage.mockClear() + mockGetMessages.mockReturnValue([]) + mockStartModel.mockResolvedValue(undefined) + mockSetPrompt.mockClear() + mockResetTokenSpeed.mockClear() + mockSendCompletion.mockResolvedValue({ choices: [{ message: { content: 'AI response', role: 'assistant', }, }], - })) - mockPostMessageProcessing = vi.fn((toolCalls, builder, content) => + }) + mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) => Promise.resolve(content) ) - mockCompletionMessagesBuilder = { - addUserMessage: vi.fn(), - addAssistantMessage: vi.fn(), - getMessages: vi.fn(() => []), - } - - vi.clearAllMocks() + mockCompletionMessagesBuilder.addUserMessage.mockClear() + mockCompletionMessagesBuilder.addAssistantMessage.mockClear() + mockCompletionMessagesBuilder.getMessages.mockReturnValue([]) }) afterEach(() => { @@ -344,7 +358,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } @@ -369,7 +383,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } @@ -393,13 +407,13 @@ describe('useChat', () => { id: 'msg-1', thread_id: 'test-thread', role: 'user', - content: [{ type: ContentType.Text, text: { value: 'Hello', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }], } const stoppedMessage = { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }], status: MessageStatus.Stopped, } mockGetMessages.mockReturnValue([userMsg, stoppedMessage]) @@ -423,7 +437,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } @@ -450,7 +464,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } @@ -520,7 +534,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } @@ -563,7 +577,7 @@ describe('useChat', () => { id: 'msg-123', thread_id: 'test-thread', role: 'assistant', - content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }], status: MessageStatus.Stopped, metadata: {}, } diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index 59e2e6dda..646294a8d 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -38,6 +38,11 @@ type AppState = { updateTools: (tools: MCPTool[]) => void setAbortController: (threadId: string, controller: AbortController) => void updateTokenSpeed: (message: ThreadMessage, increment?: number) => void + setTokenSpeed: ( + message: ThreadMessage, + speed: number, + completionTokens: number + ) => void resetTokenSpeed: () => void clearAppState: () => void setOutOfContextDialog: (show: boolean) => void @@ -96,6 +101,17 @@ export const useAppState = create()((set) => ({ }, })) }, + setTokenSpeed: (message, speed, completionTokens) => { + set((state) => ({ + tokenSpeed: { + ...state.tokenSpeed, + lastTimestamp: new Date().getTime(), + tokenSpeed: speed, + tokenCount: completionTokens, + message: message.id, + }, + })) + }, updateTokenSpeed: (message, increment = 1) => set((state) => { const currentTimestamp = new Date().getTime() // Get current time in milliseconds diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 8fae57777..fb66cd639 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -19,7 +19,10 @@ import { } from '@/lib/completion' import { CompletionMessagesBuilder } from '@/lib/messages' import { renderInstructions } from '@/lib/instructionTemplate' -import { ChatCompletionMessageToolCall } from 'openai/resources' +import { + ChatCompletionMessageToolCall, + CompletionUsage, +} from 'openai/resources' import { MessageStatus, ContentType } from '@janhq/core' import { useServiceHub } from '@/hooks/useServiceHub' @@ -98,7 +101,10 @@ const processStreamingCompletion = async ( currentCall: ChatCompletionMessageToolCall | null, updateStreamingContent: (content: ThreadMessage | undefined) => void, updateTokenSpeed: (message: ThreadMessage, increment?: number) => void, + setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void, updatePromptProgress: (progress: unknown) => void, + timeToFirstToken: number, + tokenUsageRef: { current: CompletionUsage | undefined }, continueFromMessageId?: string, updateMessage?: (message: ThreadMessage) => void, continueFromMessage?: ThreadMessage @@ -128,7 +134,14 @@ const processStreamingCompletion = async ( updateStreamingContent(currentContent) } - if (pendingDeltaCount > 0) { + if (tokenUsageRef.current) { + setTokenSpeed( + currentContent, + tokenUsageRef.current.completion_tokens / + Math.max((Date.now() - timeToFirstToken) / 1000, 1), + tokenUsageRef.current.completion_tokens + ) + } else if (pendingDeltaCount > 0) { updateTokenSpeed(currentContent, pendingDeltaCount) } pendingDeltaCount = 0 @@ -183,6 +196,10 @@ const processStreamingCompletion = async ( ) } + if ('usage' in part && part.usage) { + tokenUsageRef.current = part.usage + } + if (part.choices[0]?.delta?.tool_calls) { extractToolCall(part, currentCall, toolCalls) // Schedule a flush to reflect tool update @@ -221,6 +238,7 @@ export const useChat = () => { updateStreamingContent, updateLoadingModel, setAbortController, + setTokenSpeed, ] = useAppState( useShallow((state) => [ state.updateTokenSpeed, @@ -228,6 +246,7 @@ export const useChat = () => { state.updateStreamingContent, state.updateLoadingModel, state.setAbortController, + state.setTokenSpeed, ]) ) const updatePromptProgress = useAppState( @@ -541,10 +560,18 @@ export const useChat = () => { if (!completion) throw new Error('No completion received') const currentCall: ChatCompletionMessageToolCall | null = null const toolCalls: ChatCompletionMessageToolCall[] = [] + const timeToFirstToken = Date.now() + let tokenUsage: CompletionUsage | undefined = undefined try { if (isCompletionResponse(completion)) { const message = completion.choices[0]?.message - accumulatedTextRef.value = (message?.content as string) || '' + // When continuing, append to existing content; otherwise replace + const newContent = (message?.content as string) || '' + if (continueFromMessageId && accumulatedTextRef.value) { + accumulatedTextRef.value += newContent + } else { + accumulatedTextRef.value = newContent + } // Handle reasoning field if there is one const reasoning = extractReasoningFromMessage(message) @@ -556,7 +583,11 @@ export const useChat = () => { if (message?.tool_calls) { toolCalls.push(...message.tool_calls) } + if ('usage' in completion) { + tokenUsage = completion.usage + } } else { + const tokenUsageRef = { current: tokenUsage } await processStreamingCompletion( completion, abortController, @@ -566,11 +597,15 @@ export const useChat = () => { currentCall, updateStreamingContent, updateTokenSpeed, + setTokenSpeed, updatePromptProgress, + timeToFirstToken, + tokenUsageRef, continueFromMessageId, updateMessage, continueFromMessage ) + tokenUsage = tokenUsageRef.current } } catch (error) { const errorMessage = @@ -797,6 +832,7 @@ export const useChat = () => { allowAllMCPPermissions, showApprovalModal, updateTokenSpeed, + setTokenSpeed, showIncreaseContextSizeModal, increaseModelContextSize, toggleOnContextShifting, diff --git a/web-app/src/hooks/useThreadScrolling.tsx b/web-app/src/hooks/useThreadScrolling.tsx index a3c6d7ed2..bdc4df9b1 100644 --- a/web-app/src/hooks/useThreadScrolling.tsx +++ b/web-app/src/hooks/useThreadScrolling.tsx @@ -54,7 +54,6 @@ export const useThreadScrolling = ( } }, [scrollContainerRef]) - const handleScroll = useCallback((e: Event) => { const target = e.target as HTMLDivElement const { scrollTop, scrollHeight, clientHeight } = target @@ -69,7 +68,7 @@ export const useThreadScrolling = ( setIsAtBottom(isBottom) setHasScrollbar(hasScroll) lastScrollTopRef.current = scrollTop - }, [streamingContent]) + }, [streamingContent, setIsAtBottom, setHasScrollbar]) useEffect(() => { const scrollContainer = scrollContainerRef.current @@ -90,7 +89,7 @@ export const useThreadScrolling = ( setIsAtBottom(isBottom) setHasScrollbar(hasScroll) - }, [scrollContainerRef]) + }, [scrollContainerRef, setIsAtBottom, setHasScrollbar]) useEffect(() => { if (!scrollContainerRef.current) return