Compare commits
12 Commits
dev
...
feat/retai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35264e9a22 | ||
|
|
e7c9275488 | ||
|
|
4ac45aba23 | ||
|
|
34036d895a | ||
|
|
7127ff1244 | ||
|
|
1c0e135077 | ||
|
|
99473ed568 | ||
|
|
52f73af08c | ||
|
|
ccca331d6c | ||
|
|
f4b187ba11 | ||
|
|
4ea9d296ea | ||
|
|
2e86d4e421 |
@ -20,6 +20,13 @@ export interface MessageInterface {
|
||||
*/
|
||||
listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
|
||||
/**
|
||||
* Updates an existing message in a thread.
|
||||
* @param {ThreadMessage} message - The message to be updated (must have existing ID).
|
||||
* @returns {Promise<ThreadMessage>} A promise that resolves to the updated message.
|
||||
*/
|
||||
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||
|
||||
/**
|
||||
* Deletes a specific message from a thread.
|
||||
* @param {string} threadId - The ID of the thread from which the message will be deleted.
|
||||
|
||||
@ -176,7 +176,6 @@ const ChatInput = ({
|
||||
const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP)
|
||||
const MCPToolComponent = mcpExtension?.getToolComponent?.()
|
||||
|
||||
|
||||
const handleSendMesage = async (prompt: string) => {
|
||||
if (!selectedModel) {
|
||||
setMessage('Please select a model to start chatting.')
|
||||
|
||||
@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { Play } from 'lucide-react'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useMemo } from 'react'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
const { t } = useTranslation()
|
||||
@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
}))
|
||||
)
|
||||
const sendMessage = useChat()
|
||||
|
||||
// Detect if last message is a partial assistant response (user stopped midway)
|
||||
const isPartialResponse = useMemo(() => {
|
||||
if (!messages || messages.length < 2) return false
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
const secondLastMessage = messages[messages.length - 2]
|
||||
|
||||
return (
|
||||
lastMessage?.role === 'assistant' &&
|
||||
lastMessage?.status === MessageStatus.Stopped &&
|
||||
secondLastMessage?.role === 'user' &&
|
||||
!lastMessage?.metadata?.tool_calls
|
||||
)
|
||||
}, [messages])
|
||||
|
||||
const generateAIResponse = () => {
|
||||
if (isPartialResponse) {
|
||||
const partialMessage = messages[messages.length - 1]
|
||||
const userMessage = messages[messages.length - 2]
|
||||
if (userMessage?.content?.[0]?.text?.value) {
|
||||
sendMessage(
|
||||
userMessage.content[0].text.value,
|
||||
false,
|
||||
undefined,
|
||||
partialMessage.id
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const latestUserMessage = messages[messages.length - 1]
|
||||
if (
|
||||
latestUserMessage?.content?.[0]?.text?.value &&
|
||||
@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
|
||||
onClick={generateAIResponse}
|
||||
>
|
||||
<p className="text-xs">{t('common:generateAiResponse')}</p>
|
||||
<p className="text-xs">
|
||||
{isPartialResponse
|
||||
? t('common:continueAiResponse')
|
||||
: t('common:generateAiResponse')}
|
||||
</p>
|
||||
<Play size={12} />
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
|
||||
import { ArrowDown } from 'lucide-react'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
const ScrollToBottom = ({
|
||||
threadId,
|
||||
@ -28,11 +29,20 @@ const ScrollToBottom = ({
|
||||
|
||||
const streamingContent = useAppState((state) => state.streamingContent)
|
||||
|
||||
// Check if last message is a partial assistant response and show continue buton (user interrupted)
|
||||
const isPartialResponse =
|
||||
messages.length >= 2 &&
|
||||
messages[messages.length - 1]?.role === 'assistant' &&
|
||||
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
|
||||
messages[messages.length - 2]?.role === 'user' &&
|
||||
!messages[messages.length - 1]?.metadata?.tool_calls
|
||||
|
||||
const showGenerateAIResponseBtn =
|
||||
(messages[messages.length - 1]?.role === 'user' ||
|
||||
((messages[messages.length - 1]?.role === 'user' ||
|
||||
(messages[messages.length - 1]?.metadata &&
|
||||
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) &&
|
||||
!streamingContent
|
||||
'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
|
||||
isPartialResponse) &&
|
||||
!streamingContent)
|
||||
|
||||
return (
|
||||
<div
|
||||
|
||||
@ -2,6 +2,7 @@ import { useAppState } from '@/hooks/useAppState'
|
||||
import { ThreadContent } from './ThreadContent'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useMessages } from '@/hooks/useMessages'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
type Props = {
|
||||
threadId: string
|
||||
@ -48,12 +49,19 @@ export const StreamingContent = memo(({ threadId }: Props) => {
|
||||
return extractReasoningSegment(text)
|
||||
}, [lastAssistant])
|
||||
|
||||
if (!streamingContent || streamingContent.thread_id !== threadId) return null
|
||||
if (!streamingContent || streamingContent.thread_id !== threadId) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Don't show streaming content if there's already a stopped message
|
||||
if (lastAssistant?.status === MessageStatus.Stopped) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Pass a new object to ThreadContent to avoid reference issues
|
||||
// The streaming content is always the last message
|
||||
return (
|
||||
|
||||
@ -1,6 +1,30 @@
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { useChat } from '../useChat'
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { MessageStatus, ContentType } from '@janhq/core'
|
||||
|
||||
// Store mock functions for assertions - initialize immediately
|
||||
const mockAddMessage = vi.fn()
|
||||
const mockUpdateMessage = vi.fn()
|
||||
const mockGetMessages = vi.fn(() => [])
|
||||
const mockStartModel = vi.fn(() => Promise.resolve())
|
||||
const mockSendCompletion = vi.fn(() => Promise.resolve({
|
||||
choices: [{
|
||||
message: {
|
||||
content: 'AI response',
|
||||
role: 'assistant',
|
||||
},
|
||||
}],
|
||||
}))
|
||||
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
|
||||
Promise.resolve(content)
|
||||
)
|
||||
const mockCompletionMessagesBuilder = {
|
||||
addUserMessage: vi.fn(),
|
||||
addAssistantMessage: vi.fn(),
|
||||
getMessages: vi.fn(() => []),
|
||||
}
|
||||
const mockSetPrompt = vi.fn()
|
||||
const mockResetTokenSpeed = vi.fn()
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../usePrompt', () => ({
|
||||
@ -8,11 +32,16 @@ vi.mock('../usePrompt', () => ({
|
||||
(selector: any) => {
|
||||
const state = {
|
||||
prompt: 'test prompt',
|
||||
setPrompt: vi.fn(),
|
||||
setPrompt: mockSetPrompt,
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) }
|
||||
{
|
||||
getState: () => ({
|
||||
prompt: 'test prompt',
|
||||
setPrompt: mockSetPrompt
|
||||
})
|
||||
}
|
||||
),
|
||||
}))
|
||||
|
||||
@ -22,39 +51,58 @@ vi.mock('../useAppState', () => ({
|
||||
const state = {
|
||||
tools: [],
|
||||
updateTokenSpeed: vi.fn(),
|
||||
resetTokenSpeed: vi.fn(),
|
||||
resetTokenSpeed: mockResetTokenSpeed,
|
||||
updateTools: vi.fn(),
|
||||
updateStreamingContent: vi.fn(),
|
||||
updatePromptProgress: vi.fn(),
|
||||
updateLoadingModel: vi.fn(),
|
||||
setAbortController: vi.fn(),
|
||||
streamingContent: undefined,
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
{
|
||||
getState: vi.fn(() => ({
|
||||
tools: [],
|
||||
tokenSpeed: { tokensPerSecond: 10 },
|
||||
streamingContent: undefined,
|
||||
}))
|
||||
}
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../useAssistant', () => ({
|
||||
useAssistant: (selector: any) => {
|
||||
const state = {
|
||||
assistants: [{
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
}],
|
||||
currentAssistant: {
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
},
|
||||
useAssistant: Object.assign(
|
||||
(selector: any) => {
|
||||
const state = {
|
||||
assistants: [{
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
}],
|
||||
currentAssistant: {
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
},
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
{
|
||||
getState: () => ({
|
||||
assistants: [{
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
}],
|
||||
currentAssistant: {
|
||||
id: 'test-assistant',
|
||||
instructions: 'test instructions',
|
||||
parameters: { stream: true },
|
||||
},
|
||||
})
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../useModelProvider', () => ({
|
||||
@ -62,14 +110,15 @@ vi.mock('../useModelProvider', () => ({
|
||||
(selector: any) => {
|
||||
const state = {
|
||||
getProviderByName: vi.fn(() => ({
|
||||
provider: 'openai',
|
||||
provider: 'llamacpp',
|
||||
models: [],
|
||||
settings: [],
|
||||
})),
|
||||
selectedModel: {
|
||||
id: 'test-model',
|
||||
capabilities: ['tools'],
|
||||
},
|
||||
selectedProvider: 'openai',
|
||||
selectedProvider: 'llamacpp',
|
||||
updateProvider: vi.fn(),
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
@ -77,14 +126,15 @@ vi.mock('../useModelProvider', () => ({
|
||||
{
|
||||
getState: () => ({
|
||||
getProviderByName: vi.fn(() => ({
|
||||
provider: 'openai',
|
||||
provider: 'llamacpp',
|
||||
models: [],
|
||||
settings: [],
|
||||
})),
|
||||
selectedModel: {
|
||||
id: 'test-model',
|
||||
capabilities: ['tools'],
|
||||
},
|
||||
selectedProvider: 'openai',
|
||||
selectedProvider: 'llamacpp',
|
||||
updateProvider: vi.fn(),
|
||||
})
|
||||
}
|
||||
@ -96,11 +146,11 @@ vi.mock('../useThreads', () => ({
|
||||
const state = {
|
||||
getCurrentThread: vi.fn(() => ({
|
||||
id: 'test-thread',
|
||||
model: { id: 'test-model', provider: 'openai' },
|
||||
model: { id: 'test-model', provider: 'llamacpp' },
|
||||
})),
|
||||
createThread: vi.fn(() => Promise.resolve({
|
||||
id: 'test-thread',
|
||||
model: { id: 'test-model', provider: 'openai' },
|
||||
model: { id: 'test-model', provider: 'llamacpp' },
|
||||
})),
|
||||
updateThreadTimestamp: vi.fn(),
|
||||
}
|
||||
@ -111,22 +161,33 @@ vi.mock('../useThreads', () => ({
|
||||
vi.mock('../useMessages', () => ({
|
||||
useMessages: (selector: any) => {
|
||||
const state = {
|
||||
getMessages: vi.fn(() => []),
|
||||
addMessage: vi.fn(),
|
||||
getMessages: mockGetMessages,
|
||||
addMessage: mockAddMessage,
|
||||
updateMessage: mockUpdateMessage,
|
||||
setMessages: vi.fn(),
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../useToolApproval', () => ({
|
||||
useToolApproval: (selector: any) => {
|
||||
const state = {
|
||||
approvedTools: [],
|
||||
showApprovalModal: vi.fn(),
|
||||
allowAllMCPPermissions: false,
|
||||
useToolApproval: Object.assign(
|
||||
(selector: any) => {
|
||||
const state = {
|
||||
approvedTools: [],
|
||||
showApprovalModal: vi.fn(),
|
||||
allowAllMCPPermissions: false,
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
{
|
||||
getState: () => ({
|
||||
approvedTools: [],
|
||||
showApprovalModal: vi.fn(),
|
||||
allowAllMCPPermissions: false,
|
||||
})
|
||||
}
|
||||
return selector ? selector(state) : state
|
||||
},
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../useToolAvailable', () => ({
|
||||
@ -162,38 +223,57 @@ vi.mock('@tanstack/react-router', () => ({
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('../useServiceHub', () => ({
|
||||
useServiceHub: vi.fn(() => ({
|
||||
models: () => ({
|
||||
startModel: mockStartModel,
|
||||
stopModel: vi.fn(() => Promise.resolve()),
|
||||
stopAllModels: vi.fn(() => Promise.resolve()),
|
||||
}),
|
||||
providers: () => ({
|
||||
updateSettings: vi.fn(() => Promise.resolve()),
|
||||
}),
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/completion', () => ({
|
||||
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
||||
extractToolCall: vi.fn(),
|
||||
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
|
||||
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
|
||||
sendCompletion: vi.fn(),
|
||||
postMessageProcessing: vi.fn(),
|
||||
isCompletionResponse: vi.fn(),
|
||||
newUserThreadContent: vi.fn((threadId, content) => ({
|
||||
thread_id: threadId,
|
||||
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||
role: 'user'
|
||||
})),
|
||||
newAssistantThreadContent: vi.fn((threadId, content) => ({
|
||||
thread_id: threadId,
|
||||
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||
role: 'assistant'
|
||||
})),
|
||||
sendCompletion: mockSendCompletion,
|
||||
postMessageProcessing: mockPostMessageProcessing,
|
||||
isCompletionResponse: vi.fn(() => true),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/messages', () => ({
|
||||
CompletionMessagesBuilder: vi.fn(() => ({
|
||||
addUserMessage: vi.fn(),
|
||||
addAssistantMessage: vi.fn(),
|
||||
getMessages: vi.fn(() => []),
|
||||
CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/instructionTemplate', () => ({
|
||||
renderInstructions: vi.fn((instructions: string) => instructions),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/reasoning', () => ({
|
||||
ReasoningProcessor: vi.fn(() => ({
|
||||
processReasoningChunk: vi.fn(() => null),
|
||||
finalize: vi.fn(() => ''),
|
||||
})),
|
||||
extractReasoningFromMessage: vi.fn(() => null),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/mcp', () => ({
|
||||
getTools: vi.fn(() => Promise.resolve([])),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/models', () => ({
|
||||
startModel: vi.fn(() => Promise.resolve()),
|
||||
stopModel: vi.fn(() => Promise.resolve()),
|
||||
stopAllModels: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/providers', () => ({
|
||||
updateSettings: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@tauri-apps/api/event', () => ({
|
||||
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
||||
}))
|
||||
@ -204,9 +284,41 @@ vi.mock('sonner', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Import after mocks to avoid hoisting issues
|
||||
const { useChat } = await import('../useChat')
|
||||
const completionLib = await import('@/lib/completion')
|
||||
const messagesLib = await import('@/lib/messages')
|
||||
|
||||
describe('useChat', () => {
|
||||
beforeEach(() => {
|
||||
// Clear mock call history
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Reset mock implementations
|
||||
mockAddMessage.mockClear()
|
||||
mockUpdateMessage.mockClear()
|
||||
mockGetMessages.mockReturnValue([])
|
||||
mockStartModel.mockResolvedValue(undefined)
|
||||
mockSetPrompt.mockClear()
|
||||
mockResetTokenSpeed.mockClear()
|
||||
mockSendCompletion.mockResolvedValue({
|
||||
choices: [{
|
||||
message: {
|
||||
content: 'AI response',
|
||||
role: 'assistant',
|
||||
},
|
||||
}],
|
||||
})
|
||||
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
|
||||
Promise.resolve(content)
|
||||
)
|
||||
mockCompletionMessagesBuilder.addUserMessage.mockClear()
|
||||
mockCompletionMessagesBuilder.addAssistantMessage.mockClear()
|
||||
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllTimers()
|
||||
})
|
||||
|
||||
it('returns sendMessage function', () => {
|
||||
@ -216,13 +328,270 @@ describe('useChat', () => {
|
||||
expect(typeof result.current).toBe('function')
|
||||
})
|
||||
|
||||
it('sends message successfully', async () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
describe('Continue with AI response functionality', () => {
|
||||
it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('Hello world')
|
||||
await act(async () => {
|
||||
await result.current('Hello world', true, undefined, undefined, undefined)
|
||||
})
|
||||
|
||||
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||
'test-thread',
|
||||
'Hello world',
|
||||
undefined
|
||||
)
|
||||
expect(mockAddMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
thread_id: 'test-thread',
|
||||
role: 'user',
|
||||
})
|
||||
)
|
||||
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||
'Hello world',
|
||||
undefined
|
||||
)
|
||||
})
|
||||
|
||||
expect(result.current).toBeDefined()
|
||||
it('should NOT add new user message when continueFromMessageId is provided', async () => {
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
|
||||
const userMessageCalls = mockAddMessage.mock.calls.filter(
|
||||
(call: any) => call[0]?.role === 'user'
|
||||
)
|
||||
expect(userMessageCalls).toHaveLength(0)
|
||||
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should add partial assistant message to builder when continuing', async () => {
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
// Should be called twice: once with partial message (line 517-521), once after completion (line 689)
|
||||
const assistantCalls = mockCompletionMessagesBuilder.addAssistantMessage.mock.calls
|
||||
expect(assistantCalls.length).toBeGreaterThanOrEqual(1)
|
||||
// First call should be with the partial response content
|
||||
expect(assistantCalls[0]).toEqual([
|
||||
'Partial response',
|
||||
undefined,
|
||||
[]
|
||||
])
|
||||
})
|
||||
|
||||
it('should filter out stopped message from context when continuing', async () => {
|
||||
const userMsg = {
|
||||
id: 'msg-1',
|
||||
thread_id: 'test-thread',
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }],
|
||||
}
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
// The CompletionMessagesBuilder is called with filtered messages (line 507-512)
|
||||
// The stopped message should be filtered out from the context
|
||||
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalled()
|
||||
const builderCall = (messagesLib.CompletionMessagesBuilder as any).mock.calls[0]
|
||||
expect(builderCall[0]).toEqual([userMsg]) // stopped message filtered out
|
||||
expect(builderCall[1]).toEqual('test instructions')
|
||||
})
|
||||
|
||||
it('should update existing message instead of adding new one when continuing', async () => {
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
// finalizeMessage is called at line 700-708, which should update the message
|
||||
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'msg-123',
|
||||
status: MessageStatus.Ready,
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should start with previous content when continuing', async () => {
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
mockSendCompletion.mockResolvedValue({
|
||||
choices: [{
|
||||
message: {
|
||||
content: ' continued',
|
||||
role: 'assistant',
|
||||
},
|
||||
}],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
// The accumulated text should contain the previous content plus new content
|
||||
// accumulatedTextRef starts with 'Partial response' (line 490)
|
||||
// Then gets ' continued' appended (line 585)
|
||||
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'msg-123',
|
||||
content: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
text: expect.objectContaining({
|
||||
value: 'Partial response continued',
|
||||
})
|
||||
})
|
||||
])
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle attachments correctly when not continuing', async () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
const attachments = [
|
||||
{
|
||||
name: 'test.png',
|
||||
type: 'image/png',
|
||||
size: 1024,
|
||||
base64: 'base64data',
|
||||
dataUrl: 'data:image/png;base64,base64data',
|
||||
},
|
||||
]
|
||||
|
||||
await act(async () => {
|
||||
await result.current('Message with attachment', true, attachments, undefined, undefined)
|
||||
})
|
||||
|
||||
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||
'test-thread',
|
||||
'Message with attachment',
|
||||
attachments
|
||||
)
|
||||
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||
'Message with attachment',
|
||||
attachments
|
||||
)
|
||||
})
|
||||
|
||||
it('should preserve message status as Ready after continuation completes', async () => {
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
// finalContent is created at line 678-683 with status Ready when continuing
|
||||
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'msg-123',
|
||||
status: MessageStatus.Ready,
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Normal message sending', () => {
|
||||
it('sends message successfully without continuation', async () => {
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('Hello world')
|
||||
})
|
||||
|
||||
expect(mockSendCompletion).toHaveBeenCalled()
|
||||
expect(mockStartModel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle errors gracefully during continuation', async () => {
|
||||
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
|
||||
const stoppedMessage = {
|
||||
id: 'msg-123',
|
||||
thread_id: 'test-thread',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {},
|
||||
}
|
||||
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||
|
||||
const { result } = renderHook(() => useChat())
|
||||
|
||||
await act(async () => {
|
||||
await result.current('', true, undefined, undefined, 'msg-123')
|
||||
})
|
||||
|
||||
expect(result.current).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -225,9 +225,25 @@ describe('useMessages', () => {
|
||||
})
|
||||
)
|
||||
|
||||
// Wait for async operation
|
||||
// Message should be immediately available (optimistic update)
|
||||
expect(result.current.messages['thread1']).toContainEqual(
|
||||
expect.objectContaining({
|
||||
id: messageToAdd.id,
|
||||
thread_id: messageToAdd.thread_id,
|
||||
role: messageToAdd.role,
|
||||
content: messageToAdd.content,
|
||||
metadata: expect.objectContaining({
|
||||
assistant: expect.objectContaining({
|
||||
id: expect.any(String),
|
||||
name: expect.any(String),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Verify persistence was attempted
|
||||
await vi.waitFor(() => {
|
||||
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage)
|
||||
expect(mockCreateMessage).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import { MCPTool } from '@/types/completion'
|
||||
import { useAssistant } from './useAssistant'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
type PromptProgress = {
|
||||
export type PromptProgress = {
|
||||
cache: number
|
||||
processed: number
|
||||
time_ms: number
|
||||
|
||||
@ -3,7 +3,7 @@ import { flushSync } from 'react-dom'
|
||||
import { usePrompt } from './usePrompt'
|
||||
import { useModelProvider } from './useModelProvider'
|
||||
import { useThreads } from './useThreads'
|
||||
import { useAppState } from './useAppState'
|
||||
import { useAppState, type PromptProgress } from './useAppState'
|
||||
import { useMessages } from './useMessages'
|
||||
import { useRouter } from '@tanstack/react-router'
|
||||
import { defaultModel } from '@/lib/models'
|
||||
@ -23,6 +23,7 @@ import {
|
||||
ChatCompletionMessageToolCall,
|
||||
CompletionUsage,
|
||||
} from 'openai/resources'
|
||||
import { MessageStatus, ContentType, ThreadMessage } from '@janhq/core'
|
||||
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
@ -38,6 +39,198 @@ import { useAssistant } from './useAssistant'
|
||||
import { useShallow } from 'zustand/shallow'
|
||||
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
||||
|
||||
// Helper to create thread content with consistent structure
|
||||
const createThreadContent = (
|
||||
threadId: string,
|
||||
text: string,
|
||||
toolCalls: ChatCompletionMessageToolCall[],
|
||||
messageId?: string
|
||||
) => {
|
||||
const content = newAssistantThreadContent(threadId, text, {
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
})
|
||||
// If continuing from a message, preserve the message ID
|
||||
if (messageId) {
|
||||
return { ...content, id: messageId }
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// Helper to cancel animation frame cross-platform
|
||||
const cancelFrame = (handle: number | undefined) => {
|
||||
if (handle === undefined) return
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(handle)
|
||||
} else {
|
||||
clearTimeout(handle)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to finalize and save a message
|
||||
const finalizeMessage = (
|
||||
finalContent: ThreadMessage,
|
||||
addMessage: (message: ThreadMessage) => void,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updatePromptProgress: (progress: PromptProgress | undefined) => void,
|
||||
updateThreadTimestamp: (threadId: string) => void,
|
||||
updateMessage?: (message: ThreadMessage) => void,
|
||||
continueFromMessageId?: string
|
||||
) => {
|
||||
// If continuing from a message, update it; otherwise add new message
|
||||
if (continueFromMessageId && updateMessage) {
|
||||
updateMessage({ ...finalContent, id: continueFromMessageId })
|
||||
} else {
|
||||
addMessage(finalContent)
|
||||
}
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(finalContent.thread_id)
|
||||
}
|
||||
|
||||
// Helper to process streaming completion
|
||||
const processStreamingCompletion = async (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
completion: AsyncIterable<any>,
|
||||
abortController: AbortController,
|
||||
activeThread: Thread,
|
||||
accumulatedText: { value: string },
|
||||
toolCalls: ChatCompletionMessageToolCall[],
|
||||
currentCall: ChatCompletionMessageToolCall | null,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||
setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void,
|
||||
updatePromptProgress: (progress: PromptProgress | undefined) => void,
|
||||
timeToFirstToken: number,
|
||||
tokenUsageRef: { current: CompletionUsage | undefined },
|
||||
continueFromMessageId?: string,
|
||||
updateMessage?: (message: ThreadMessage) => void,
|
||||
continueFromMessage?: ThreadMessage
|
||||
) => {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
|
||||
const flushStreamingContent = () => {
|
||||
const currentContent = createThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText.value,
|
||||
toolCalls,
|
||||
continueFromMessageId
|
||||
)
|
||||
|
||||
// When continuing, update the message directly instead of using streamingContent
|
||||
if (continueFromMessageId && updateMessage && continueFromMessage) {
|
||||
updateMessage({
|
||||
...continueFromMessage, // Preserve original message metadata
|
||||
content: currentContent.content, // Update content
|
||||
status: MessageStatus.Stopped, // Keep as Stopped while streaming
|
||||
})
|
||||
} else {
|
||||
updateStreamingContent(currentContent)
|
||||
}
|
||||
|
||||
if (tokenUsageRef.current) {
|
||||
setTokenSpeed(
|
||||
currentContent,
|
||||
tokenUsageRef.current.completion_tokens /
|
||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||
tokenUsageRef.current.completion_tokens
|
||||
)
|
||||
} else if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
flushStreamingContent()
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if ('usage' in part && part.usage) {
|
||||
tokenUsageRef.current = part.usage
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText.value += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText.value += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
cancelFrame(rafHandle)
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText.value += reasoningProcessor.finalize()
|
||||
}
|
||||
}
|
||||
|
||||
export const useChat = () => {
|
||||
const [
|
||||
updateTokenSpeed,
|
||||
@ -86,6 +279,7 @@ export const useChat = () => {
|
||||
|
||||
const getMessages = useMessages((state) => state.getMessages)
|
||||
const addMessage = useMessages((state) => state.addMessage)
|
||||
const updateMessage = useMessages((state) => state.updateMessage)
|
||||
const setMessages = useMessages((state) => state.setMessages)
|
||||
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
||||
const router = useRouter()
|
||||
@ -149,7 +343,7 @@ export const useChat = () => {
|
||||
})
|
||||
}
|
||||
return currentThread
|
||||
}, [createThread, retrieveThread, router, setMessages])
|
||||
}, [createThread, retrieveThread, router, setMessages, serviceHub])
|
||||
|
||||
const restartModel = useCallback(
|
||||
async (provider: ProviderObject, modelId: string) => {
|
||||
@ -264,7 +458,8 @@ export const useChat = () => {
|
||||
base64: string
|
||||
dataUrl: string
|
||||
}>,
|
||||
projectId?: string
|
||||
projectId?: string,
|
||||
continueFromMessageId?: string
|
||||
) => {
|
||||
const activeThread = await getCurrentThread(projectId)
|
||||
const selectedProvider = useModelProvider.getState().selectedProvider
|
||||
@ -277,26 +472,54 @@ export const useChat = () => {
|
||||
setAbortController(activeThread.id, abortController)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
// Do not add new message on retry
|
||||
if (troubleshooting)
|
||||
|
||||
// Find the message to continue from if provided
|
||||
const continueFromMessage = continueFromMessageId
|
||||
? messages.find((m) => m.id === continueFromMessageId)
|
||||
: undefined
|
||||
|
||||
// Do not add new message on retry or when continuing
|
||||
if (troubleshooting && !continueFromMessageId)
|
||||
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
usePrompt.getState().setPrompt('')
|
||||
const selectedModel = useModelProvider.getState().selectedModel
|
||||
|
||||
// If continuing, start with the previous content
|
||||
const accumulatedTextRef = {
|
||||
value: continueFromMessage?.content?.[0]?.text?.value || ''
|
||||
}
|
||||
let currentAssistant: Assistant | undefined | null
|
||||
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await serviceHub.models().startModel(activeProvider, selectedModel.id)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
const currentAssistant = useAssistant.getState().currentAssistant
|
||||
currentAssistant = useAssistant.getState().currentAssistant
|
||||
|
||||
// Filter out the stopped message from context if continuing
|
||||
const contextMessages = continueFromMessageId
|
||||
? messages.filter((m) => m.id !== continueFromMessageId)
|
||||
: messages
|
||||
|
||||
const builder = new CompletionMessagesBuilder(
|
||||
messages,
|
||||
contextMessages,
|
||||
currentAssistant
|
||||
? renderInstructions(currentAssistant.instructions)
|
||||
: undefined
|
||||
)
|
||||
if (troubleshooting) builder.addUserMessage(message, attachments)
|
||||
if (troubleshooting && !continueFromMessageId) {
|
||||
builder.addUserMessage(message, attachments)
|
||||
} else if (continueFromMessage) {
|
||||
// When continuing, add the partial assistant response to the context
|
||||
builder.addAssistantMessage(
|
||||
continueFromMessage.content?.[0]?.text?.value || '',
|
||||
undefined,
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
@ -349,7 +572,6 @@ export const useChat = () => {
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
const timeToFirstToken = Date.now()
|
||||
@ -357,13 +579,19 @@ export const useChat = () => {
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
const message = completion.choices[0]?.message
|
||||
accumulatedText = (message?.content as string) || ''
|
||||
// When continuing, append to existing content; otherwise replace
|
||||
const newContent = (message?.content as string) || ''
|
||||
if (continueFromMessageId && accumulatedTextRef.value) {
|
||||
accumulatedTextRef.value += newContent
|
||||
} else {
|
||||
accumulatedTextRef.value = newContent
|
||||
}
|
||||
|
||||
// Handle reasoning field if there is one
|
||||
const reasoning = extractReasoningFromMessage(message)
|
||||
if (reasoning) {
|
||||
accumulatedText =
|
||||
`<think>${reasoning}</think>` + accumulatedText
|
||||
accumulatedTextRef.value =
|
||||
`<think>${reasoning}</think>` + accumulatedTextRef.value
|
||||
}
|
||||
|
||||
if (message?.tool_calls) {
|
||||
@ -373,161 +601,25 @@ export const useChat = () => {
|
||||
tokenUsage = completion.usage
|
||||
}
|
||||
} else {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (tokenUsage) {
|
||||
setTokenSpeed(
|
||||
currentContent,
|
||||
tokenUsage.completion_tokens /
|
||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||
tokenUsage.completion_tokens
|
||||
)
|
||||
} else if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
})
|
||||
}
|
||||
const flushIfPending = () => {
|
||||
if (!rafScheduled) return
|
||||
if (
|
||||
typeof cancelAnimationFrame !== 'undefined' &&
|
||||
rafHandle !== undefined
|
||||
) {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else if (rafHandle !== undefined) {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
// Do an immediate flush
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (tokenUsage) {
|
||||
setTokenSpeed(
|
||||
currentContent,
|
||||
tokenUsage.completion_tokens /
|
||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||
tokenUsage.completion_tokens
|
||||
)
|
||||
} else if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if ('usage' in part && part.usage) {
|
||||
tokenUsage = part.usage
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
if (rafHandle !== undefined) {
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
// Only finalize and flush if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText += reasoningProcessor.finalize()
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
}
|
||||
const tokenUsageRef = { current: tokenUsage }
|
||||
await processStreamingCompletion(
|
||||
completion,
|
||||
abortController,
|
||||
activeThread,
|
||||
accumulatedTextRef,
|
||||
toolCalls,
|
||||
currentCall,
|
||||
updateStreamingContent,
|
||||
updateTokenSpeed,
|
||||
setTokenSpeed,
|
||||
updatePromptProgress,
|
||||
timeToFirstToken,
|
||||
tokenUsageRef,
|
||||
continueFromMessageId,
|
||||
updateMessage,
|
||||
continueFromMessage
|
||||
)
|
||||
tokenUsage = tokenUsageRef.current
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
@ -561,7 +653,7 @@ export const useChat = () => {
|
||||
}
|
||||
// TODO: Remove this check when integrating new llama.cpp extension
|
||||
if (
|
||||
accumulatedText.length === 0 &&
|
||||
accumulatedTextRef.value.length === 0 &&
|
||||
toolCalls.length === 0 &&
|
||||
activeThread.model?.id &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
@ -573,16 +665,29 @@ export const useChat = () => {
|
||||
}
|
||||
|
||||
// Create a final content object for adding to the thread
|
||||
const finalContent = newAssistantThreadContent(
|
||||
let finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
// If continuing from a message, preserve the ID and set status to Ready
|
||||
if (continueFromMessageId) {
|
||||
finalContent = {
|
||||
...finalContent,
|
||||
id: continueFromMessageId,
|
||||
status: MessageStatus.Ready,
|
||||
}
|
||||
}
|
||||
|
||||
// Normal completion flow (abort is handled after loop exits)
|
||||
// Don't add assistant message to builder if continuing - it's already there
|
||||
if (!continueFromMessageId) {
|
||||
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
||||
}
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
@ -592,10 +697,15 @@ export const useChat = () => {
|
||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||
allowAllMCPPermissions
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
finalizeMessage(
|
||||
updatedMessage ?? finalContent,
|
||||
addMessage,
|
||||
updateStreamingContent,
|
||||
updatePromptProgress,
|
||||
updateThreadTimestamp,
|
||||
updateMessage,
|
||||
continueFromMessageId
|
||||
)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
// Do not create agent loop if there is no need for it
|
||||
@ -605,8 +715,109 @@ export const useChat = () => {
|
||||
availableTools = []
|
||||
}
|
||||
}
|
||||
|
||||
// IMPORTANT: Check if aborted AFTER the while loop exits
|
||||
// The while loop exits when abort is true, so we handle it here
|
||||
// Only save interrupted messages for llamacpp provider
|
||||
// Other providers (OpenAI, Claude, etc.) handle streaming differently
|
||||
if (
|
||||
abortController.signal.aborted &&
|
||||
accumulatedTextRef.value.length > 0 &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
) {
|
||||
// If continuing, update the existing message; otherwise add new
|
||||
if (continueFromMessageId && continueFromMessage) {
|
||||
// Preserve the original message metadata
|
||||
updateMessage({
|
||||
...continueFromMessage,
|
||||
content: [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: accumulatedTextRef.value,
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {
|
||||
...continueFromMessage.metadata,
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// Create final content for the partial message with Stopped status
|
||||
const partialContent = {
|
||||
...newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
addMessage(partialContent)
|
||||
}
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
}
|
||||
} catch (error) {
|
||||
if (!abortController.signal.aborted) {
|
||||
// If aborted, save the partial message even though an error occurred
|
||||
// Only save for llamacpp provider - other providers handle streaming differently
|
||||
const streamingContent = useAppState.getState().streamingContent
|
||||
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
|
||||
(streamingContent && streamingContent.content?.[0]?.text?.value)
|
||||
|
||||
if (
|
||||
abortController.signal.aborted &&
|
||||
hasPartialContent &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
) {
|
||||
// Use streaming content if available, otherwise use accumulatedTextRef
|
||||
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
|
||||
|
||||
// If continuing, update the existing message; otherwise add new
|
||||
if (continueFromMessageId && continueFromMessage) {
|
||||
// Preserve the original message metadata
|
||||
updateMessage({
|
||||
...continueFromMessage,
|
||||
content: [
|
||||
{
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: contentText,
|
||||
annotations: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
status: MessageStatus.Stopped,
|
||||
metadata: {
|
||||
...continueFromMessage.metadata,
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
const partialContent = {
|
||||
...newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
contentText,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
addMessage(partialContent)
|
||||
}
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
} else if (!abortController.signal.aborted) {
|
||||
// Only show error if not aborted
|
||||
if (error && typeof error === 'object' && 'message' in error) {
|
||||
setModelLoadError(error as ErrorObject)
|
||||
} else {
|
||||
@ -628,12 +839,14 @@ export const useChat = () => {
|
||||
updateStreamingContent,
|
||||
updatePromptProgress,
|
||||
addMessage,
|
||||
updateMessage,
|
||||
updateThreadTimestamp,
|
||||
updateLoadingModel,
|
||||
getDisabledToolsForThread,
|
||||
allowAllMCPPermissions,
|
||||
showApprovalModal,
|
||||
updateTokenSpeed,
|
||||
setTokenSpeed,
|
||||
showIncreaseContextSizeModal,
|
||||
increaseModelContextSize,
|
||||
toggleOnContextShifting,
|
||||
|
||||
@ -8,6 +8,7 @@ type MessageState = {
|
||||
getMessages: (threadId: string) => ThreadMessage[]
|
||||
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
||||
addMessage: (message: ThreadMessage) => void
|
||||
updateMessage: (message: ThreadMessage) => void
|
||||
deleteMessage: (threadId: string, messageId: string) => void
|
||||
clearAllMessages: () => void
|
||||
}
|
||||
@ -40,16 +41,52 @@ export const useMessages = create<MessageState>()((set, get) => ({
|
||||
assistant: selectedAssistant,
|
||||
},
|
||||
}
|
||||
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: [
|
||||
...(state.messages[message.thread_id] || []),
|
||||
createdMessage,
|
||||
],
|
||||
},
|
||||
}))
|
||||
|
||||
// Optimistically update state immediately for instant UI feedback
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: [
|
||||
...(state.messages[message.thread_id] || []),
|
||||
newMessage,
|
||||
],
|
||||
},
|
||||
}))
|
||||
|
||||
// Persist to storage asynchronously
|
||||
getServiceHub().messages().createMessage(newMessage).catch((error) => {
|
||||
console.error('Failed to persist message:', error)
|
||||
})
|
||||
},
|
||||
updateMessage: (message) => {
|
||||
const assistants = useAssistant.getState().assistants
|
||||
const currentAssistant = useAssistant.getState().currentAssistant
|
||||
|
||||
const selectedAssistant =
|
||||
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
|
||||
|
||||
const updatedMessage = {
|
||||
...message,
|
||||
metadata: {
|
||||
...message.metadata,
|
||||
assistant: selectedAssistant,
|
||||
},
|
||||
}
|
||||
|
||||
// Optimistically update state immediately for instant UI feedback
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
|
||||
m.id === message.id ? updatedMessage : m
|
||||
),
|
||||
},
|
||||
}))
|
||||
|
||||
// Persist to storage asynchronously using modifyMessage instead of createMessage
|
||||
// to prevent duplicates when updating existing messages
|
||||
getServiceHub().messages().modifyMessage(updatedMessage).catch((error) => {
|
||||
console.error('Failed to persist message update:', error)
|
||||
})
|
||||
},
|
||||
deleteMessage: (threadId, messageId) => {
|
||||
|
||||
@ -23,8 +23,9 @@ export const useTools = () => {
|
||||
updateTools(data)
|
||||
|
||||
// Initialize default disabled tools for new users (only once)
|
||||
if (!isDefaultsInitialized() && data.length > 0 && mcpExtension?.getDefaultDisabledTools) {
|
||||
const defaultDisabled = await mcpExtension.getDefaultDisabledTools()
|
||||
const mcpExt = mcpExtension as MCPExtension & { getDefaultDisabledTools?: () => Promise<string[]> }
|
||||
if (!isDefaultsInitialized() && data.length > 0 && mcpExt?.getDefaultDisabledTools) {
|
||||
const defaultDisabled = await mcpExt.getDefaultDisabledTools()
|
||||
if (defaultDisabled.length > 0) {
|
||||
setDefaultDisabledTools(defaultDisabled)
|
||||
markDefaultsAsInitialized()
|
||||
|
||||
@ -135,6 +135,7 @@
|
||||
"enterApiKey": "Enter API Key",
|
||||
"scrollToBottom": "Scroll to bottom",
|
||||
"generateAiResponse": "Generate AI Response",
|
||||
"continueAiResponse": "Continue with AI Response",
|
||||
"addModel": {
|
||||
"title": "Add Model",
|
||||
"modelId": "Model ID",
|
||||
|
||||
@ -40,6 +40,20 @@ export class DefaultMessagesService implements MessagesService {
|
||||
)
|
||||
}
|
||||
|
||||
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||
// Don't modify messages on server for temporary chat - it's local only
|
||||
if (message.thread_id === TEMPORARY_CHAT_ID) {
|
||||
return message
|
||||
}
|
||||
|
||||
return (
|
||||
ExtensionManager.getInstance()
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.modifyMessage(message)
|
||||
?.catch(() => message) ?? message
|
||||
)
|
||||
}
|
||||
|
||||
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
||||
// Don't delete messages on server for temporary chat - it's local only
|
||||
if (threadId === TEMPORARY_CHAT_ID) {
|
||||
|
||||
@ -7,5 +7,6 @@ import { ThreadMessage } from '@janhq/core'
|
||||
export interface MessagesService {
|
||||
fetchMessages(threadId: string): Promise<ThreadMessage[]>
|
||||
createMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||
deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user