fix: generate response button disappear on tool call (#5988)

* fix: generate a response button should appear when an incomplete tool call message is present

* fix: wording

* fix: do not send duplicate messages on regenerating

* fix: tests
This commit is contained in:
Louis 2025-07-30 21:04:12 +07:00 committed by GitHub
parent f58d745585
commit 76bcf33f80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 159 additions and 92 deletions

View File

@ -52,7 +52,7 @@ export default function LoadModelErrorDialog() {
<div> <div>
<DialogTitle>{t('common:error')}</DialogTitle> <DialogTitle>{t('common:error')}</DialogTitle>
<DialogDescription className="mt-1 text-main-view-fg/70"> <DialogDescription className="mt-1 text-main-view-fg/70">
Failed to load model Something went wrong
</DialogDescription> </DialogDescription>
</div> </div>
</div> </div>

View File

@ -247,8 +247,7 @@ export const useChat = () => {
messages, messages,
currentAssistant?.instructions currentAssistant?.instructions
) )
if (troubleshooting) builder.addUserMessage(message)
builder.addUserMessage(message)
let isCompleted = false let isCompleted = false

View File

@ -1,3 +1,4 @@
import { describe, it, expect } from 'vitest'
import { CompletionMessagesBuilder } from '../messages' import { CompletionMessagesBuilder } from '../messages'
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
@ -66,7 +67,10 @@ describe('CompletionMessagesBuilder', () => {
it('should normalize assistant message content', () => { it('should normalize assistant message content', () => {
const messages: ThreadMessage[] = [ const messages: ThreadMessage[] = [
createMockThreadMessage('assistant', '<think>Let me think...</think>Hello there!'), createMockThreadMessage(
'assistant',
'<think>Let me think...</think>Hello there!'
),
] ]
const builder = new CompletionMessagesBuilder(messages) const builder = new CompletionMessagesBuilder(messages)
@ -78,14 +82,19 @@ describe('CompletionMessagesBuilder', () => {
it('should preserve user message content without normalization', () => { it('should preserve user message content without normalization', () => {
const messages: ThreadMessage[] = [ const messages: ThreadMessage[] = [
createMockThreadMessage('user', '<think>This should not be normalized</think>Hello'), createMockThreadMessage(
'user',
'<think>This should not be normalized</think>Hello'
),
] ]
const builder = new CompletionMessagesBuilder(messages) const builder = new CompletionMessagesBuilder(messages)
const result = builder.getMessages() const result = builder.getMessages()
expect(result).toHaveLength(1) expect(result).toHaveLength(1)
expect(result[0].content).toBe('<think>This should not be normalized</think>Hello') expect(result[0].content).toBe(
'<think>This should not be normalized</think>Hello'
)
}) })
it('should handle messages with empty content', () => { it('should handle messages with empty content', () => {
@ -104,7 +113,9 @@ describe('CompletionMessagesBuilder', () => {
it('should handle messages with missing text value', () => { it('should handle messages with missing text value', () => {
const message: ThreadMessage = { const message: ThreadMessage = {
...createMockThreadMessage('user', ''), ...createMockThreadMessage('user', ''),
content: [{ type: 'text' as any, text: { value: '', annotations: [] } }], content: [
{ type: 'text' as any, text: { value: '', annotations: [] } },
],
} }
const builder = new CompletionMessagesBuilder([message]) const builder = new CompletionMessagesBuilder([message])
@ -129,16 +140,15 @@ describe('CompletionMessagesBuilder', () => {
}) })
}) })
it('should add multiple user messages', () => { it('should not add consecutive user messages', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addUserMessage('First message') builder.addUserMessage('First message')
builder.addUserMessage('Second message') builder.addUserMessage('Second message')
const result = builder.getMessages() const result = builder.getMessages()
expect(result).toHaveLength(2) expect(result).toHaveLength(1)
expect(result[0].content).toBe('First message') expect(result[0].content).toBe('Second message')
expect(result[1].content).toBe('Second message')
}) })
it('should handle empty user message', () => { it('should handle empty user message', () => {
@ -171,7 +181,10 @@ describe('CompletionMessagesBuilder', () => {
it('should add assistant message with refusal', () => { it('should add assistant message with refusal', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('I cannot help with that', 'Content policy violation') builder.addAssistantMessage(
'I cannot help with that',
'Content policy violation'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result).toHaveLength(1) expect(result).toHaveLength(1)
@ -196,7 +209,11 @@ describe('CompletionMessagesBuilder', () => {
}, },
] ]
builder.addAssistantMessage('Let me check the weather', undefined, toolCalls) builder.addAssistantMessage(
'Let me check the weather',
undefined,
toolCalls
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result).toHaveLength(1) expect(result).toHaveLength(1)
@ -282,19 +299,21 @@ describe('CompletionMessagesBuilder', () => {
const threadMessages: ThreadMessage[] = [ const threadMessages: ThreadMessage[] = [
createMockThreadMessage('user', 'Hello'), createMockThreadMessage('user', 'Hello'),
] ]
const builder = new CompletionMessagesBuilder(threadMessages, 'You are helpful') const builder = new CompletionMessagesBuilder(
threadMessages,
'You are helpful'
)
builder.addUserMessage('How are you?') builder.addUserMessage('How are you?')
builder.addAssistantMessage('I am well, thank you!') builder.addAssistantMessage('I am well, thank you!')
builder.addToolMessage('Tool response', 'call_123') builder.addToolMessage('Tool response', 'call_123')
const result = builder.getMessages() const result = builder.getMessages()
expect(result).toHaveLength(5) expect(result).toHaveLength(4)
expect(result[0].role).toBe('system') expect(result[0].role).toBe('system')
expect(result[1].role).toBe('user') expect(result[1].role).toBe('user')
expect(result[2].role).toBe('user') expect(result[2].role).toBe('assistant')
expect(result[3].role).toBe('assistant') expect(result[3].role).toBe('tool')
expect(result[4].role).toBe('tool')
}) })
it('should return the same array reference (not immutable)', () => { it('should return the same array reference (not immutable)', () => {
@ -317,7 +336,9 @@ describe('CompletionMessagesBuilder', () => {
it('should remove thinking content from the beginning', () => { it('should remove thinking content from the beginning', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('<think>Let me analyze this...</think>The answer is 42.') builder.addAssistantMessage(
'<think>Let me analyze this...</think>The answer is 42.'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result[0].content).toBe('The answer is 42.') expect(result[0].content).toBe('The answer is 42.')
@ -326,7 +347,9 @@ describe('CompletionMessagesBuilder', () => {
it('should handle nested thinking tags', () => { it('should handle nested thinking tags', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('<think>First thought<think>Nested</think>More thinking</think>Final answer') builder.addAssistantMessage(
'<think>First thought<think>Nested</think>More thinking</think>Final answer'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result[0].content).toBe('More thinking</think>Final answer') expect(result[0].content).toBe('More thinking</think>Final answer')
@ -335,7 +358,9 @@ describe('CompletionMessagesBuilder', () => {
it('should handle multiple thinking blocks', () => { it('should handle multiple thinking blocks', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('<think>First</think>Answer<think>Second</think>More content') builder.addAssistantMessage(
'<think>First</think>Answer<think>Second</think>More content'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result[0].content).toBe('Answer<think>Second</think>More content') expect(result[0].content).toBe('Answer<think>Second</think>More content')
@ -362,16 +387,22 @@ describe('CompletionMessagesBuilder', () => {
it('should handle unclosed thinking tags', () => { it('should handle unclosed thinking tags', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('<think>Unclosed thinking tag... Regular content') builder.addAssistantMessage(
'<think>Unclosed thinking tag... Regular content'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result[0].content).toBe('<think>Unclosed thinking tag... Regular content') expect(result[0].content).toBe(
'<think>Unclosed thinking tag... Regular content'
)
}) })
it('should handle thinking tags with whitespace', () => { it('should handle thinking tags with whitespace', () => {
const builder = new CompletionMessagesBuilder([]) const builder = new CompletionMessagesBuilder([])
builder.addAssistantMessage('<think> \n Some thinking \n </think> \n Clean answer') builder.addAssistantMessage(
'<think> \n Some thinking \n </think> \n Clean answer'
)
const result = builder.getMessages() const result = builder.getMessages()
expect(result[0].content).toBe('Clean answer') expect(result[0].content).toBe('Clean answer')
@ -382,10 +413,16 @@ describe('CompletionMessagesBuilder', () => {
it('should handle complex conversation flow', () => { it('should handle complex conversation flow', () => {
const threadMessages: ThreadMessage[] = [ const threadMessages: ThreadMessage[] = [
createMockThreadMessage('user', 'What is the weather like?'), createMockThreadMessage('user', 'What is the weather like?'),
createMockThreadMessage('assistant', '<think>I need to call weather API</think>Let me check the weather for you.'), createMockThreadMessage(
'assistant',
'<think>I need to call weather API</think>Let me check the weather for you.'
),
] ]
const builder = new CompletionMessagesBuilder(threadMessages, 'You are a weather assistant') const builder = new CompletionMessagesBuilder(
threadMessages,
'You are a weather assistant'
)
// Add tool call and response // Add tool call and response
const toolCalls: ChatCompletionMessageToolCall[] = [ const toolCalls: ChatCompletionMessageToolCall[] = [
@ -399,9 +436,18 @@ describe('CompletionMessagesBuilder', () => {
}, },
] ]
builder.addAssistantMessage('Calling weather service...', undefined, toolCalls) builder.addAssistantMessage(
builder.addToolMessage('{"temperature": 72, "condition": "sunny"}', 'call_weather') 'Calling weather service...',
builder.addAssistantMessage('<think>The weather is nice</think>The weather is 72°F and sunny!') undefined,
toolCalls
)
builder.addToolMessage(
'{"temperature": 72, "condition": "sunny"}',
'call_weather'
)
builder.addAssistantMessage(
'<think>The weather is nice</think>The weather is 72°F and sunny!'
)
const result = builder.getMessages() const result = builder.getMessages()

View File

@ -26,7 +26,7 @@ export class CompletionMessagesBuilder {
content: content:
msg.role === 'assistant' msg.role === 'assistant'
? this.normalizeContent(msg.content[0]?.text?.value || '.') ? this.normalizeContent(msg.content[0]?.text?.value || '.')
: (msg.content[0]?.text?.value || '.'), : msg.content[0]?.text?.value || '.',
}) as ChatCompletionMessageParam }) as ChatCompletionMessageParam
) )
) )
@ -37,6 +37,10 @@ export class CompletionMessagesBuilder {
* @param content - The content of the user message. * @param content - The content of the user message.
*/ */
addUserMessage(content: string) { addUserMessage(content: string) {
// Ensure no consecutive user messages
if (this.messages[this.messages.length - 1]?.role === 'user') {
this.messages.pop()
}
this.messages.push({ this.messages.push({
role: 'user', role: 'user',
content: content, content: content,

View File

@ -39,7 +39,7 @@ function ThreadDetail() {
const lastScrollTopRef = useRef(0) const lastScrollTopRef = useRef(0)
const { currentThreadId, setCurrentThreadId } = useThreads() const { currentThreadId, setCurrentThreadId } = useThreads()
const { setCurrentAssistant, assistants } = useAssistant() const { setCurrentAssistant, assistants } = useAssistant()
const { setMessages } = useMessages() const { setMessages, deleteMessage } = useMessages()
const { streamingContent } = useAppState() const { streamingContent } = useAppState()
const { appMainViewBgColor, chatWidth } = useAppearance() const { appMainViewBgColor, chatWidth } = useAppearance()
const { sendMessage } = useChat() const { sendMessage } = useChat()
@ -221,8 +221,23 @@ function ThreadDetail() {
// used when there is a sent/added user message and no assistant message (error or manual deletion) // used when there is a sent/added user message and no assistant message (error or manual deletion)
const generateAIResponse = () => { const generateAIResponse = () => {
const latestUserMessage = messages[messages.length - 1] const latestUserMessage = messages[messages.length - 1]
if (latestUserMessage?.content?.[0]?.text?.value) { if (
latestUserMessage?.content?.[0]?.text?.value &&
latestUserMessage.role === 'user'
) {
sendMessage(latestUserMessage.content[0].text.value, false) sendMessage(latestUserMessage.content[0].text.value, false)
} else if (latestUserMessage?.metadata?.tool_calls) {
// Only regenerate assistant message is allowed
const threadMessages = [...messages]
let toSendMessage = threadMessages.pop()
while (toSendMessage && toSendMessage?.role !== 'user') {
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
toSendMessage = threadMessages.pop()
}
if (toSendMessage) {
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
}
} }
} }
@ -232,7 +247,10 @@ function ThreadDetail() {
const showScrollToBottomBtn = !isAtBottom && hasScrollbar const showScrollToBottomBtn = !isAtBottom && hasScrollbar
const showGenerateAIResponseBtn = const showGenerateAIResponseBtn =
messages[messages.length - 1]?.role === 'user' && !streamingContent (messages[messages.length - 1]?.role === 'user' ||
(messages[messages.length - 1]?.metadata &&
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) &&
!streamingContent
return ( return (
<div className="flex flex-col h-full"> <div className="flex flex-col h-full">