Merge pull request #4900 from menloresearch/feat/jan-ui-with-tool-use

feat: jan UI with Tool use UX
This commit is contained in:
Louis 2025-04-14 15:23:31 +07:00 committed by GitHub
parent a8e418c4d3
commit 57786e5e45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 485 additions and 232 deletions

View File

@ -7,6 +7,7 @@ export enum ChatCompletionRole {
System = 'system',
Assistant = 'assistant',
User = 'user',
Tool = 'tool',
}
/**
@ -18,6 +19,9 @@ export type ChatCompletionMessage = {
content?: ChatCompletionMessageContent
/** The role of the author of this message. **/
role: ChatCompletionRole
type?: string
output?: string
tool_call_id?: string
}
export type ChatCompletionMessageContent =

View File

@ -36,6 +36,8 @@ export type ThreadMessage = {
type?: string
/** The error code which explain what error type. Used in conjunction with MessageStatus.Error */
error_code?: ErrorCode
tool_call_id?: string
}
/**

View File

@ -114,7 +114,7 @@ export default function ModelHandler() {
const onNewMessageResponse = useCallback(
async (message: ThreadMessage) => {
if (message.type === MessageRequestType.Thread) {
if (message.type !== MessageRequestType.Summary) {
addNewMessage(message)
}
},
@ -129,35 +129,20 @@ export default function ModelHandler() {
const updateThreadTitle = useCallback(
(message: ThreadMessage) => {
// Update only when it's finished
if (message.status !== MessageStatus.Ready) {
return
}
if (message.status !== MessageStatus.Ready) return
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
if (!thread) {
console.warn(
`Failed to update title for thread ${message.thread_id}: Thread not found!`
)
return
}
let messageContent = message.content[0]?.text?.value
if (!messageContent) {
console.warn(
`Failed to update title for thread ${message.thread_id}: Responded content is null!`
)
return
}
if (!thread || !messageContent) return
// No new line character is presented in the title
// And non-alphanumeric characters should be removed
if (messageContent.includes('\n')) {
if (messageContent.includes('\n'))
messageContent = messageContent.replace(/\n/g, ' ')
}
const match = messageContent.match(/<\/think>(.*)$/)
if (match) {
messageContent = match[1]
}
if (match) messageContent = match[1]
// Remove non-alphanumeric characters
const cleanedMessageContent = messageContent
.replace(/[^\p{L}\s]+/gu, '')
@ -193,18 +178,13 @@ export default function ModelHandler() {
const updateThreadMessage = useCallback(
(message: ThreadMessage) => {
if (
messageGenerationSubscriber.current &&
message.thread_id === activeThreadRef.current?.id &&
!messageGenerationSubscriber.current!.thread_id
) {
updateMessage(
message.id,
message.thread_id,
message.content,
message.status
)
}
updateMessage(
message.id,
message.thread_id,
message.content,
message.metadata,
message.status
)
if (message.status === MessageStatus.Pending) {
if (message.content.length) {
@ -243,16 +223,19 @@ export default function ModelHandler() {
engines &&
isLocalEngine(engines, activeModelRef.current.engine)
) {
;(async () => {
if (
!(await extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.isModelLoaded(activeModelRef.current?.id as string))
) {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: undefined })
}
})()
extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.isModelLoaded(activeModelRef.current?.id as string)
.then((isLoaded) => {
if (!isLoaded) {
setActiveModel(undefined)
setStateModel({
state: 'start',
loading: false,
model: undefined,
})
}
})
}
// Mark the thread as not waiting for response
updateThreadWaiting(message.thread_id, false)
@ -296,19 +279,10 @@ export default function ModelHandler() {
error_code: message.error_code,
}
}
;(async () => {
const updatedMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(message)
.catch(() => undefined)
if (updatedMessage) {
deleteMessage(message.id)
addNewMessage(updatedMessage)
setTokenSpeed((prev) =>
prev ? { ...prev, message: updatedMessage.id } : undefined
)
}
})()
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(message)
// Attempt to generate the title of the Thread when needed
generateThreadTitle(message, thread)
@ -319,25 +293,21 @@ export default function ModelHandler() {
const onMessageResponseUpdate = useCallback(
(message: ThreadMessage) => {
switch (message.type) {
case MessageRequestType.Summary:
updateThreadTitle(message)
break
default:
updateThreadMessage(message)
break
}
if (message.type === MessageRequestType.Summary)
updateThreadTitle(message)
else updateThreadMessage(message)
},
[updateThreadMessage, updateThreadTitle]
)
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
// If this is the first ever prompt in the thread
if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle)
if (
!activeModelRef.current ||
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
)
return
if (!activeModelRef.current) return
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
if (
activeModelRef.current?.engine !== InferenceEngine.cortex &&

View File

@ -165,6 +165,7 @@ export const updateMessageAtom = atom(
id: string,
conversationId: string,
text: ThreadContent[],
metadata: Record<string, unknown> | undefined,
status: MessageStatus
) => {
const messages = get(chatMessages)[conversationId] ?? []
@ -172,6 +173,7 @@ export const updateMessageAtom = atom(
if (message) {
message.content = text
message.status = status
message.metadata = metadata
const updatedMessages = [...messages]
const newData: Record<string, ThreadMessage[]> = {
@ -192,6 +194,7 @@ export const updateMessageAtom = atom(
created_at: Date.now() / 1000,
completed_at: Date.now() / 1000,
object: 'thread.message',
metadata: metadata,
})
}
}

View File

@ -26,6 +26,7 @@ import {
ChatCompletionTool,
} from 'openai/resources/chat'
import { Stream } from 'openai/streaming'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
@ -258,111 +259,63 @@ export default function useSendChatMessage() {
baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
})
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = {
id: messageId,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
role: ChatCompletionRole.Assistant,
content: [],
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
status: MessageStatus.Pending,
created_at: Date.now() / 1000,
completed_at: Date.now() / 1000,
}
events.emit(MessageEvent.OnMessageResponse, message)
const response = await openai.chat.completions.create({
messages: (data.messages ?? []).map((e) => {
return {
role: e.role as OpenAIChatCompletionRole,
content: e.content,
}
}) as ChatCompletionMessageParam[],
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: false,
stream: data.model?.parameters?.stream ?? false,
tool_choice: 'auto',
})
if (response.choices[0]?.message.content) {
const newMessage: ThreadMessage = {
id: ulid(),
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: response.choices[0].message.role as ChatCompletionRole,
content: [
{
type: ContentType.Text,
text: {
value: response.choices[0].message.content
? response.choices[0].message.content
: '',
annotations: [],
},
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
],
status: MessageStatus.Ready,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(
response.choices[0].message.content ?? ''
},
]
}
if (data.model?.parameters?.stream)
isDone = await processStreamingResponse(
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder,
message
)
else {
isDone = await processNonStreamingResponse(
response as OpenAI.Chat.Completions.ChatCompletion,
requestBuilder,
message
)
events.emit(MessageEvent.OnMessageUpdate, newMessage)
}
if (response.choices[0]?.message.tool_calls) {
for (const toolCall of response.choices[0].message.tool_calls) {
const id = ulid()
const toolMessage: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: ChatCompletionRole.Assistant,
content: [
{
type: ContentType.Text,
text: {
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`,
annotations: [],
},
},
],
status: MessageStatus.Pending,
created_at: Date.now(),
completed_at: Date.now(),
}
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) {
console.error(result.error)
break
}
const message: ThreadMessage = {
id: id,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
attachments: [],
role: ChatCompletionRole.Assistant,
content: [
{
type: ContentType.Text,
text: {
value:
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
(result.content[0]?.text ?? ''),
annotations: [],
},
},
],
status: MessageStatus.Ready,
created_at: Date.now(),
completed_at: Date.now(),
}
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
requestBuilder.pushMessage('Go for the next step')
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
isDone =
!response.choices[0]?.message.tool_calls ||
!response.choices[0]?.message.tool_calls.length
}
} else {
// Request for inference
@ -376,6 +329,184 @@ export default function useSendChatMessage() {
setEngineParamsUpdate(false)
}
const processNonStreamingResponse = async (
response: OpenAI.Chat.Completions.ChatCompletion,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Handle tool calls in the response
const toolCalls: ChatCompletionMessageToolCall[] =
response.choices[0]?.message?.tool_calls ?? []
const content = response.choices[0].message?.content
message.content = [
{
type: ContentType.Text,
text: {
value: content ?? '',
annotations: [],
},
},
]
events.emit(MessageEvent.OnMessageUpdate, message)
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
content ?? ''
)
return !toolCalls || !toolCalls.length
}
const processStreamingResponse = async (
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Variables to track and accumulate streaming content
let currentToolCall: {
id: string
function: { name: string; arguments: string }
} | null = null
let accumulatedContent = ''
const toolCalls: ChatCompletionMessageToolCall[] = []
// Process the streaming chunks
for await (const chunk of response) {
// Handle tool calls in the chunk
if (chunk.choices[0]?.delta?.tool_calls) {
const deltaToolCalls = chunk.choices[0].delta.tool_calls
// Handle the beginning of a new tool call
if (
deltaToolCalls[0]?.index !== undefined &&
deltaToolCalls[0]?.function
) {
const index = deltaToolCalls[0].index
// Create new tool call if this is the first chunk for it
if (!toolCalls[index]) {
toolCalls[index] = {
id: deltaToolCalls[0]?.id || '',
function: {
name: deltaToolCalls[0]?.function?.name || '',
arguments: deltaToolCalls[0]?.function?.arguments || '',
},
type: 'function',
}
currentToolCall = toolCalls[index]
} else {
// Continuation of existing tool call
currentToolCall = toolCalls[index]
// Append to function name or arguments if they exist in this chunk
if (deltaToolCalls[0]?.function?.name) {
currentToolCall!.function.name += deltaToolCalls[0].function.name
}
if (deltaToolCalls[0]?.function?.arguments) {
currentToolCall!.function.arguments +=
deltaToolCalls[0].function.arguments
}
}
}
}
// Handle regular content in the chunk
if (chunk.choices[0]?.delta?.content) {
const content = chunk.choices[0].delta.content
accumulatedContent += content
message.content = [
{
type: ContentType.Text,
text: {
value: accumulatedContent,
annotations: [],
},
},
]
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
accumulatedContent ?? ''
)
return !toolCalls || !toolCalls.length
}
const postMessageProcessing = async (
toolCalls: ChatCompletionMessageToolCall[],
requestBuilder: MessageRequestBuilder,
message: ThreadMessage,
content: string
) => {
requestBuilder.pushAssistantMessage({
content,
role: 'assistant',
refusal: null,
tool_calls: toolCalls,
})
// Handle completed tool calls
if (toolCalls.length > 0) {
for (const toolCall of toolCalls) {
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: undefined,
state: 'pending',
},
],
}
events.emit(MessageEvent.OnMessageUpdate, message)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) break
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
requestBuilder.pushToolMessage(
result.content[0]?.text ?? '',
toolCall.id
)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}
return {
sendChatMessage,
resendChatMessage,

View File

@ -1,9 +1,8 @@
import { useRef, useState } from 'react'
import { useState } from 'react'
import { Slider, Input, Tooltip } from '@janhq/joi'
import { Slider, Input } from '@janhq/joi'
import { atom, useAtom } from 'jotai'
import { InfoIcon } from 'lucide-react'
export const hubModelSizeMinAtom = atom(0)
export const hubModelSizeMaxAtom = atom(100)

View File

@ -3,7 +3,15 @@ import { useEffect, useRef, useState } from 'react'
import { InferenceEngine } from '@janhq/core'
import { TextArea, Button, Tooltip, useClickOutside, Badge } from '@janhq/joi'
import {
TextArea,
Button,
Tooltip,
useClickOutside,
Badge,
Modal,
ModalClose,
} from '@janhq/joi'
import { useAtom, useAtomValue } from 'jotai'
import {
FileTextIcon,
@ -13,6 +21,7 @@ import {
SettingsIcon,
ChevronUpIcon,
Settings2Icon,
WrenchIcon,
} from 'lucide-react'
import { twMerge } from 'tailwind-merge'
@ -45,6 +54,7 @@ import {
isBlockingSendAtom,
} from '@/helpers/atoms/Thread.atom'
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'
import { ModelTool } from '@/types/model'
const ChatInput = () => {
const activeThread = useAtomValue(activeThreadAtom)
@ -69,6 +79,8 @@ const ChatInput = () => {
const isBlockingSend = useAtomValue(isBlockingSendAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { stopInference } = useActiveModel()
const [tools, setTools] = useState<any>([])
const [showToolsModal, setShowToolsModal] = useState(false)
const upload = uploader()
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
@ -92,6 +104,12 @@ const ChatInput = () => {
}
}, [activeSettingInputBox, selectedModel, setActiveSettingInputBox])
useEffect(() => {
window.core?.api?.getTools().then((data: ModelTool[]) => {
setTools(data)
})
}, [])
const onStopInferenceClick = async () => {
stopInference()
}
@ -136,6 +154,8 @@ const ChatInput = () => {
}
}
console.log(tools)
return (
<div className="relative p-4 pb-2">
{renderPreview(fileUpload)}
@ -385,6 +405,62 @@ const ChatInput = () => {
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
/>
</Badge>
{tools && tools.length > 0 && (
<>
<Badge
theme="secondary"
className={twMerge(
'flex cursor-pointer items-center gap-x-1'
)}
variant={'outline'}
onClick={() => setShowToolsModal(true)}
>
<WrenchIcon
size={16}
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
/>
<span className="text-xs">{tools.length}</span>
</Badge>
<Modal
open={showToolsModal}
onOpenChange={setShowToolsModal}
title="Available MCP Tools"
content={
<div className="overflow-y-auto">
<div className="mb-2 py-2 text-sm text-[hsla(var(--text-secondary))]">
Jan can use tools provided by specialized servers using
Model Context Protocol.{' '}
<a
href="https://modelcontextprotocol.io/introduction"
target="_blank"
className="text-[hsla(var(--app-link))]"
>
Learn more about MCP
</a>
</div>
{tools.map((tool: any) => (
<div
key={tool.name}
className="flex items-center gap-x-3 px-4 py-3 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]"
>
<WrenchIcon
size={16}
className="flex-shrink-0 text-[hsla(var(--text-secondary))]"
/>
<div>
<div className="font-medium">{tool.name}</div>
<div className="text-sm text-[hsla(var(--text-secondary))]">
{tool.description}
</div>
</div>
</div>
))}
</div>
}
/>
</>
)}
</div>
{selectedModel && (
<Button

View File

@ -0,0 +1,57 @@
import React from 'react'
import { atom, useAtom } from 'jotai'
import { ChevronDown, ChevronUp, Loader } from 'lucide-react'
import { MarkdownTextMessage } from './MarkdownTextMessage'
interface Props {
result: string
name: string
id: number
loading: boolean
}
const toolCallBlockStateAtom = atom<{ [id: number]: boolean }>({})
const ToolCallBlock = ({ id, name, result, loading }: Props) => {
const [collapseState, setCollapseState] = useAtom(toolCallBlockStateAtom)
const isExpanded = collapseState[id] ?? false
const handleClick = () => {
setCollapseState((prev) => ({ ...prev, [id]: !isExpanded }))
}
return (
<div className="mx-auto w-full">
<div className="mb-4 rounded-lg border border-dashed border-[hsla(var(--app-border))] p-2">
<div
className="flex cursor-pointer items-center gap-3"
onClick={handleClick}
>
{loading && (
<Loader className="h-4 w-4 animate-spin text-[hsla(var(--primary-bg))]" />
)}
<button className="flex items-center gap-2 focus:outline-none">
{isExpanded ? (
<ChevronUp className="h-4 w-4" />
) : (
<ChevronDown className="h-4 w-4" />
)}
<span className="font-medium">
{' '}
View result from <span className="font-bold">{name}</span>
</span>
</button>
</div>
{isExpanded && (
<div className="mt-2 overflow-x-hidden pl-6 text-[hsla(var(--text-secondary))]">
<span>{result.trim()} </span>
</div>
)}
</div>
</div>
)
}
export default ToolCallBlock

View File

@ -18,6 +18,8 @@ import ImageMessage from './ImageMessage'
import { MarkdownTextMessage } from './MarkdownTextMessage'
import ThinkingBlock from './ThinkingBlock'
import ToolCallBlock from './ToolCallBlock'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
editMessageAtom,
@ -65,57 +67,62 @@ const MessageContainer: React.FC<
[props.content]
)
const attachedFile = useMemo(() => 'attachments' in props, [props])
const attachedFile = useMemo(
() => 'attachments' in props && props.attachments?.length,
[props]
)
return (
<div
className={twMerge(
'group relative mx-auto px-4 py-2',
'group relative mx-auto px-4',
!(props.metadata && 'parent_id' in props.metadata) && 'py-2',
chatWidth === 'compact' && 'max-w-[700px]',
isUser && 'pb-4 pt-0'
)}
>
<div
className={twMerge(
'mb-2 flex items-center justify-start',
!isUser && 'mt-2 gap-x-2'
)}
>
{!isUser && !isSystem && <LogoMark width={28} />}
{!(props.metadata && 'parent_id' in props.metadata) && (
<div
className={twMerge(
'font-extrabold capitalize',
isUser && 'text-gray-500'
'mb-2 flex items-center justify-start',
!isUser && 'mt-2 gap-x-2'
)}
>
{!isUser && (
<>
{props.metadata && 'model' in props.metadata
? (props.metadata?.model as string)
: props.isCurrentMessage
? selectedModel?.name
: (activeAssistant?.assistant_name ?? props.role)}
</>
)}
{!isUser && !isSystem && <LogoMark width={28} />}
<div
className={twMerge(
'font-extrabold capitalize',
isUser && 'text-gray-500'
)}
>
{!isUser && (
<>
{props.metadata && 'model' in props.metadata
? (props.metadata?.model as string)
: props.isCurrentMessage
? selectedModel?.name
: (activeAssistant?.assistant_name ?? props.role)}
</>
)}
</div>
<p className="text-xs font-medium text-gray-400">
{props.created_at &&
displayDate(props.created_at ?? Date.now() / 1000)}
</p>
</div>
)}
<p className="text-xs font-medium text-gray-400">
{props.created_at &&
displayDate(props.created_at ?? Date.now() / 1000)}
</p>
</div>
<div className="flex w-full flex-col ">
<div className="flex w-full flex-col">
<div
className={twMerge(
'absolute right-0 order-1 flex cursor-pointer items-center justify-start gap-x-2 transition-all',
isUser
? twMerge(
'hidden group-hover:absolute group-hover:right-4 group-hover:top-4 group-hover:z-50 group-hover:flex',
image && 'group-hover:-top-2'
)
: 'relative left-0 order-2 flex w-full justify-between opacity-0 group-hover:opacity-100',
twMerge(
'hidden group-hover:absolute group-hover:-bottom-4 group-hover:right-4 group-hover:z-50 group-hover:flex',
image && 'group-hover:-top-2'
),
props.isCurrentMessage && 'opacity-100'
)}
>
@ -179,6 +186,22 @@ const MessageContainer: React.FC<
/>
</div>
)}
{props.metadata &&
'tool_calls' in props.metadata &&
Array.isArray(props.metadata.tool_calls) &&
props.metadata.tool_calls.length && (
<>
{props.metadata.tool_calls.map((toolCall) => (
<ToolCallBlock
id={toolCall.tool?.id}
name={toolCall.tool?.function?.name ?? ''}
key={toolCall.tool?.id}
result={JSON.stringify(toolCall.response)}
loading={toolCall.state === 'pending'}
/>
))}
</>
)}
</>
</div>
</div>

View File

@ -11,6 +11,7 @@ import {
Thread,
ThreadMessage,
} from '@janhq/core'
import { ChatCompletionMessage as OAIChatCompletionMessage } from 'openai/resources/chat'
import { ulid } from 'ulidx'
import { Stack } from '@/utils/Stack'
@ -45,12 +46,26 @@ export class MessageRequestBuilder {
this.tools = tools
}
pushAssistantMessage(message: string) {
pushAssistantMessage(message: OAIChatCompletionMessage) {
const { content, refusal, ...rest } = message
const normalizedMessage = {
...rest,
...(content ? { content } : {}),
...(refusal ? { refusal } : {}),
}
this.messages = [
...this.messages,
normalizedMessage as ChatCompletionMessage,
]
}
pushToolMessage(message: string, toolCallId: string) {
this.messages = [
...this.messages,
{
role: ChatCompletionRole.Assistant,
role: ChatCompletionRole.Tool,
content: message,
tool_call_id: toolCallId,
},
]
}
@ -140,40 +155,13 @@ export class MessageRequestBuilder {
return this
}
normalizeMessages = (
messages: ChatCompletionMessage[]
): ChatCompletionMessage[] => {
const stack = new Stack<ChatCompletionMessage>()
for (const message of messages) {
if (stack.isEmpty()) {
stack.push(message)
continue
}
const topMessage = stack.peek()
if (message.role === topMessage.role) {
// add an empty message
stack.push({
role:
topMessage.role === ChatCompletionRole.User
? ChatCompletionRole.Assistant
: ChatCompletionRole.User,
content: '.', // some model requires not empty message
})
}
stack.push(message)
}
return stack.reverseOutput()
}
build(): MessageRequest {
return {
id: this.msgId,
type: this.type,
attachments: [],
threadId: this.thread.id,
messages: this.normalizeMessages(this.messages),
messages: this.messages,
model: this.model,
thread: this.thread,
tools: this.tools,