Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message
# Conflicts: # web-app/src/hooks/useChat.ts
This commit is contained in:
commit
1c0e135077
12
extensions/llamacpp-extension/src/type.d.ts
vendored
Normal file
12
extensions/llamacpp-extension/src/type.d.ts
vendored
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
export {}
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
interface RequestInit {
|
||||||
|
/**
|
||||||
|
* Tauri HTTP plugin option for connection timeout in milliseconds.
|
||||||
|
*/
|
||||||
|
connectTimeout?: number
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -126,13 +126,13 @@ mod windows_impl {
|
|||||||
pub iOSDisplayIndex: c_int,
|
pub iOSDisplayIndex: c_int,
|
||||||
}
|
}
|
||||||
|
|
||||||
type ADL_MAIN_MALLOC_CALLBACK = Option<unsafe extern "C" fn(i32) -> *mut c_void>;
|
type AdlMainMallocCallback = Option<unsafe extern "C" fn(i32) -> *mut c_void>;
|
||||||
type ADL_MAIN_CONTROL_CREATE = unsafe extern "C" fn(ADL_MAIN_MALLOC_CALLBACK, c_int) -> c_int;
|
type ADLMAINCONTROLCREATE = unsafe extern "C" fn(AdlMainMallocCallback, c_int) -> c_int;
|
||||||
type ADL_MAIN_CONTROL_DESTROY = unsafe extern "C" fn() -> c_int;
|
type ADLMAINCONTROLDESTROY = unsafe extern "C" fn() -> c_int;
|
||||||
type ADL_ADAPTER_NUMBEROFADAPTERS_GET = unsafe extern "C" fn(*mut c_int) -> c_int;
|
type AdlAdapterNumberofadaptersGet = unsafe extern "C" fn(*mut c_int) -> c_int;
|
||||||
type ADL_ADAPTER_ADAPTERINFO_GET = unsafe extern "C" fn(*mut AdapterInfo, c_int) -> c_int;
|
type AdlAdapterAdapterinfoGet = unsafe extern "C" fn(*mut AdapterInfo, c_int) -> c_int;
|
||||||
type ADL_ADAPTER_ACTIVE_GET = unsafe extern "C" fn(c_int, *mut c_int) -> c_int;
|
type AdlAdapterActiveGet = unsafe extern "C" fn(c_int, *mut c_int) -> c_int;
|
||||||
type ADL_GET_DEDICATED_VRAM_USAGE =
|
type AdlGetDedicatedVramUsage =
|
||||||
unsafe extern "C" fn(*mut c_void, c_int, *mut c_int) -> c_int;
|
unsafe extern "C" fn(*mut c_void, c_int, *mut c_int) -> c_int;
|
||||||
|
|
||||||
// === ADL Memory Allocator ===
|
// === ADL Memory Allocator ===
|
||||||
@ -144,24 +144,24 @@ mod windows_impl {
|
|||||||
unsafe {
|
unsafe {
|
||||||
let lib = Library::new("atiadlxx.dll").or_else(|_| Library::new("atiadlxy.dll"))?;
|
let lib = Library::new("atiadlxx.dll").or_else(|_| Library::new("atiadlxy.dll"))?;
|
||||||
|
|
||||||
let adl_main_control_create: Symbol<ADL_MAIN_CONTROL_CREATE> =
|
let adlmaincontrolcreate: Symbol<ADLMAINCONTROLCREATE> =
|
||||||
lib.get(b"ADL_Main_Control_Create")?;
|
lib.get(b"AdlMainControlCreate")?;
|
||||||
let adl_main_control_destroy: Symbol<ADL_MAIN_CONTROL_DESTROY> =
|
let adlmaincontroldestroy: Symbol<ADLMAINCONTROLDESTROY> =
|
||||||
lib.get(b"ADL_Main_Control_Destroy")?;
|
lib.get(b"AdlMainControlDestroy")?;
|
||||||
let adl_adapter_number_of_adapters_get: Symbol<ADL_ADAPTER_NUMBEROFADAPTERS_GET> =
|
let adl_adapter_number_of_adapters_get: Symbol<AdlAdapterNumberofadaptersGet> =
|
||||||
lib.get(b"ADL_Adapter_NumberOfAdapters_Get")?;
|
lib.get(b"AdlAdapterNumberofadaptersGet")?;
|
||||||
let adl_adapter_adapter_info_get: Symbol<ADL_ADAPTER_ADAPTERINFO_GET> =
|
let adl_adapter_adapter_info_get: Symbol<AdlAdapterAdapterinfoGet> =
|
||||||
lib.get(b"ADL_Adapter_AdapterInfo_Get")?;
|
lib.get(b"AdlAdapterAdapterinfoGet")?;
|
||||||
let adl_adapter_active_get: Symbol<ADL_ADAPTER_ACTIVE_GET> =
|
let AdlAdapterActiveGet: Symbol<AdlAdapterActiveGet> =
|
||||||
lib.get(b"ADL_Adapter_Active_Get")?;
|
lib.get(b"AdlAdapterActiveGet")?;
|
||||||
let adl_get_dedicated_vram_usage: Symbol<ADL_GET_DEDICATED_VRAM_USAGE> =
|
let AdlGetDedicatedVramUsage: Symbol<AdlGetDedicatedVramUsage> =
|
||||||
lib.get(b"ADL2_Adapter_DedicatedVRAMUsage_Get")?;
|
lib.get(b"ADL2_Adapter_DedicatedVRAMUsage_Get")?;
|
||||||
|
|
||||||
// TODO: try to put nullptr here. then we don't need direct libc dep
|
// TODO: try to put nullptr here. then we don't need direct libc dep
|
||||||
if adl_main_control_create(Some(adl_malloc), 1) != 0 {
|
if adlmaincontrolcreate(Some(adl_malloc), 1) != 0 {
|
||||||
return Err("ADL initialization error!".into());
|
return Err("ADL initialization error!".into());
|
||||||
}
|
}
|
||||||
// NOTE: after this call, we must call ADL_Main_Control_Destroy
|
// NOTE: after this call, we must call AdlMainControlDestroy
|
||||||
// whenver we encounter an error
|
// whenver we encounter an error
|
||||||
|
|
||||||
let mut num_adapters: c_int = 0;
|
let mut num_adapters: c_int = 0;
|
||||||
@ -184,11 +184,11 @@ mod windows_impl {
|
|||||||
|
|
||||||
for adapter in adapter_info.iter() {
|
for adapter in adapter_info.iter() {
|
||||||
let mut is_active = 0;
|
let mut is_active = 0;
|
||||||
adl_adapter_active_get(adapter.iAdapterIndex, &mut is_active);
|
AdlAdapterActiveGet(adapter.iAdapterIndex, &mut is_active);
|
||||||
|
|
||||||
if is_active != 0 {
|
if is_active != 0 {
|
||||||
let mut vram_mb = 0;
|
let mut vram_mb = 0;
|
||||||
let _ = adl_get_dedicated_vram_usage(
|
let _ = AdlGetDedicatedVramUsage(
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
adapter.iAdapterIndex,
|
adapter.iAdapterIndex,
|
||||||
&mut vram_mb,
|
&mut vram_mb,
|
||||||
@ -202,7 +202,7 @@ mod windows_impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adl_main_control_destroy();
|
adlmaincontroldestroy();
|
||||||
|
|
||||||
Ok(vram_usages)
|
Ok(vram_usages)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::time::Duration;
|
|
||||||
use sysinfo::{Pid, System};
|
use sysinfo::{Pid, System};
|
||||||
use tauri::{Manager, Runtime, State};
|
use tauri::{Manager, Runtime, State};
|
||||||
use tokio::time::timeout;
|
|
||||||
|
|
||||||
use crate::state::{LlamacppState, SessionInfo};
|
use crate::state::{LlamacppState, SessionInfo};
|
||||||
use jan_utils::generate_random_port;
|
use jan_utils::generate_random_port;
|
||||||
@ -56,6 +54,8 @@ pub async fn get_random_available_port<R: Runtime>(
|
|||||||
pub async fn graceful_terminate_process(child: &mut tokio::process::Child) {
|
pub async fn graceful_terminate_process(child: &mut tokio::process::Child) {
|
||||||
use nix::sys::signal::{kill, Signal};
|
use nix::sys::signal::{kill, Signal};
|
||||||
use nix::unistd::Pid;
|
use nix::unistd::Pid;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
if let Some(raw_pid) = child.id() {
|
if let Some(raw_pid) = child.id() {
|
||||||
let raw_pid = raw_pid as i32;
|
let raw_pid = raw_pid as i32;
|
||||||
|
|||||||
@ -81,7 +81,6 @@ pub fn setup_library_path(library_path: Option<&str>, command: &mut tokio::proce
|
|||||||
pub fn setup_windows_process_flags(command: &mut tokio::process::Command) {
|
pub fn setup_windows_process_flags(command: &mut tokio::process::Command) {
|
||||||
#[cfg(all(windows, target_arch = "x86_64"))]
|
#[cfg(all(windows, target_arch = "x86_64"))]
|
||||||
{
|
{
|
||||||
use std::os::windows::process::CommandExt;
|
|
||||||
const CREATE_NO_WINDOW: u32 = 0x0800_0000;
|
const CREATE_NO_WINDOW: u32 = 0x0800_0000;
|
||||||
const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200;
|
const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200;
|
||||||
command.creation_flags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP);
|
command.creation_flags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP);
|
||||||
|
|||||||
@ -1,20 +1,30 @@
|
|||||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||||
import { useChat } from '../useChat'
|
|
||||||
import * as completionLib from '@/lib/completion'
|
|
||||||
import * as messagesLib from '@/lib/messages'
|
|
||||||
import { MessageStatus, ContentType } from '@janhq/core'
|
import { MessageStatus, ContentType } from '@janhq/core'
|
||||||
|
|
||||||
// Store mock functions for assertions
|
// Store mock functions for assertions - initialize immediately
|
||||||
let mockAddMessage: ReturnType<typeof vi.fn>
|
const mockAddMessage = vi.fn()
|
||||||
let mockUpdateMessage: ReturnType<typeof vi.fn>
|
const mockUpdateMessage = vi.fn()
|
||||||
let mockGetMessages: ReturnType<typeof vi.fn>
|
const mockGetMessages = vi.fn(() => [])
|
||||||
let mockStartModel: ReturnType<typeof vi.fn>
|
const mockStartModel = vi.fn(() => Promise.resolve())
|
||||||
let mockSendCompletion: ReturnType<typeof vi.fn>
|
const mockSendCompletion = vi.fn(() => Promise.resolve({
|
||||||
let mockPostMessageProcessing: ReturnType<typeof vi.fn>
|
choices: [{
|
||||||
let mockCompletionMessagesBuilder: any
|
message: {
|
||||||
let mockSetPrompt: ReturnType<typeof vi.fn>
|
content: 'AI response',
|
||||||
let mockResetTokenSpeed: ReturnType<typeof vi.fn>
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}))
|
||||||
|
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
|
||||||
|
Promise.resolve(content)
|
||||||
|
)
|
||||||
|
const mockCompletionMessagesBuilder = {
|
||||||
|
addUserMessage: vi.fn(),
|
||||||
|
addAssistantMessage: vi.fn(),
|
||||||
|
getMessages: vi.fn(() => []),
|
||||||
|
}
|
||||||
|
const mockSetPrompt = vi.fn()
|
||||||
|
const mockResetTokenSpeed = vi.fn()
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('../usePrompt', () => ({
|
vi.mock('../usePrompt', () => ({
|
||||||
@ -231,12 +241,12 @@ vi.mock('@/lib/completion', () => ({
|
|||||||
extractToolCall: vi.fn(),
|
extractToolCall: vi.fn(),
|
||||||
newUserThreadContent: vi.fn((threadId, content) => ({
|
newUserThreadContent: vi.fn((threadId, content) => ({
|
||||||
thread_id: threadId,
|
thread_id: threadId,
|
||||||
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
|
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||||
role: 'user'
|
role: 'user'
|
||||||
})),
|
})),
|
||||||
newAssistantThreadContent: vi.fn((threadId, content) => ({
|
newAssistantThreadContent: vi.fn((threadId, content) => ({
|
||||||
thread_id: threadId,
|
thread_id: threadId,
|
||||||
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
|
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||||
role: 'assistant'
|
role: 'assistant'
|
||||||
})),
|
})),
|
||||||
sendCompletion: mockSendCompletion,
|
sendCompletion: mockSendCompletion,
|
||||||
@ -274,33 +284,37 @@ vi.mock('sonner', () => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// Import after mocks to avoid hoisting issues
|
||||||
|
const { useChat } = await import('../useChat')
|
||||||
|
const completionLib = await import('@/lib/completion')
|
||||||
|
const messagesLib = await import('@/lib/messages')
|
||||||
|
|
||||||
describe('useChat', () => {
|
describe('useChat', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
// Reset all mocks
|
// Clear mock call history
|
||||||
mockAddMessage = vi.fn()
|
vi.clearAllMocks()
|
||||||
mockUpdateMessage = vi.fn()
|
|
||||||
mockGetMessages = vi.fn(() => [])
|
// Reset mock implementations
|
||||||
mockStartModel = vi.fn(() => Promise.resolve())
|
mockAddMessage.mockClear()
|
||||||
mockSetPrompt = vi.fn()
|
mockUpdateMessage.mockClear()
|
||||||
mockResetTokenSpeed = vi.fn()
|
mockGetMessages.mockReturnValue([])
|
||||||
mockSendCompletion = vi.fn(() => Promise.resolve({
|
mockStartModel.mockResolvedValue(undefined)
|
||||||
|
mockSetPrompt.mockClear()
|
||||||
|
mockResetTokenSpeed.mockClear()
|
||||||
|
mockSendCompletion.mockResolvedValue({
|
||||||
choices: [{
|
choices: [{
|
||||||
message: {
|
message: {
|
||||||
content: 'AI response',
|
content: 'AI response',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
},
|
},
|
||||||
}],
|
}],
|
||||||
}))
|
})
|
||||||
mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
|
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
|
||||||
Promise.resolve(content)
|
Promise.resolve(content)
|
||||||
)
|
)
|
||||||
mockCompletionMessagesBuilder = {
|
mockCompletionMessagesBuilder.addUserMessage.mockClear()
|
||||||
addUserMessage: vi.fn(),
|
mockCompletionMessagesBuilder.addAssistantMessage.mockClear()
|
||||||
addAssistantMessage: vi.fn(),
|
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
|
||||||
getMessages: vi.fn(() => []),
|
|
||||||
}
|
|
||||||
|
|
||||||
vi.clearAllMocks()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@ -344,7 +358,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
@ -369,7 +383,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
@ -393,13 +407,13 @@ describe('useChat', () => {
|
|||||||
id: 'msg-1',
|
id: 'msg-1',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Hello', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }],
|
||||||
}
|
}
|
||||||
const stoppedMessage = {
|
const stoppedMessage = {
|
||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
}
|
}
|
||||||
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
|
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
|
||||||
@ -423,7 +437,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
@ -450,7 +464,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
@ -520,7 +534,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
@ -563,7 +577,7 @@ describe('useChat', () => {
|
|||||||
id: 'msg-123',
|
id: 'msg-123',
|
||||||
thread_id: 'test-thread',
|
thread_id: 'test-thread',
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,6 +38,11 @@ type AppState = {
|
|||||||
updateTools: (tools: MCPTool[]) => void
|
updateTools: (tools: MCPTool[]) => void
|
||||||
setAbortController: (threadId: string, controller: AbortController) => void
|
setAbortController: (threadId: string, controller: AbortController) => void
|
||||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
||||||
|
setTokenSpeed: (
|
||||||
|
message: ThreadMessage,
|
||||||
|
speed: number,
|
||||||
|
completionTokens: number
|
||||||
|
) => void
|
||||||
resetTokenSpeed: () => void
|
resetTokenSpeed: () => void
|
||||||
clearAppState: () => void
|
clearAppState: () => void
|
||||||
setOutOfContextDialog: (show: boolean) => void
|
setOutOfContextDialog: (show: boolean) => void
|
||||||
@ -96,6 +101,17 @@ export const useAppState = create<AppState>()((set) => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
},
|
},
|
||||||
|
setTokenSpeed: (message, speed, completionTokens) => {
|
||||||
|
set((state) => ({
|
||||||
|
tokenSpeed: {
|
||||||
|
...state.tokenSpeed,
|
||||||
|
lastTimestamp: new Date().getTime(),
|
||||||
|
tokenSpeed: speed,
|
||||||
|
tokenCount: completionTokens,
|
||||||
|
message: message.id,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
},
|
||||||
updateTokenSpeed: (message, increment = 1) =>
|
updateTokenSpeed: (message, increment = 1) =>
|
||||||
set((state) => {
|
set((state) => {
|
||||||
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
||||||
|
|||||||
@ -19,7 +19,10 @@ import {
|
|||||||
} from '@/lib/completion'
|
} from '@/lib/completion'
|
||||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||||
import { renderInstructions } from '@/lib/instructionTemplate'
|
import { renderInstructions } from '@/lib/instructionTemplate'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import {
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
CompletionUsage,
|
||||||
|
} from 'openai/resources'
|
||||||
import { MessageStatus, ContentType } from '@janhq/core'
|
import { MessageStatus, ContentType } from '@janhq/core'
|
||||||
|
|
||||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||||
@ -98,7 +101,10 @@ const processStreamingCompletion = async (
|
|||||||
currentCall: ChatCompletionMessageToolCall | null,
|
currentCall: ChatCompletionMessageToolCall | null,
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||||
|
setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void,
|
||||||
updatePromptProgress: (progress: unknown) => void,
|
updatePromptProgress: (progress: unknown) => void,
|
||||||
|
timeToFirstToken: number,
|
||||||
|
tokenUsageRef: { current: CompletionUsage | undefined },
|
||||||
continueFromMessageId?: string,
|
continueFromMessageId?: string,
|
||||||
updateMessage?: (message: ThreadMessage) => void,
|
updateMessage?: (message: ThreadMessage) => void,
|
||||||
continueFromMessage?: ThreadMessage
|
continueFromMessage?: ThreadMessage
|
||||||
@ -128,7 +134,14 @@ const processStreamingCompletion = async (
|
|||||||
updateStreamingContent(currentContent)
|
updateStreamingContent(currentContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pendingDeltaCount > 0) {
|
if (tokenUsageRef.current) {
|
||||||
|
setTokenSpeed(
|
||||||
|
currentContent,
|
||||||
|
tokenUsageRef.current.completion_tokens /
|
||||||
|
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||||
|
tokenUsageRef.current.completion_tokens
|
||||||
|
)
|
||||||
|
} else if (pendingDeltaCount > 0) {
|
||||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||||
}
|
}
|
||||||
pendingDeltaCount = 0
|
pendingDeltaCount = 0
|
||||||
@ -183,6 +196,10 @@ const processStreamingCompletion = async (
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('usage' in part && part.usage) {
|
||||||
|
tokenUsageRef.current = part.usage
|
||||||
|
}
|
||||||
|
|
||||||
if (part.choices[0]?.delta?.tool_calls) {
|
if (part.choices[0]?.delta?.tool_calls) {
|
||||||
extractToolCall(part, currentCall, toolCalls)
|
extractToolCall(part, currentCall, toolCalls)
|
||||||
// Schedule a flush to reflect tool update
|
// Schedule a flush to reflect tool update
|
||||||
@ -221,6 +238,7 @@ export const useChat = () => {
|
|||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updateLoadingModel,
|
updateLoadingModel,
|
||||||
setAbortController,
|
setAbortController,
|
||||||
|
setTokenSpeed,
|
||||||
] = useAppState(
|
] = useAppState(
|
||||||
useShallow((state) => [
|
useShallow((state) => [
|
||||||
state.updateTokenSpeed,
|
state.updateTokenSpeed,
|
||||||
@ -228,6 +246,7 @@ export const useChat = () => {
|
|||||||
state.updateStreamingContent,
|
state.updateStreamingContent,
|
||||||
state.updateLoadingModel,
|
state.updateLoadingModel,
|
||||||
state.setAbortController,
|
state.setAbortController,
|
||||||
|
state.setTokenSpeed,
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
const updatePromptProgress = useAppState(
|
const updatePromptProgress = useAppState(
|
||||||
@ -541,10 +560,18 @@ export const useChat = () => {
|
|||||||
if (!completion) throw new Error('No completion received')
|
if (!completion) throw new Error('No completion received')
|
||||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
|
const timeToFirstToken = Date.now()
|
||||||
|
let tokenUsage: CompletionUsage | undefined = undefined
|
||||||
try {
|
try {
|
||||||
if (isCompletionResponse(completion)) {
|
if (isCompletionResponse(completion)) {
|
||||||
const message = completion.choices[0]?.message
|
const message = completion.choices[0]?.message
|
||||||
accumulatedTextRef.value = (message?.content as string) || ''
|
// When continuing, append to existing content; otherwise replace
|
||||||
|
const newContent = (message?.content as string) || ''
|
||||||
|
if (continueFromMessageId && accumulatedTextRef.value) {
|
||||||
|
accumulatedTextRef.value += newContent
|
||||||
|
} else {
|
||||||
|
accumulatedTextRef.value = newContent
|
||||||
|
}
|
||||||
|
|
||||||
// Handle reasoning field if there is one
|
// Handle reasoning field if there is one
|
||||||
const reasoning = extractReasoningFromMessage(message)
|
const reasoning = extractReasoningFromMessage(message)
|
||||||
@ -556,7 +583,11 @@ export const useChat = () => {
|
|||||||
if (message?.tool_calls) {
|
if (message?.tool_calls) {
|
||||||
toolCalls.push(...message.tool_calls)
|
toolCalls.push(...message.tool_calls)
|
||||||
}
|
}
|
||||||
|
if ('usage' in completion) {
|
||||||
|
tokenUsage = completion.usage
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
const tokenUsageRef = { current: tokenUsage }
|
||||||
await processStreamingCompletion(
|
await processStreamingCompletion(
|
||||||
completion,
|
completion,
|
||||||
abortController,
|
abortController,
|
||||||
@ -566,11 +597,15 @@ export const useChat = () => {
|
|||||||
currentCall,
|
currentCall,
|
||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
|
setTokenSpeed,
|
||||||
updatePromptProgress,
|
updatePromptProgress,
|
||||||
|
timeToFirstToken,
|
||||||
|
tokenUsageRef,
|
||||||
continueFromMessageId,
|
continueFromMessageId,
|
||||||
updateMessage,
|
updateMessage,
|
||||||
continueFromMessage
|
continueFromMessage
|
||||||
)
|
)
|
||||||
|
tokenUsage = tokenUsageRef.current
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
@ -797,6 +832,7 @@ export const useChat = () => {
|
|||||||
allowAllMCPPermissions,
|
allowAllMCPPermissions,
|
||||||
showApprovalModal,
|
showApprovalModal,
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
|
setTokenSpeed,
|
||||||
showIncreaseContextSizeModal,
|
showIncreaseContextSizeModal,
|
||||||
increaseModelContextSize,
|
increaseModelContextSize,
|
||||||
toggleOnContextShifting,
|
toggleOnContextShifting,
|
||||||
|
|||||||
@ -54,7 +54,6 @@ export const useThreadScrolling = (
|
|||||||
}
|
}
|
||||||
}, [scrollContainerRef])
|
}, [scrollContainerRef])
|
||||||
|
|
||||||
|
|
||||||
const handleScroll = useCallback((e: Event) => {
|
const handleScroll = useCallback((e: Event) => {
|
||||||
const target = e.target as HTMLDivElement
|
const target = e.target as HTMLDivElement
|
||||||
const { scrollTop, scrollHeight, clientHeight } = target
|
const { scrollTop, scrollHeight, clientHeight } = target
|
||||||
@ -69,7 +68,7 @@ export const useThreadScrolling = (
|
|||||||
setIsAtBottom(isBottom)
|
setIsAtBottom(isBottom)
|
||||||
setHasScrollbar(hasScroll)
|
setHasScrollbar(hasScroll)
|
||||||
lastScrollTopRef.current = scrollTop
|
lastScrollTopRef.current = scrollTop
|
||||||
}, [streamingContent])
|
}, [streamingContent, setIsAtBottom, setHasScrollbar])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const scrollContainer = scrollContainerRef.current
|
const scrollContainer = scrollContainerRef.current
|
||||||
@ -90,7 +89,7 @@ export const useThreadScrolling = (
|
|||||||
|
|
||||||
setIsAtBottom(isBottom)
|
setIsAtBottom(isBottom)
|
||||||
setHasScrollbar(hasScroll)
|
setHasScrollbar(hasScroll)
|
||||||
}, [scrollContainerRef])
|
}, [scrollContainerRef, setIsAtBottom, setHasScrollbar])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!scrollContainerRef.current) return
|
if (!scrollContainerRef.current) return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user