Compare commits
12 Commits
dev
...
feat/retai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35264e9a22 | ||
|
|
e7c9275488 | ||
|
|
4ac45aba23 | ||
|
|
34036d895a | ||
|
|
7127ff1244 | ||
|
|
1c0e135077 | ||
|
|
99473ed568 | ||
|
|
52f73af08c | ||
|
|
ccca331d6c | ||
|
|
f4b187ba11 | ||
|
|
4ea9d296ea | ||
|
|
2e86d4e421 |
@ -20,6 +20,13 @@ export interface MessageInterface {
|
|||||||
*/
|
*/
|
||||||
listMessages(threadId: string): Promise<ThreadMessage[]>
|
listMessages(threadId: string): Promise<ThreadMessage[]>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates an existing message in a thread.
|
||||||
|
* @param {ThreadMessage} message - The message to be updated (must have existing ID).
|
||||||
|
* @returns {Promise<ThreadMessage>} A promise that resolves to the updated message.
|
||||||
|
*/
|
||||||
|
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes a specific message from a thread.
|
* Deletes a specific message from a thread.
|
||||||
* @param {string} threadId - The ID of the thread from which the message will be deleted.
|
* @param {string} threadId - The ID of the thread from which the message will be deleted.
|
||||||
|
|||||||
@ -176,7 +176,6 @@ const ChatInput = ({
|
|||||||
const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP)
|
const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP)
|
||||||
const MCPToolComponent = mcpExtension?.getToolComponent?.()
|
const MCPToolComponent = mcpExtension?.getToolComponent?.()
|
||||||
|
|
||||||
|
|
||||||
const handleSendMesage = async (prompt: string) => {
|
const handleSendMesage = async (prompt: string) => {
|
||||||
if (!selectedModel) {
|
if (!selectedModel) {
|
||||||
setMessage('Please select a model to start chatting.')
|
setMessage('Please select a model to start chatting.')
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages'
|
|||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||||
import { Play } from 'lucide-react'
|
import { Play } from 'lucide-react'
|
||||||
import { useShallow } from 'zustand/react/shallow'
|
import { useShallow } from 'zustand/react/shallow'
|
||||||
|
import { useMemo } from 'react'
|
||||||
|
import { MessageStatus } from '@janhq/core'
|
||||||
|
|
||||||
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
|||||||
}))
|
}))
|
||||||
)
|
)
|
||||||
const sendMessage = useChat()
|
const sendMessage = useChat()
|
||||||
|
|
||||||
|
// Detect if last message is a partial assistant response (user stopped midway)
|
||||||
|
const isPartialResponse = useMemo(() => {
|
||||||
|
if (!messages || messages.length < 2) return false
|
||||||
|
const lastMessage = messages[messages.length - 1]
|
||||||
|
const secondLastMessage = messages[messages.length - 2]
|
||||||
|
|
||||||
|
return (
|
||||||
|
lastMessage?.role === 'assistant' &&
|
||||||
|
lastMessage?.status === MessageStatus.Stopped &&
|
||||||
|
secondLastMessage?.role === 'user' &&
|
||||||
|
!lastMessage?.metadata?.tool_calls
|
||||||
|
)
|
||||||
|
}, [messages])
|
||||||
|
|
||||||
const generateAIResponse = () => {
|
const generateAIResponse = () => {
|
||||||
|
if (isPartialResponse) {
|
||||||
|
const partialMessage = messages[messages.length - 1]
|
||||||
|
const userMessage = messages[messages.length - 2]
|
||||||
|
if (userMessage?.content?.[0]?.text?.value) {
|
||||||
|
sendMessage(
|
||||||
|
userMessage.content[0].text.value,
|
||||||
|
false,
|
||||||
|
undefined,
|
||||||
|
partialMessage.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const latestUserMessage = messages[messages.length - 1]
|
const latestUserMessage = messages[messages.length - 1]
|
||||||
if (
|
if (
|
||||||
latestUserMessage?.content?.[0]?.text?.value &&
|
latestUserMessage?.content?.[0]?.text?.value &&
|
||||||
@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
|||||||
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
|
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
|
||||||
onClick={generateAIResponse}
|
onClick={generateAIResponse}
|
||||||
>
|
>
|
||||||
<p className="text-xs">{t('common:generateAiResponse')}</p>
|
<p className="text-xs">
|
||||||
|
{isPartialResponse
|
||||||
|
? t('common:continueAiResponse')
|
||||||
|
: t('common:generateAiResponse')}
|
||||||
|
</p>
|
||||||
<Play size={12} />
|
<Play size={12} />
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
|
|||||||
import { ArrowDown } from 'lucide-react'
|
import { ArrowDown } from 'lucide-react'
|
||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||||
import { useAppState } from '@/hooks/useAppState'
|
import { useAppState } from '@/hooks/useAppState'
|
||||||
|
import { MessageStatus } from '@janhq/core'
|
||||||
|
|
||||||
const ScrollToBottom = ({
|
const ScrollToBottom = ({
|
||||||
threadId,
|
threadId,
|
||||||
@ -28,11 +29,20 @@ const ScrollToBottom = ({
|
|||||||
|
|
||||||
const streamingContent = useAppState((state) => state.streamingContent)
|
const streamingContent = useAppState((state) => state.streamingContent)
|
||||||
|
|
||||||
|
// Check if last message is a partial assistant response and show continue buton (user interrupted)
|
||||||
|
const isPartialResponse =
|
||||||
|
messages.length >= 2 &&
|
||||||
|
messages[messages.length - 1]?.role === 'assistant' &&
|
||||||
|
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
|
||||||
|
messages[messages.length - 2]?.role === 'user' &&
|
||||||
|
!messages[messages.length - 1]?.metadata?.tool_calls
|
||||||
|
|
||||||
const showGenerateAIResponseBtn =
|
const showGenerateAIResponseBtn =
|
||||||
(messages[messages.length - 1]?.role === 'user' ||
|
((messages[messages.length - 1]?.role === 'user' ||
|
||||||
(messages[messages.length - 1]?.metadata &&
|
(messages[messages.length - 1]?.metadata &&
|
||||||
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) &&
|
'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
|
||||||
!streamingContent
|
isPartialResponse) &&
|
||||||
|
!streamingContent)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import { useAppState } from '@/hooks/useAppState'
|
|||||||
import { ThreadContent } from './ThreadContent'
|
import { ThreadContent } from './ThreadContent'
|
||||||
import { memo, useMemo } from 'react'
|
import { memo, useMemo } from 'react'
|
||||||
import { useMessages } from '@/hooks/useMessages'
|
import { useMessages } from '@/hooks/useMessages'
|
||||||
|
import { MessageStatus } from '@janhq/core'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
threadId: string
|
threadId: string
|
||||||
@ -48,12 +49,19 @@ export const StreamingContent = memo(({ threadId }: Props) => {
|
|||||||
return extractReasoningSegment(text)
|
return extractReasoningSegment(text)
|
||||||
}, [lastAssistant])
|
}, [lastAssistant])
|
||||||
|
|
||||||
if (!streamingContent || streamingContent.thread_id !== threadId) return null
|
if (!streamingContent || streamingContent.thread_id !== threadId) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
|
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't show streaming content if there's already a stopped message
|
||||||
|
if (lastAssistant?.status === MessageStatus.Stopped) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
// Pass a new object to ThreadContent to avoid reference issues
|
// Pass a new object to ThreadContent to avoid reference issues
|
||||||
// The streaming content is always the last message
|
// The streaming content is always the last message
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -1,6 +1,30 @@
|
|||||||
import { renderHook, act } from '@testing-library/react'
|
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||||
import { useChat } from '../useChat'
|
import { MessageStatus, ContentType } from '@janhq/core'
|
||||||
|
|
||||||
|
// Store mock functions for assertions - initialize immediately
|
||||||
|
const mockAddMessage = vi.fn()
|
||||||
|
const mockUpdateMessage = vi.fn()
|
||||||
|
const mockGetMessages = vi.fn(() => [])
|
||||||
|
const mockStartModel = vi.fn(() => Promise.resolve())
|
||||||
|
const mockSendCompletion = vi.fn(() => Promise.resolve({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: 'AI response',
|
||||||
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}))
|
||||||
|
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
|
||||||
|
Promise.resolve(content)
|
||||||
|
)
|
||||||
|
const mockCompletionMessagesBuilder = {
|
||||||
|
addUserMessage: vi.fn(),
|
||||||
|
addAssistantMessage: vi.fn(),
|
||||||
|
getMessages: vi.fn(() => []),
|
||||||
|
}
|
||||||
|
const mockSetPrompt = vi.fn()
|
||||||
|
const mockResetTokenSpeed = vi.fn()
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('../usePrompt', () => ({
|
vi.mock('../usePrompt', () => ({
|
||||||
@ -8,11 +32,16 @@ vi.mock('../usePrompt', () => ({
|
|||||||
(selector: any) => {
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
prompt: 'test prompt',
|
prompt: 'test prompt',
|
||||||
setPrompt: vi.fn(),
|
setPrompt: mockSetPrompt,
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) }
|
{
|
||||||
|
getState: () => ({
|
||||||
|
prompt: 'test prompt',
|
||||||
|
setPrompt: mockSetPrompt
|
||||||
|
})
|
||||||
|
}
|
||||||
),
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -22,39 +51,58 @@ vi.mock('../useAppState', () => ({
|
|||||||
const state = {
|
const state = {
|
||||||
tools: [],
|
tools: [],
|
||||||
updateTokenSpeed: vi.fn(),
|
updateTokenSpeed: vi.fn(),
|
||||||
resetTokenSpeed: vi.fn(),
|
resetTokenSpeed: mockResetTokenSpeed,
|
||||||
updateTools: vi.fn(),
|
updateTools: vi.fn(),
|
||||||
updateStreamingContent: vi.fn(),
|
updateStreamingContent: vi.fn(),
|
||||||
updatePromptProgress: vi.fn(),
|
updatePromptProgress: vi.fn(),
|
||||||
updateLoadingModel: vi.fn(),
|
updateLoadingModel: vi.fn(),
|
||||||
setAbortController: vi.fn(),
|
setAbortController: vi.fn(),
|
||||||
|
streamingContent: undefined,
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
getState: vi.fn(() => ({
|
getState: vi.fn(() => ({
|
||||||
|
tools: [],
|
||||||
tokenSpeed: { tokensPerSecond: 10 },
|
tokenSpeed: { tokensPerSecond: 10 },
|
||||||
|
streamingContent: undefined,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useAssistant', () => ({
|
vi.mock('../useAssistant', () => ({
|
||||||
useAssistant: (selector: any) => {
|
useAssistant: Object.assign(
|
||||||
const state = {
|
(selector: any) => {
|
||||||
assistants: [{
|
const state = {
|
||||||
id: 'test-assistant',
|
assistants: [{
|
||||||
instructions: 'test instructions',
|
id: 'test-assistant',
|
||||||
parameters: { stream: true },
|
instructions: 'test instructions',
|
||||||
}],
|
parameters: { stream: true },
|
||||||
currentAssistant: {
|
}],
|
||||||
id: 'test-assistant',
|
currentAssistant: {
|
||||||
instructions: 'test instructions',
|
id: 'test-assistant',
|
||||||
parameters: { stream: true },
|
instructions: 'test instructions',
|
||||||
},
|
parameters: { stream: true },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return selector ? selector(state) : state
|
||||||
|
},
|
||||||
|
{
|
||||||
|
getState: () => ({
|
||||||
|
assistants: [{
|
||||||
|
id: 'test-assistant',
|
||||||
|
instructions: 'test instructions',
|
||||||
|
parameters: { stream: true },
|
||||||
|
}],
|
||||||
|
currentAssistant: {
|
||||||
|
id: 'test-assistant',
|
||||||
|
instructions: 'test instructions',
|
||||||
|
parameters: { stream: true },
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
),
|
||||||
},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useModelProvider', () => ({
|
vi.mock('../useModelProvider', () => ({
|
||||||
@ -62,14 +110,15 @@ vi.mock('../useModelProvider', () => ({
|
|||||||
(selector: any) => {
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
getProviderByName: vi.fn(() => ({
|
getProviderByName: vi.fn(() => ({
|
||||||
provider: 'openai',
|
provider: 'llamacpp',
|
||||||
models: [],
|
models: [],
|
||||||
|
settings: [],
|
||||||
})),
|
})),
|
||||||
selectedModel: {
|
selectedModel: {
|
||||||
id: 'test-model',
|
id: 'test-model',
|
||||||
capabilities: ['tools'],
|
capabilities: ['tools'],
|
||||||
},
|
},
|
||||||
selectedProvider: 'openai',
|
selectedProvider: 'llamacpp',
|
||||||
updateProvider: vi.fn(),
|
updateProvider: vi.fn(),
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
@ -77,14 +126,15 @@ vi.mock('../useModelProvider', () => ({
|
|||||||
{
|
{
|
||||||
getState: () => ({
|
getState: () => ({
|
||||||
getProviderByName: vi.fn(() => ({
|
getProviderByName: vi.fn(() => ({
|
||||||
provider: 'openai',
|
provider: 'llamacpp',
|
||||||
models: [],
|
models: [],
|
||||||
|
settings: [],
|
||||||
})),
|
})),
|
||||||
selectedModel: {
|
selectedModel: {
|
||||||
id: 'test-model',
|
id: 'test-model',
|
||||||
capabilities: ['tools'],
|
capabilities: ['tools'],
|
||||||
},
|
},
|
||||||
selectedProvider: 'openai',
|
selectedProvider: 'llamacpp',
|
||||||
updateProvider: vi.fn(),
|
updateProvider: vi.fn(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -96,11 +146,11 @@ vi.mock('../useThreads', () => ({
|
|||||||
const state = {
|
const state = {
|
||||||
getCurrentThread: vi.fn(() => ({
|
getCurrentThread: vi.fn(() => ({
|
||||||
id: 'test-thread',
|
id: 'test-thread',
|
||||||
model: { id: 'test-model', provider: 'openai' },
|
model: { id: 'test-model', provider: 'llamacpp' },
|
||||||
})),
|
})),
|
||||||
createThread: vi.fn(() => Promise.resolve({
|
createThread: vi.fn(() => Promise.resolve({
|
||||||
id: 'test-thread',
|
id: 'test-thread',
|
||||||
model: { id: 'test-model', provider: 'openai' },
|
model: { id: 'test-model', provider: 'llamacpp' },
|
||||||
})),
|
})),
|
||||||
updateThreadTimestamp: vi.fn(),
|
updateThreadTimestamp: vi.fn(),
|
||||||
}
|
}
|
||||||
@ -111,22 +161,33 @@ vi.mock('../useThreads', () => ({
|
|||||||
vi.mock('../useMessages', () => ({
|
vi.mock('../useMessages', () => ({
|
||||||
useMessages: (selector: any) => {
|
useMessages: (selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
getMessages: vi.fn(() => []),
|
getMessages: mockGetMessages,
|
||||||
addMessage: vi.fn(),
|
addMessage: mockAddMessage,
|
||||||
|
updateMessage: mockUpdateMessage,
|
||||||
|
setMessages: vi.fn(),
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useToolApproval', () => ({
|
vi.mock('../useToolApproval', () => ({
|
||||||
useToolApproval: (selector: any) => {
|
useToolApproval: Object.assign(
|
||||||
const state = {
|
(selector: any) => {
|
||||||
approvedTools: [],
|
const state = {
|
||||||
showApprovalModal: vi.fn(),
|
approvedTools: [],
|
||||||
allowAllMCPPermissions: false,
|
showApprovalModal: vi.fn(),
|
||||||
|
allowAllMCPPermissions: false,
|
||||||
|
}
|
||||||
|
return selector ? selector(state) : state
|
||||||
|
},
|
||||||
|
{
|
||||||
|
getState: () => ({
|
||||||
|
approvedTools: [],
|
||||||
|
showApprovalModal: vi.fn(),
|
||||||
|
allowAllMCPPermissions: false,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
),
|
||||||
},
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useToolAvailable', () => ({
|
vi.mock('../useToolAvailable', () => ({
|
||||||
@ -162,38 +223,57 @@ vi.mock('@tanstack/react-router', () => ({
|
|||||||
})),
|
})),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('../useServiceHub', () => ({
|
||||||
|
useServiceHub: vi.fn(() => ({
|
||||||
|
models: () => ({
|
||||||
|
startModel: mockStartModel,
|
||||||
|
stopModel: vi.fn(() => Promise.resolve()),
|
||||||
|
stopAllModels: vi.fn(() => Promise.resolve()),
|
||||||
|
}),
|
||||||
|
providers: () => ({
|
||||||
|
updateSettings: vi.fn(() => Promise.resolve()),
|
||||||
|
}),
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
vi.mock('@/lib/completion', () => ({
|
vi.mock('@/lib/completion', () => ({
|
||||||
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
||||||
extractToolCall: vi.fn(),
|
extractToolCall: vi.fn(),
|
||||||
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
|
newUserThreadContent: vi.fn((threadId, content) => ({
|
||||||
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
|
thread_id: threadId,
|
||||||
sendCompletion: vi.fn(),
|
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||||
postMessageProcessing: vi.fn(),
|
role: 'user'
|
||||||
isCompletionResponse: vi.fn(),
|
})),
|
||||||
|
newAssistantThreadContent: vi.fn((threadId, content) => ({
|
||||||
|
thread_id: threadId,
|
||||||
|
content: [{ type: 'text', text: { value: content, annotations: [] } }],
|
||||||
|
role: 'assistant'
|
||||||
|
})),
|
||||||
|
sendCompletion: mockSendCompletion,
|
||||||
|
postMessageProcessing: mockPostMessageProcessing,
|
||||||
|
isCompletionResponse: vi.fn(() => true),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/lib/messages', () => ({
|
vi.mock('@/lib/messages', () => ({
|
||||||
CompletionMessagesBuilder: vi.fn(() => ({
|
CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
|
||||||
addUserMessage: vi.fn(),
|
}))
|
||||||
addAssistantMessage: vi.fn(),
|
|
||||||
getMessages: vi.fn(() => []),
|
vi.mock('@/lib/instructionTemplate', () => ({
|
||||||
|
renderInstructions: vi.fn((instructions: string) => instructions),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/utils/reasoning', () => ({
|
||||||
|
ReasoningProcessor: vi.fn(() => ({
|
||||||
|
processReasoningChunk: vi.fn(() => null),
|
||||||
|
finalize: vi.fn(() => ''),
|
||||||
})),
|
})),
|
||||||
|
extractReasoningFromMessage: vi.fn(() => null),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/services/mcp', () => ({
|
vi.mock('@/services/mcp', () => ({
|
||||||
getTools: vi.fn(() => Promise.resolve([])),
|
getTools: vi.fn(() => Promise.resolve([])),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/services/models', () => ({
|
|
||||||
startModel: vi.fn(() => Promise.resolve()),
|
|
||||||
stopModel: vi.fn(() => Promise.resolve()),
|
|
||||||
stopAllModels: vi.fn(() => Promise.resolve()),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/services/providers', () => ({
|
|
||||||
updateSettings: vi.fn(() => Promise.resolve()),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@tauri-apps/api/event', () => ({
|
vi.mock('@tauri-apps/api/event', () => ({
|
||||||
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
||||||
}))
|
}))
|
||||||
@ -204,9 +284,41 @@ vi.mock('sonner', () => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// Import after mocks to avoid hoisting issues
|
||||||
|
const { useChat } = await import('../useChat')
|
||||||
|
const completionLib = await import('@/lib/completion')
|
||||||
|
const messagesLib = await import('@/lib/messages')
|
||||||
|
|
||||||
describe('useChat', () => {
|
describe('useChat', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
// Clear mock call history
|
||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
// Reset mock implementations
|
||||||
|
mockAddMessage.mockClear()
|
||||||
|
mockUpdateMessage.mockClear()
|
||||||
|
mockGetMessages.mockReturnValue([])
|
||||||
|
mockStartModel.mockResolvedValue(undefined)
|
||||||
|
mockSetPrompt.mockClear()
|
||||||
|
mockResetTokenSpeed.mockClear()
|
||||||
|
mockSendCompletion.mockResolvedValue({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: 'AI response',
|
||||||
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
|
||||||
|
Promise.resolve(content)
|
||||||
|
)
|
||||||
|
mockCompletionMessagesBuilder.addUserMessage.mockClear()
|
||||||
|
mockCompletionMessagesBuilder.addAssistantMessage.mockClear()
|
||||||
|
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.clearAllTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('returns sendMessage function', () => {
|
it('returns sendMessage function', () => {
|
||||||
@ -216,13 +328,270 @@ describe('useChat', () => {
|
|||||||
expect(typeof result.current).toBe('function')
|
expect(typeof result.current).toBe('function')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('sends message successfully', async () => {
|
describe('Continue with AI response functionality', () => {
|
||||||
const { result } = renderHook(() => useChat())
|
it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await result.current('Hello world')
|
await result.current('Hello world', true, undefined, undefined, undefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||||
|
'test-thread',
|
||||||
|
'Hello world',
|
||||||
|
undefined
|
||||||
|
)
|
||||||
|
expect(mockAddMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'user',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||||
|
'Hello world',
|
||||||
|
undefined
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
expect(result.current).toBeDefined()
|
it('should NOT add new user message when continueFromMessageId is provided', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
|
||||||
|
const userMessageCalls = mockAddMessage.mock.calls.filter(
|
||||||
|
(call: any) => call[0]?.role === 'user'
|
||||||
|
)
|
||||||
|
expect(userMessageCalls).toHaveLength(0)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should add partial assistant message to builder when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should be called twice: once with partial message (line 517-521), once after completion (line 689)
|
||||||
|
const assistantCalls = mockCompletionMessagesBuilder.addAssistantMessage.mock.calls
|
||||||
|
expect(assistantCalls.length).toBeGreaterThanOrEqual(1)
|
||||||
|
// First call should be with the partial response content
|
||||||
|
expect(assistantCalls[0]).toEqual([
|
||||||
|
'Partial response',
|
||||||
|
undefined,
|
||||||
|
[]
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should filter out stopped message from context when continuing', async () => {
|
||||||
|
const userMsg = {
|
||||||
|
id: 'msg-1',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'user',
|
||||||
|
content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }],
|
||||||
|
}
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// The CompletionMessagesBuilder is called with filtered messages (line 507-512)
|
||||||
|
// The stopped message should be filtered out from the context
|
||||||
|
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalled()
|
||||||
|
const builderCall = (messagesLib.CompletionMessagesBuilder as any).mock.calls[0]
|
||||||
|
expect(builderCall[0]).toEqual([userMsg]) // stopped message filtered out
|
||||||
|
expect(builderCall[1]).toEqual('test instructions')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update existing message instead of adding new one when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// finalizeMessage is called at line 700-708, which should update the message
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should start with previous content when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
mockSendCompletion.mockResolvedValue({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: ' continued',
|
||||||
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// The accumulated text should contain the previous content plus new content
|
||||||
|
// accumulatedTextRef starts with 'Partial response' (line 490)
|
||||||
|
// Then gets ' continued' appended (line 585)
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
content: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
text: expect.objectContaining({
|
||||||
|
value: 'Partial response continued',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
])
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle attachments correctly when not continuing', async () => {
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
const attachments = [
|
||||||
|
{
|
||||||
|
name: 'test.png',
|
||||||
|
type: 'image/png',
|
||||||
|
size: 1024,
|
||||||
|
base64: 'base64data',
|
||||||
|
dataUrl: '',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('Message with attachment', true, attachments, undefined, undefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||||
|
'test-thread',
|
||||||
|
'Message with attachment',
|
||||||
|
attachments
|
||||||
|
)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||||
|
'Message with attachment',
|
||||||
|
attachments
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve message status as Ready after continuation completes', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// finalContent is created at line 678-683 with status Ready when continuing
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Normal message sending', () => {
|
||||||
|
it('sends message successfully without continuation', async () => {
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('Hello world')
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockSendCompletion).toHaveBeenCalled()
|
||||||
|
expect(mockStartModel).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error handling', () => {
|
||||||
|
it('should handle errors gracefully during continuation', async () => {
|
||||||
|
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.current).toBeDefined()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -225,9 +225,25 @@ describe('useMessages', () => {
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
// Wait for async operation
|
// Message should be immediately available (optimistic update)
|
||||||
|
expect(result.current.messages['thread1']).toContainEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: messageToAdd.id,
|
||||||
|
thread_id: messageToAdd.thread_id,
|
||||||
|
role: messageToAdd.role,
|
||||||
|
content: messageToAdd.content,
|
||||||
|
metadata: expect.objectContaining({
|
||||||
|
assistant: expect.objectContaining({
|
||||||
|
id: expect.any(String),
|
||||||
|
name: expect.any(String),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify persistence was attempted
|
||||||
await vi.waitFor(() => {
|
await vi.waitFor(() => {
|
||||||
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage)
|
expect(mockCreateMessage).toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import { MCPTool } from '@/types/completion'
|
|||||||
import { useAssistant } from './useAssistant'
|
import { useAssistant } from './useAssistant'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
|
|
||||||
type PromptProgress = {
|
export type PromptProgress = {
|
||||||
cache: number
|
cache: number
|
||||||
processed: number
|
processed: number
|
||||||
time_ms: number
|
time_ms: number
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import { flushSync } from 'react-dom'
|
|||||||
import { usePrompt } from './usePrompt'
|
import { usePrompt } from './usePrompt'
|
||||||
import { useModelProvider } from './useModelProvider'
|
import { useModelProvider } from './useModelProvider'
|
||||||
import { useThreads } from './useThreads'
|
import { useThreads } from './useThreads'
|
||||||
import { useAppState } from './useAppState'
|
import { useAppState, type PromptProgress } from './useAppState'
|
||||||
import { useMessages } from './useMessages'
|
import { useMessages } from './useMessages'
|
||||||
import { useRouter } from '@tanstack/react-router'
|
import { useRouter } from '@tanstack/react-router'
|
||||||
import { defaultModel } from '@/lib/models'
|
import { defaultModel } from '@/lib/models'
|
||||||
@ -23,6 +23,7 @@ import {
|
|||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
CompletionUsage,
|
CompletionUsage,
|
||||||
} from 'openai/resources'
|
} from 'openai/resources'
|
||||||
|
import { MessageStatus, ContentType, ThreadMessage } from '@janhq/core'
|
||||||
|
|
||||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||||
@ -38,6 +39,198 @@ import { useAssistant } from './useAssistant'
|
|||||||
import { useShallow } from 'zustand/shallow'
|
import { useShallow } from 'zustand/shallow'
|
||||||
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
||||||
|
|
||||||
|
// Helper to create thread content with consistent structure
|
||||||
|
const createThreadContent = (
|
||||||
|
threadId: string,
|
||||||
|
text: string,
|
||||||
|
toolCalls: ChatCompletionMessageToolCall[],
|
||||||
|
messageId?: string
|
||||||
|
) => {
|
||||||
|
const content = newAssistantThreadContent(threadId, text, {
|
||||||
|
tool_calls: toolCalls.map((e) => ({
|
||||||
|
...e,
|
||||||
|
state: 'pending',
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
// If continuing from a message, preserve the message ID
|
||||||
|
if (messageId) {
|
||||||
|
return { ...content, id: messageId }
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to cancel animation frame cross-platform
|
||||||
|
const cancelFrame = (handle: number | undefined) => {
|
||||||
|
if (handle === undefined) return
|
||||||
|
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||||
|
cancelAnimationFrame(handle)
|
||||||
|
} else {
|
||||||
|
clearTimeout(handle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to finalize and save a message
|
||||||
|
const finalizeMessage = (
|
||||||
|
finalContent: ThreadMessage,
|
||||||
|
addMessage: (message: ThreadMessage) => void,
|
||||||
|
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||||
|
updatePromptProgress: (progress: PromptProgress | undefined) => void,
|
||||||
|
updateThreadTimestamp: (threadId: string) => void,
|
||||||
|
updateMessage?: (message: ThreadMessage) => void,
|
||||||
|
continueFromMessageId?: string
|
||||||
|
) => {
|
||||||
|
// If continuing from a message, update it; otherwise add new message
|
||||||
|
if (continueFromMessageId && updateMessage) {
|
||||||
|
updateMessage({ ...finalContent, id: continueFromMessageId })
|
||||||
|
} else {
|
||||||
|
addMessage(finalContent)
|
||||||
|
}
|
||||||
|
updateStreamingContent(emptyThreadContent)
|
||||||
|
updatePromptProgress(undefined)
|
||||||
|
updateThreadTimestamp(finalContent.thread_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to process streaming completion
|
||||||
|
const processStreamingCompletion = async (
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
completion: AsyncIterable<any>,
|
||||||
|
abortController: AbortController,
|
||||||
|
activeThread: Thread,
|
||||||
|
accumulatedText: { value: string },
|
||||||
|
toolCalls: ChatCompletionMessageToolCall[],
|
||||||
|
currentCall: ChatCompletionMessageToolCall | null,
|
||||||
|
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||||
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||||
|
setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void,
|
||||||
|
updatePromptProgress: (progress: PromptProgress | undefined) => void,
|
||||||
|
timeToFirstToken: number,
|
||||||
|
tokenUsageRef: { current: CompletionUsage | undefined },
|
||||||
|
continueFromMessageId?: string,
|
||||||
|
updateMessage?: (message: ThreadMessage) => void,
|
||||||
|
continueFromMessage?: ThreadMessage
|
||||||
|
) => {
|
||||||
|
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||||
|
let rafScheduled = false
|
||||||
|
let rafHandle: number | undefined
|
||||||
|
let pendingDeltaCount = 0
|
||||||
|
const reasoningProcessor = new ReasoningProcessor()
|
||||||
|
|
||||||
|
const flushStreamingContent = () => {
|
||||||
|
const currentContent = createThreadContent(
|
||||||
|
activeThread.id,
|
||||||
|
accumulatedText.value,
|
||||||
|
toolCalls,
|
||||||
|
continueFromMessageId
|
||||||
|
)
|
||||||
|
|
||||||
|
// When continuing, update the message directly instead of using streamingContent
|
||||||
|
if (continueFromMessageId && updateMessage && continueFromMessage) {
|
||||||
|
updateMessage({
|
||||||
|
...continueFromMessage, // Preserve original message metadata
|
||||||
|
content: currentContent.content, // Update content
|
||||||
|
status: MessageStatus.Stopped, // Keep as Stopped while streaming
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
updateStreamingContent(currentContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tokenUsageRef.current) {
|
||||||
|
setTokenSpeed(
|
||||||
|
currentContent,
|
||||||
|
tokenUsageRef.current.completion_tokens /
|
||||||
|
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
||||||
|
tokenUsageRef.current.completion_tokens
|
||||||
|
)
|
||||||
|
} else if (pendingDeltaCount > 0) {
|
||||||
|
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||||
|
}
|
||||||
|
pendingDeltaCount = 0
|
||||||
|
rafScheduled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
const scheduleFlush = () => {
|
||||||
|
if (rafScheduled || abortController.signal.aborted) return
|
||||||
|
rafScheduled = true
|
||||||
|
const doSchedule = (cb: () => void) => {
|
||||||
|
if (typeof requestAnimationFrame !== 'undefined') {
|
||||||
|
rafHandle = requestAnimationFrame(() => cb())
|
||||||
|
} else {
|
||||||
|
// Fallback for non-browser test environments
|
||||||
|
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||||
|
rafHandle = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
doSchedule(() => {
|
||||||
|
// Check abort status before executing the scheduled callback
|
||||||
|
if (abortController.signal.aborted) {
|
||||||
|
rafScheduled = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flushStreamingContent()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
for await (const part of completion) {
|
||||||
|
// Check if aborted before processing each part
|
||||||
|
if (abortController.signal.aborted) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle prompt progress if available
|
||||||
|
if ('prompt_progress' in part && part.prompt_progress) {
|
||||||
|
// Force immediate state update to ensure we see intermediate values
|
||||||
|
flushSync(() => {
|
||||||
|
updatePromptProgress(part.prompt_progress)
|
||||||
|
})
|
||||||
|
// Add a small delay to make progress visible
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error message
|
||||||
|
if (!part.choices) {
|
||||||
|
throw new Error(
|
||||||
|
'message' in part
|
||||||
|
? (part.message as string)
|
||||||
|
: (JSON.stringify(part) ?? '')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('usage' in part && part.usage) {
|
||||||
|
tokenUsageRef.current = part.usage
|
||||||
|
}
|
||||||
|
|
||||||
|
if (part.choices[0]?.delta?.tool_calls) {
|
||||||
|
extractToolCall(part, currentCall, toolCalls)
|
||||||
|
// Schedule a flush to reflect tool update
|
||||||
|
scheduleFlush()
|
||||||
|
}
|
||||||
|
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
|
||||||
|
if (deltaReasoning) {
|
||||||
|
accumulatedText.value += deltaReasoning
|
||||||
|
pendingDeltaCount += 1
|
||||||
|
// Schedule flush for reasoning updates
|
||||||
|
scheduleFlush()
|
||||||
|
}
|
||||||
|
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||||
|
if (deltaContent) {
|
||||||
|
accumulatedText.value += deltaContent
|
||||||
|
pendingDeltaCount += 1
|
||||||
|
// Batch UI update on next animation frame
|
||||||
|
scheduleFlush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||||
|
cancelFrame(rafHandle)
|
||||||
|
rafHandle = undefined
|
||||||
|
rafScheduled = false
|
||||||
|
|
||||||
|
// Finalize reasoning (close any open think tags)
|
||||||
|
accumulatedText.value += reasoningProcessor.finalize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const useChat = () => {
|
export const useChat = () => {
|
||||||
const [
|
const [
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
@ -86,6 +279,7 @@ export const useChat = () => {
|
|||||||
|
|
||||||
const getMessages = useMessages((state) => state.getMessages)
|
const getMessages = useMessages((state) => state.getMessages)
|
||||||
const addMessage = useMessages((state) => state.addMessage)
|
const addMessage = useMessages((state) => state.addMessage)
|
||||||
|
const updateMessage = useMessages((state) => state.updateMessage)
|
||||||
const setMessages = useMessages((state) => state.setMessages)
|
const setMessages = useMessages((state) => state.setMessages)
|
||||||
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
@ -149,7 +343,7 @@ export const useChat = () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
return currentThread
|
return currentThread
|
||||||
}, [createThread, retrieveThread, router, setMessages])
|
}, [createThread, retrieveThread, router, setMessages, serviceHub])
|
||||||
|
|
||||||
const restartModel = useCallback(
|
const restartModel = useCallback(
|
||||||
async (provider: ProviderObject, modelId: string) => {
|
async (provider: ProviderObject, modelId: string) => {
|
||||||
@ -264,7 +458,8 @@ export const useChat = () => {
|
|||||||
base64: string
|
base64: string
|
||||||
dataUrl: string
|
dataUrl: string
|
||||||
}>,
|
}>,
|
||||||
projectId?: string
|
projectId?: string,
|
||||||
|
continueFromMessageId?: string
|
||||||
) => {
|
) => {
|
||||||
const activeThread = await getCurrentThread(projectId)
|
const activeThread = await getCurrentThread(projectId)
|
||||||
const selectedProvider = useModelProvider.getState().selectedProvider
|
const selectedProvider = useModelProvider.getState().selectedProvider
|
||||||
@ -277,26 +472,54 @@ export const useChat = () => {
|
|||||||
setAbortController(activeThread.id, abortController)
|
setAbortController(activeThread.id, abortController)
|
||||||
updateStreamingContent(emptyThreadContent)
|
updateStreamingContent(emptyThreadContent)
|
||||||
updatePromptProgress(undefined)
|
updatePromptProgress(undefined)
|
||||||
// Do not add new message on retry
|
|
||||||
if (troubleshooting)
|
// Find the message to continue from if provided
|
||||||
|
const continueFromMessage = continueFromMessageId
|
||||||
|
? messages.find((m) => m.id === continueFromMessageId)
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// Do not add new message on retry or when continuing
|
||||||
|
if (troubleshooting && !continueFromMessageId)
|
||||||
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateThreadTimestamp(activeThread.id)
|
||||||
usePrompt.getState().setPrompt('')
|
usePrompt.getState().setPrompt('')
|
||||||
const selectedModel = useModelProvider.getState().selectedModel
|
const selectedModel = useModelProvider.getState().selectedModel
|
||||||
|
|
||||||
|
// If continuing, start with the previous content
|
||||||
|
const accumulatedTextRef = {
|
||||||
|
value: continueFromMessage?.content?.[0]?.text?.value || ''
|
||||||
|
}
|
||||||
|
let currentAssistant: Assistant | undefined | null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (selectedModel?.id) {
|
if (selectedModel?.id) {
|
||||||
updateLoadingModel(true)
|
updateLoadingModel(true)
|
||||||
await serviceHub.models().startModel(activeProvider, selectedModel.id)
|
await serviceHub.models().startModel(activeProvider, selectedModel.id)
|
||||||
updateLoadingModel(false)
|
updateLoadingModel(false)
|
||||||
}
|
}
|
||||||
const currentAssistant = useAssistant.getState().currentAssistant
|
currentAssistant = useAssistant.getState().currentAssistant
|
||||||
|
|
||||||
|
// Filter out the stopped message from context if continuing
|
||||||
|
const contextMessages = continueFromMessageId
|
||||||
|
? messages.filter((m) => m.id !== continueFromMessageId)
|
||||||
|
: messages
|
||||||
|
|
||||||
const builder = new CompletionMessagesBuilder(
|
const builder = new CompletionMessagesBuilder(
|
||||||
messages,
|
contextMessages,
|
||||||
currentAssistant
|
currentAssistant
|
||||||
? renderInstructions(currentAssistant.instructions)
|
? renderInstructions(currentAssistant.instructions)
|
||||||
: undefined
|
: undefined
|
||||||
)
|
)
|
||||||
if (troubleshooting) builder.addUserMessage(message, attachments)
|
if (troubleshooting && !continueFromMessageId) {
|
||||||
|
builder.addUserMessage(message, attachments)
|
||||||
|
} else if (continueFromMessage) {
|
||||||
|
// When continuing, add the partial assistant response to the context
|
||||||
|
builder.addAssistantMessage(
|
||||||
|
continueFromMessage.content?.[0]?.text?.value || '',
|
||||||
|
undefined,
|
||||||
|
[]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let isCompleted = false
|
let isCompleted = false
|
||||||
|
|
||||||
@ -349,7 +572,6 @@ export const useChat = () => {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (!completion) throw new Error('No completion received')
|
if (!completion) throw new Error('No completion received')
|
||||||
let accumulatedText = ''
|
|
||||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
const timeToFirstToken = Date.now()
|
const timeToFirstToken = Date.now()
|
||||||
@ -357,13 +579,19 @@ export const useChat = () => {
|
|||||||
try {
|
try {
|
||||||
if (isCompletionResponse(completion)) {
|
if (isCompletionResponse(completion)) {
|
||||||
const message = completion.choices[0]?.message
|
const message = completion.choices[0]?.message
|
||||||
accumulatedText = (message?.content as string) || ''
|
// When continuing, append to existing content; otherwise replace
|
||||||
|
const newContent = (message?.content as string) || ''
|
||||||
|
if (continueFromMessageId && accumulatedTextRef.value) {
|
||||||
|
accumulatedTextRef.value += newContent
|
||||||
|
} else {
|
||||||
|
accumulatedTextRef.value = newContent
|
||||||
|
}
|
||||||
|
|
||||||
// Handle reasoning field if there is one
|
// Handle reasoning field if there is one
|
||||||
const reasoning = extractReasoningFromMessage(message)
|
const reasoning = extractReasoningFromMessage(message)
|
||||||
if (reasoning) {
|
if (reasoning) {
|
||||||
accumulatedText =
|
accumulatedTextRef.value =
|
||||||
`<think>${reasoning}</think>` + accumulatedText
|
`<think>${reasoning}</think>` + accumulatedTextRef.value
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message?.tool_calls) {
|
if (message?.tool_calls) {
|
||||||
@ -373,161 +601,25 @@ export const useChat = () => {
|
|||||||
tokenUsage = completion.usage
|
tokenUsage = completion.usage
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
const tokenUsageRef = { current: tokenUsage }
|
||||||
let rafScheduled = false
|
await processStreamingCompletion(
|
||||||
let rafHandle: number | undefined
|
completion,
|
||||||
let pendingDeltaCount = 0
|
abortController,
|
||||||
const reasoningProcessor = new ReasoningProcessor()
|
activeThread,
|
||||||
const scheduleFlush = () => {
|
accumulatedTextRef,
|
||||||
if (rafScheduled || abortController.signal.aborted) return
|
toolCalls,
|
||||||
rafScheduled = true
|
currentCall,
|
||||||
const doSchedule = (cb: () => void) => {
|
updateStreamingContent,
|
||||||
if (typeof requestAnimationFrame !== 'undefined') {
|
updateTokenSpeed,
|
||||||
rafHandle = requestAnimationFrame(() => cb())
|
setTokenSpeed,
|
||||||
} else {
|
updatePromptProgress,
|
||||||
// Fallback for non-browser test environments
|
timeToFirstToken,
|
||||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
tokenUsageRef,
|
||||||
rafHandle = t
|
continueFromMessageId,
|
||||||
}
|
updateMessage,
|
||||||
}
|
continueFromMessage
|
||||||
doSchedule(() => {
|
)
|
||||||
// Check abort status before executing the scheduled callback
|
tokenUsage = tokenUsageRef.current
|
||||||
if (abortController.signal.aborted) {
|
|
||||||
rafScheduled = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const currentContent = newAssistantThreadContent(
|
|
||||||
activeThread.id,
|
|
||||||
accumulatedText,
|
|
||||||
{
|
|
||||||
tool_calls: toolCalls.map((e) => ({
|
|
||||||
...e,
|
|
||||||
state: 'pending',
|
|
||||||
})),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
updateStreamingContent(currentContent)
|
|
||||||
if (tokenUsage) {
|
|
||||||
setTokenSpeed(
|
|
||||||
currentContent,
|
|
||||||
tokenUsage.completion_tokens /
|
|
||||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
|
||||||
tokenUsage.completion_tokens
|
|
||||||
)
|
|
||||||
} else if (pendingDeltaCount > 0) {
|
|
||||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
|
||||||
}
|
|
||||||
pendingDeltaCount = 0
|
|
||||||
rafScheduled = false
|
|
||||||
})
|
|
||||||
}
|
|
||||||
const flushIfPending = () => {
|
|
||||||
if (!rafScheduled) return
|
|
||||||
if (
|
|
||||||
typeof cancelAnimationFrame !== 'undefined' &&
|
|
||||||
rafHandle !== undefined
|
|
||||||
) {
|
|
||||||
cancelAnimationFrame(rafHandle)
|
|
||||||
} else if (rafHandle !== undefined) {
|
|
||||||
clearTimeout(rafHandle)
|
|
||||||
}
|
|
||||||
// Do an immediate flush
|
|
||||||
const currentContent = newAssistantThreadContent(
|
|
||||||
activeThread.id,
|
|
||||||
accumulatedText,
|
|
||||||
{
|
|
||||||
tool_calls: toolCalls.map((e) => ({
|
|
||||||
...e,
|
|
||||||
state: 'pending',
|
|
||||||
})),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
updateStreamingContent(currentContent)
|
|
||||||
if (tokenUsage) {
|
|
||||||
setTokenSpeed(
|
|
||||||
currentContent,
|
|
||||||
tokenUsage.completion_tokens /
|
|
||||||
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
|
|
||||||
tokenUsage.completion_tokens
|
|
||||||
)
|
|
||||||
} else if (pendingDeltaCount > 0) {
|
|
||||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
|
||||||
}
|
|
||||||
pendingDeltaCount = 0
|
|
||||||
rafScheduled = false
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
for await (const part of completion) {
|
|
||||||
// Check if aborted before processing each part
|
|
||||||
if (abortController.signal.aborted) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle prompt progress if available
|
|
||||||
if ('prompt_progress' in part && part.prompt_progress) {
|
|
||||||
// Force immediate state update to ensure we see intermediate values
|
|
||||||
flushSync(() => {
|
|
||||||
updatePromptProgress(part.prompt_progress)
|
|
||||||
})
|
|
||||||
// Add a small delay to make progress visible
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error message
|
|
||||||
if (!part.choices) {
|
|
||||||
throw new Error(
|
|
||||||
'message' in part
|
|
||||||
? (part.message as string)
|
|
||||||
: (JSON.stringify(part) ?? '')
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('usage' in part && part.usage) {
|
|
||||||
tokenUsage = part.usage
|
|
||||||
}
|
|
||||||
|
|
||||||
if (part.choices[0]?.delta?.tool_calls) {
|
|
||||||
extractToolCall(part, currentCall, toolCalls)
|
|
||||||
// Schedule a flush to reflect tool update
|
|
||||||
scheduleFlush()
|
|
||||||
}
|
|
||||||
const deltaReasoning =
|
|
||||||
reasoningProcessor.processReasoningChunk(part)
|
|
||||||
if (deltaReasoning) {
|
|
||||||
accumulatedText += deltaReasoning
|
|
||||||
pendingDeltaCount += 1
|
|
||||||
// Schedule flush for reasoning updates
|
|
||||||
scheduleFlush()
|
|
||||||
}
|
|
||||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
|
||||||
if (deltaContent) {
|
|
||||||
accumulatedText += deltaContent
|
|
||||||
pendingDeltaCount += 1
|
|
||||||
// Batch UI update on next animation frame
|
|
||||||
scheduleFlush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
|
||||||
if (rafHandle !== undefined) {
|
|
||||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
|
||||||
cancelAnimationFrame(rafHandle)
|
|
||||||
} else {
|
|
||||||
clearTimeout(rafHandle)
|
|
||||||
}
|
|
||||||
rafHandle = undefined
|
|
||||||
rafScheduled = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only finalize and flush if not aborted
|
|
||||||
if (!abortController.signal.aborted) {
|
|
||||||
// Finalize reasoning (close any open think tags)
|
|
||||||
accumulatedText += reasoningProcessor.finalize()
|
|
||||||
// Ensure any pending buffered content is rendered at the end
|
|
||||||
flushIfPending()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
@ -561,7 +653,7 @@ export const useChat = () => {
|
|||||||
}
|
}
|
||||||
// TODO: Remove this check when integrating new llama.cpp extension
|
// TODO: Remove this check when integrating new llama.cpp extension
|
||||||
if (
|
if (
|
||||||
accumulatedText.length === 0 &&
|
accumulatedTextRef.value.length === 0 &&
|
||||||
toolCalls.length === 0 &&
|
toolCalls.length === 0 &&
|
||||||
activeThread.model?.id &&
|
activeThread.model?.id &&
|
||||||
activeProvider?.provider === 'llamacpp'
|
activeProvider?.provider === 'llamacpp'
|
||||||
@ -573,16 +665,29 @@ export const useChat = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a final content object for adding to the thread
|
// Create a final content object for adding to the thread
|
||||||
const finalContent = newAssistantThreadContent(
|
let finalContent = newAssistantThreadContent(
|
||||||
activeThread.id,
|
activeThread.id,
|
||||||
accumulatedText,
|
accumulatedTextRef.value,
|
||||||
{
|
{
|
||||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||||
assistant: currentAssistant,
|
assistant: currentAssistant,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
// If continuing from a message, preserve the ID and set status to Ready
|
||||||
|
if (continueFromMessageId) {
|
||||||
|
finalContent = {
|
||||||
|
...finalContent,
|
||||||
|
id: continueFromMessageId,
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal completion flow (abort is handled after loop exits)
|
||||||
|
// Don't add assistant message to builder if continuing - it's already there
|
||||||
|
if (!continueFromMessageId) {
|
||||||
|
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
||||||
|
}
|
||||||
const updatedMessage = await postMessageProcessing(
|
const updatedMessage = await postMessageProcessing(
|
||||||
toolCalls,
|
toolCalls,
|
||||||
builder,
|
builder,
|
||||||
@ -592,10 +697,15 @@ export const useChat = () => {
|
|||||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||||
allowAllMCPPermissions
|
allowAllMCPPermissions
|
||||||
)
|
)
|
||||||
addMessage(updatedMessage ?? finalContent)
|
finalizeMessage(
|
||||||
updateStreamingContent(emptyThreadContent)
|
updatedMessage ?? finalContent,
|
||||||
updatePromptProgress(undefined)
|
addMessage,
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateStreamingContent,
|
||||||
|
updatePromptProgress,
|
||||||
|
updateThreadTimestamp,
|
||||||
|
updateMessage,
|
||||||
|
continueFromMessageId
|
||||||
|
)
|
||||||
|
|
||||||
isCompleted = !toolCalls.length
|
isCompleted = !toolCalls.length
|
||||||
// Do not create agent loop if there is no need for it
|
// Do not create agent loop if there is no need for it
|
||||||
@ -605,8 +715,109 @@ export const useChat = () => {
|
|||||||
availableTools = []
|
availableTools = []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IMPORTANT: Check if aborted AFTER the while loop exits
|
||||||
|
// The while loop exits when abort is true, so we handle it here
|
||||||
|
// Only save interrupted messages for llamacpp provider
|
||||||
|
// Other providers (OpenAI, Claude, etc.) handle streaming differently
|
||||||
|
if (
|
||||||
|
abortController.signal.aborted &&
|
||||||
|
accumulatedTextRef.value.length > 0 &&
|
||||||
|
activeProvider?.provider === 'llamacpp'
|
||||||
|
) {
|
||||||
|
// If continuing, update the existing message; otherwise add new
|
||||||
|
if (continueFromMessageId && continueFromMessage) {
|
||||||
|
// Preserve the original message metadata
|
||||||
|
updateMessage({
|
||||||
|
...continueFromMessage,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: accumulatedTextRef.value,
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {
|
||||||
|
...continueFromMessage.metadata,
|
||||||
|
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||||
|
assistant: currentAssistant,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Create final content for the partial message with Stopped status
|
||||||
|
const partialContent = {
|
||||||
|
...newAssistantThreadContent(
|
||||||
|
activeThread.id,
|
||||||
|
accumulatedTextRef.value,
|
||||||
|
{
|
||||||
|
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||||
|
assistant: currentAssistant,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
}
|
||||||
|
addMessage(partialContent)
|
||||||
|
}
|
||||||
|
updatePromptProgress(undefined)
|
||||||
|
updateThreadTimestamp(activeThread.id)
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (!abortController.signal.aborted) {
|
// If aborted, save the partial message even though an error occurred
|
||||||
|
// Only save for llamacpp provider - other providers handle streaming differently
|
||||||
|
const streamingContent = useAppState.getState().streamingContent
|
||||||
|
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
|
||||||
|
(streamingContent && streamingContent.content?.[0]?.text?.value)
|
||||||
|
|
||||||
|
if (
|
||||||
|
abortController.signal.aborted &&
|
||||||
|
hasPartialContent &&
|
||||||
|
activeProvider?.provider === 'llamacpp'
|
||||||
|
) {
|
||||||
|
// Use streaming content if available, otherwise use accumulatedTextRef
|
||||||
|
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
|
||||||
|
|
||||||
|
// If continuing, update the existing message; otherwise add new
|
||||||
|
if (continueFromMessageId && continueFromMessage) {
|
||||||
|
// Preserve the original message metadata
|
||||||
|
updateMessage({
|
||||||
|
...continueFromMessage,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: contentText,
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {
|
||||||
|
...continueFromMessage.metadata,
|
||||||
|
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||||
|
assistant: currentAssistant,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const partialContent = {
|
||||||
|
...newAssistantThreadContent(
|
||||||
|
activeThread.id,
|
||||||
|
contentText,
|
||||||
|
{
|
||||||
|
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||||
|
assistant: currentAssistant,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
}
|
||||||
|
addMessage(partialContent)
|
||||||
|
}
|
||||||
|
updatePromptProgress(undefined)
|
||||||
|
updateThreadTimestamp(activeThread.id)
|
||||||
|
} else if (!abortController.signal.aborted) {
|
||||||
|
// Only show error if not aborted
|
||||||
if (error && typeof error === 'object' && 'message' in error) {
|
if (error && typeof error === 'object' && 'message' in error) {
|
||||||
setModelLoadError(error as ErrorObject)
|
setModelLoadError(error as ErrorObject)
|
||||||
} else {
|
} else {
|
||||||
@ -628,12 +839,14 @@ export const useChat = () => {
|
|||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updatePromptProgress,
|
updatePromptProgress,
|
||||||
addMessage,
|
addMessage,
|
||||||
|
updateMessage,
|
||||||
updateThreadTimestamp,
|
updateThreadTimestamp,
|
||||||
updateLoadingModel,
|
updateLoadingModel,
|
||||||
getDisabledToolsForThread,
|
getDisabledToolsForThread,
|
||||||
allowAllMCPPermissions,
|
allowAllMCPPermissions,
|
||||||
showApprovalModal,
|
showApprovalModal,
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
|
setTokenSpeed,
|
||||||
showIncreaseContextSizeModal,
|
showIncreaseContextSizeModal,
|
||||||
increaseModelContextSize,
|
increaseModelContextSize,
|
||||||
toggleOnContextShifting,
|
toggleOnContextShifting,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ type MessageState = {
|
|||||||
getMessages: (threadId: string) => ThreadMessage[]
|
getMessages: (threadId: string) => ThreadMessage[]
|
||||||
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
||||||
addMessage: (message: ThreadMessage) => void
|
addMessage: (message: ThreadMessage) => void
|
||||||
|
updateMessage: (message: ThreadMessage) => void
|
||||||
deleteMessage: (threadId: string, messageId: string) => void
|
deleteMessage: (threadId: string, messageId: string) => void
|
||||||
clearAllMessages: () => void
|
clearAllMessages: () => void
|
||||||
}
|
}
|
||||||
@ -40,16 +41,52 @@ export const useMessages = create<MessageState>()((set, get) => ({
|
|||||||
assistant: selectedAssistant,
|
assistant: selectedAssistant,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
|
|
||||||
set((state) => ({
|
// Optimistically update state immediately for instant UI feedback
|
||||||
messages: {
|
set((state) => ({
|
||||||
...state.messages,
|
messages: {
|
||||||
[message.thread_id]: [
|
...state.messages,
|
||||||
...(state.messages[message.thread_id] || []),
|
[message.thread_id]: [
|
||||||
createdMessage,
|
...(state.messages[message.thread_id] || []),
|
||||||
],
|
newMessage,
|
||||||
},
|
],
|
||||||
}))
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Persist to storage asynchronously
|
||||||
|
getServiceHub().messages().createMessage(newMessage).catch((error) => {
|
||||||
|
console.error('Failed to persist message:', error)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
updateMessage: (message) => {
|
||||||
|
const assistants = useAssistant.getState().assistants
|
||||||
|
const currentAssistant = useAssistant.getState().currentAssistant
|
||||||
|
|
||||||
|
const selectedAssistant =
|
||||||
|
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
|
||||||
|
|
||||||
|
const updatedMessage = {
|
||||||
|
...message,
|
||||||
|
metadata: {
|
||||||
|
...message.metadata,
|
||||||
|
assistant: selectedAssistant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimistically update state immediately for instant UI feedback
|
||||||
|
set((state) => ({
|
||||||
|
messages: {
|
||||||
|
...state.messages,
|
||||||
|
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
|
||||||
|
m.id === message.id ? updatedMessage : m
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Persist to storage asynchronously using modifyMessage instead of createMessage
|
||||||
|
// to prevent duplicates when updating existing messages
|
||||||
|
getServiceHub().messages().modifyMessage(updatedMessage).catch((error) => {
|
||||||
|
console.error('Failed to persist message update:', error)
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
deleteMessage: (threadId, messageId) => {
|
deleteMessage: (threadId, messageId) => {
|
||||||
|
|||||||
@ -23,8 +23,9 @@ export const useTools = () => {
|
|||||||
updateTools(data)
|
updateTools(data)
|
||||||
|
|
||||||
// Initialize default disabled tools for new users (only once)
|
// Initialize default disabled tools for new users (only once)
|
||||||
if (!isDefaultsInitialized() && data.length > 0 && mcpExtension?.getDefaultDisabledTools) {
|
const mcpExt = mcpExtension as MCPExtension & { getDefaultDisabledTools?: () => Promise<string[]> }
|
||||||
const defaultDisabled = await mcpExtension.getDefaultDisabledTools()
|
if (!isDefaultsInitialized() && data.length > 0 && mcpExt?.getDefaultDisabledTools) {
|
||||||
|
const defaultDisabled = await mcpExt.getDefaultDisabledTools()
|
||||||
if (defaultDisabled.length > 0) {
|
if (defaultDisabled.length > 0) {
|
||||||
setDefaultDisabledTools(defaultDisabled)
|
setDefaultDisabledTools(defaultDisabled)
|
||||||
markDefaultsAsInitialized()
|
markDefaultsAsInitialized()
|
||||||
|
|||||||
@ -135,6 +135,7 @@
|
|||||||
"enterApiKey": "Enter API Key",
|
"enterApiKey": "Enter API Key",
|
||||||
"scrollToBottom": "Scroll to bottom",
|
"scrollToBottom": "Scroll to bottom",
|
||||||
"generateAiResponse": "Generate AI Response",
|
"generateAiResponse": "Generate AI Response",
|
||||||
|
"continueAiResponse": "Continue with AI Response",
|
||||||
"addModel": {
|
"addModel": {
|
||||||
"title": "Add Model",
|
"title": "Add Model",
|
||||||
"modelId": "Model ID",
|
"modelId": "Model ID",
|
||||||
|
|||||||
@ -40,6 +40,20 @@ export class DefaultMessagesService implements MessagesService {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
|
||||||
|
// Don't modify messages on server for temporary chat - it's local only
|
||||||
|
if (message.thread_id === TEMPORARY_CHAT_ID) {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
ExtensionManager.getInstance()
|
||||||
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
|
?.modifyMessage(message)
|
||||||
|
?.catch(() => message) ?? message
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
async deleteMessage(threadId: string, messageId: string): Promise<void> {
|
||||||
// Don't delete messages on server for temporary chat - it's local only
|
// Don't delete messages on server for temporary chat - it's local only
|
||||||
if (threadId === TEMPORARY_CHAT_ID) {
|
if (threadId === TEMPORARY_CHAT_ID) {
|
||||||
|
|||||||
@ -7,5 +7,6 @@ import { ThreadMessage } from '@janhq/core'
|
|||||||
export interface MessagesService {
|
export interface MessagesService {
|
||||||
fetchMessages(threadId: string): Promise<ThreadMessage[]>
|
fetchMessages(threadId: string): Promise<ThreadMessage[]>
|
||||||
createMessage(message: ThreadMessage): Promise<ThreadMessage>
|
createMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||||
|
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
|
||||||
deleteMessage(threadId: string, messageId: string): Promise<void>
|
deleteMessage(threadId: string, messageId: string): Promise<void>
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user