Merge pull request #4900 from menloresearch/feat/jan-ui-with-tool-use
feat: jan UI with Tool use UX
This commit is contained in:
parent
a8e418c4d3
commit
57786e5e45
@ -7,6 +7,7 @@ export enum ChatCompletionRole {
|
|||||||
System = 'system',
|
System = 'system',
|
||||||
Assistant = 'assistant',
|
Assistant = 'assistant',
|
||||||
User = 'user',
|
User = 'user',
|
||||||
|
Tool = 'tool',
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -18,6 +19,9 @@ export type ChatCompletionMessage = {
|
|||||||
content?: ChatCompletionMessageContent
|
content?: ChatCompletionMessageContent
|
||||||
/** The role of the author of this message. **/
|
/** The role of the author of this message. **/
|
||||||
role: ChatCompletionRole
|
role: ChatCompletionRole
|
||||||
|
type?: string
|
||||||
|
output?: string
|
||||||
|
tool_call_id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ChatCompletionMessageContent =
|
export type ChatCompletionMessageContent =
|
||||||
|
|||||||
@ -36,6 +36,8 @@ export type ThreadMessage = {
|
|||||||
type?: string
|
type?: string
|
||||||
/** The error code which explain what error type. Used in conjunction with MessageStatus.Error */
|
/** The error code which explain what error type. Used in conjunction with MessageStatus.Error */
|
||||||
error_code?: ErrorCode
|
error_code?: ErrorCode
|
||||||
|
|
||||||
|
tool_call_id?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -114,7 +114,7 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const onNewMessageResponse = useCallback(
|
const onNewMessageResponse = useCallback(
|
||||||
async (message: ThreadMessage) => {
|
async (message: ThreadMessage) => {
|
||||||
if (message.type === MessageRequestType.Thread) {
|
if (message.type !== MessageRequestType.Summary) {
|
||||||
addNewMessage(message)
|
addNewMessage(message)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -129,35 +129,20 @@ export default function ModelHandler() {
|
|||||||
const updateThreadTitle = useCallback(
|
const updateThreadTitle = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
// Update only when it's finished
|
// Update only when it's finished
|
||||||
if (message.status !== MessageStatus.Ready) {
|
if (message.status !== MessageStatus.Ready) return
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
|
||||||
if (!thread) {
|
|
||||||
console.warn(
|
|
||||||
`Failed to update title for thread ${message.thread_id}: Thread not found!`
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
let messageContent = message.content[0]?.text?.value
|
let messageContent = message.content[0]?.text?.value
|
||||||
if (!messageContent) {
|
if (!thread || !messageContent) return
|
||||||
console.warn(
|
|
||||||
`Failed to update title for thread ${message.thread_id}: Responded content is null!`
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// No new line character is presented in the title
|
// No new line character is presented in the title
|
||||||
// And non-alphanumeric characters should be removed
|
// And non-alphanumeric characters should be removed
|
||||||
if (messageContent.includes('\n')) {
|
if (messageContent.includes('\n'))
|
||||||
messageContent = messageContent.replace(/\n/g, ' ')
|
messageContent = messageContent.replace(/\n/g, ' ')
|
||||||
}
|
|
||||||
const match = messageContent.match(/<\/think>(.*)$/)
|
const match = messageContent.match(/<\/think>(.*)$/)
|
||||||
if (match) {
|
if (match) messageContent = match[1]
|
||||||
messageContent = match[1]
|
|
||||||
}
|
|
||||||
// Remove non-alphanumeric characters
|
// Remove non-alphanumeric characters
|
||||||
const cleanedMessageContent = messageContent
|
const cleanedMessageContent = messageContent
|
||||||
.replace(/[^\p{L}\s]+/gu, '')
|
.replace(/[^\p{L}\s]+/gu, '')
|
||||||
@ -193,18 +178,13 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const updateThreadMessage = useCallback(
|
const updateThreadMessage = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
if (
|
|
||||||
messageGenerationSubscriber.current &&
|
|
||||||
message.thread_id === activeThreadRef.current?.id &&
|
|
||||||
!messageGenerationSubscriber.current!.thread_id
|
|
||||||
) {
|
|
||||||
updateMessage(
|
updateMessage(
|
||||||
message.id,
|
message.id,
|
||||||
message.thread_id,
|
message.thread_id,
|
||||||
message.content,
|
message.content,
|
||||||
|
message.metadata,
|
||||||
message.status
|
message.status
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
if (message.status === MessageStatus.Pending) {
|
if (message.status === MessageStatus.Pending) {
|
||||||
if (message.content.length) {
|
if (message.content.length) {
|
||||||
@ -243,16 +223,19 @@ export default function ModelHandler() {
|
|||||||
engines &&
|
engines &&
|
||||||
isLocalEngine(engines, activeModelRef.current.engine)
|
isLocalEngine(engines, activeModelRef.current.engine)
|
||||||
) {
|
) {
|
||||||
;(async () => {
|
extensionManager
|
||||||
if (
|
|
||||||
!(await extensionManager
|
|
||||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||||
?.isModelLoaded(activeModelRef.current?.id as string))
|
?.isModelLoaded(activeModelRef.current?.id as string)
|
||||||
) {
|
.then((isLoaded) => {
|
||||||
|
if (!isLoaded) {
|
||||||
setActiveModel(undefined)
|
setActiveModel(undefined)
|
||||||
setStateModel({ state: 'start', loading: false, model: undefined })
|
setStateModel({
|
||||||
|
state: 'start',
|
||||||
|
loading: false,
|
||||||
|
model: undefined,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})()
|
})
|
||||||
}
|
}
|
||||||
// Mark the thread as not waiting for response
|
// Mark the thread as not waiting for response
|
||||||
updateThreadWaiting(message.thread_id, false)
|
updateThreadWaiting(message.thread_id, false)
|
||||||
@ -296,19 +279,10 @@ export default function ModelHandler() {
|
|||||||
error_code: message.error_code,
|
error_code: message.error_code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
;(async () => {
|
|
||||||
const updatedMessage = await extensionManager
|
extensionManager
|
||||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||||
?.createMessage(message)
|
?.createMessage(message)
|
||||||
.catch(() => undefined)
|
|
||||||
if (updatedMessage) {
|
|
||||||
deleteMessage(message.id)
|
|
||||||
addNewMessage(updatedMessage)
|
|
||||||
setTokenSpeed((prev) =>
|
|
||||||
prev ? { ...prev, message: updatedMessage.id } : undefined
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
// Attempt to generate the title of the Thread when needed
|
// Attempt to generate the title of the Thread when needed
|
||||||
generateThreadTitle(message, thread)
|
generateThreadTitle(message, thread)
|
||||||
@ -319,25 +293,21 @@ export default function ModelHandler() {
|
|||||||
|
|
||||||
const onMessageResponseUpdate = useCallback(
|
const onMessageResponseUpdate = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
switch (message.type) {
|
if (message.type === MessageRequestType.Summary)
|
||||||
case MessageRequestType.Summary:
|
|
||||||
updateThreadTitle(message)
|
updateThreadTitle(message)
|
||||||
break
|
else updateThreadMessage(message)
|
||||||
default:
|
|
||||||
updateThreadMessage(message)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
[updateThreadMessage, updateThreadTitle]
|
[updateThreadMessage, updateThreadTitle]
|
||||||
)
|
)
|
||||||
|
|
||||||
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
const generateThreadTitle = (message: ThreadMessage, thread: Thread) => {
|
||||||
// If this is the first ever prompt in the thread
|
// If this is the first ever prompt in the thread
|
||||||
if ((thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle)
|
if (
|
||||||
|
!activeModelRef.current ||
|
||||||
|
(thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (!activeModelRef.current) return
|
|
||||||
|
|
||||||
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
|
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
|
||||||
if (
|
if (
|
||||||
activeModelRef.current?.engine !== InferenceEngine.cortex &&
|
activeModelRef.current?.engine !== InferenceEngine.cortex &&
|
||||||
|
|||||||
@ -165,6 +165,7 @@ export const updateMessageAtom = atom(
|
|||||||
id: string,
|
id: string,
|
||||||
conversationId: string,
|
conversationId: string,
|
||||||
text: ThreadContent[],
|
text: ThreadContent[],
|
||||||
|
metadata: Record<string, unknown> | undefined,
|
||||||
status: MessageStatus
|
status: MessageStatus
|
||||||
) => {
|
) => {
|
||||||
const messages = get(chatMessages)[conversationId] ?? []
|
const messages = get(chatMessages)[conversationId] ?? []
|
||||||
@ -172,6 +173,7 @@ export const updateMessageAtom = atom(
|
|||||||
if (message) {
|
if (message) {
|
||||||
message.content = text
|
message.content = text
|
||||||
message.status = status
|
message.status = status
|
||||||
|
message.metadata = metadata
|
||||||
const updatedMessages = [...messages]
|
const updatedMessages = [...messages]
|
||||||
|
|
||||||
const newData: Record<string, ThreadMessage[]> = {
|
const newData: Record<string, ThreadMessage[]> = {
|
||||||
@ -192,6 +194,7 @@ export const updateMessageAtom = atom(
|
|||||||
created_at: Date.now() / 1000,
|
created_at: Date.now() / 1000,
|
||||||
completed_at: Date.now() / 1000,
|
completed_at: Date.now() / 1000,
|
||||||
object: 'thread.message',
|
object: 'thread.message',
|
||||||
|
metadata: metadata,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import {
|
|||||||
ChatCompletionTool,
|
ChatCompletionTool,
|
||||||
} from 'openai/resources/chat'
|
} from 'openai/resources/chat'
|
||||||
|
|
||||||
|
import { Stream } from 'openai/streaming'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
|
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
|
||||||
@ -258,111 +259,63 @@ export default function useSendChatMessage() {
|
|||||||
baseURL: `${API_BASE_URL}/v1`,
|
baseURL: `${API_BASE_URL}/v1`,
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
})
|
})
|
||||||
|
let parentMessageId: string | undefined
|
||||||
while (!isDone) {
|
while (!isDone) {
|
||||||
const data = requestBuilder.build()
|
let messageId = ulid()
|
||||||
const response = await openai.chat.completions.create({
|
if (!parentMessageId) {
|
||||||
messages: (data.messages ?? []).map((e) => {
|
parentMessageId = ulid()
|
||||||
return {
|
messageId = parentMessageId
|
||||||
role: e.role as OpenAIChatCompletionRole,
|
|
||||||
content: e.content,
|
|
||||||
}
|
}
|
||||||
}) as ChatCompletionMessageParam[],
|
const data = requestBuilder.build()
|
||||||
|
const message: ThreadMessage = {
|
||||||
|
id: messageId,
|
||||||
|
object: 'message',
|
||||||
|
thread_id: activeThreadRef.current.id,
|
||||||
|
assistant_id: activeAssistantRef.current.assistant_id,
|
||||||
|
role: ChatCompletionRole.Assistant,
|
||||||
|
content: [],
|
||||||
|
metadata: {
|
||||||
|
...(messageId !== parentMessageId
|
||||||
|
? { parent_id: parentMessageId }
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
|
status: MessageStatus.Pending,
|
||||||
|
created_at: Date.now() / 1000,
|
||||||
|
completed_at: Date.now() / 1000,
|
||||||
|
}
|
||||||
|
events.emit(MessageEvent.OnMessageResponse, message)
|
||||||
|
const response = await openai.chat.completions.create({
|
||||||
|
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||||
model: data.model?.id ?? '',
|
model: data.model?.id ?? '',
|
||||||
tools: data.tools as ChatCompletionTool[],
|
tools: data.tools as ChatCompletionTool[],
|
||||||
stream: false,
|
stream: data.model?.parameters?.stream ?? false,
|
||||||
|
tool_choice: 'auto',
|
||||||
})
|
})
|
||||||
if (response.choices[0]?.message.content) {
|
// Variables to track and accumulate streaming content
|
||||||
const newMessage: ThreadMessage = {
|
if (!message.content.length) {
|
||||||
id: ulid(),
|
message.content = [
|
||||||
object: 'message',
|
|
||||||
thread_id: activeThreadRef.current.id,
|
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
|
||||||
attachments: [],
|
|
||||||
role: response.choices[0].message.role as ChatCompletionRole,
|
|
||||||
content: [
|
|
||||||
{
|
{
|
||||||
type: ContentType.Text,
|
type: ContentType.Text,
|
||||||
text: {
|
text: {
|
||||||
value: response.choices[0].message.content
|
value: '',
|
||||||
? response.choices[0].message.content
|
|
||||||
: '',
|
|
||||||
annotations: [],
|
annotations: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
]
|
||||||
status: MessageStatus.Ready,
|
|
||||||
created_at: Date.now(),
|
|
||||||
completed_at: Date.now(),
|
|
||||||
}
|
}
|
||||||
requestBuilder.pushAssistantMessage(
|
if (data.model?.parameters?.stream)
|
||||||
response.choices[0].message.content ?? ''
|
isDone = await processStreamingResponse(
|
||||||
|
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||||
|
requestBuilder,
|
||||||
|
message
|
||||||
|
)
|
||||||
|
else {
|
||||||
|
isDone = await processNonStreamingResponse(
|
||||||
|
response as OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
requestBuilder,
|
||||||
|
message
|
||||||
)
|
)
|
||||||
events.emit(MessageEvent.OnMessageUpdate, newMessage)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (response.choices[0]?.message.tool_calls) {
|
|
||||||
for (const toolCall of response.choices[0].message.tool_calls) {
|
|
||||||
const id = ulid()
|
|
||||||
const toolMessage: ThreadMessage = {
|
|
||||||
id: id,
|
|
||||||
object: 'message',
|
|
||||||
thread_id: activeThreadRef.current.id,
|
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
|
||||||
attachments: [],
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: `<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}`,
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created_at: Date.now(),
|
|
||||||
completed_at: Date.now(),
|
|
||||||
}
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, toolMessage)
|
|
||||||
const result = await window.core.api.callTool({
|
|
||||||
toolName: toolCall.function.name,
|
|
||||||
arguments: JSON.parse(toolCall.function.arguments),
|
|
||||||
})
|
|
||||||
if (result.error) {
|
|
||||||
console.error(result.error)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: id,
|
|
||||||
object: 'message',
|
|
||||||
thread_id: activeThreadRef.current.id,
|
|
||||||
assistant_id: activeAssistantRef.current.assistant_id,
|
|
||||||
attachments: [],
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value:
|
|
||||||
`<think>Executing Tool ${toolCall.function.name} with arguments ${toolCall.function.arguments}</think>` +
|
|
||||||
(result.content[0]?.text ?? ''),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
status: MessageStatus.Ready,
|
|
||||||
created_at: Date.now(),
|
|
||||||
completed_at: Date.now(),
|
|
||||||
}
|
|
||||||
requestBuilder.pushAssistantMessage(result.content[0]?.text ?? '')
|
|
||||||
requestBuilder.pushMessage('Go for the next step')
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isDone =
|
|
||||||
!response.choices[0]?.message.tool_calls ||
|
|
||||||
!response.choices[0]?.message.tool_calls.length
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Request for inference
|
// Request for inference
|
||||||
@ -376,6 +329,184 @@ export default function useSendChatMessage() {
|
|||||||
setEngineParamsUpdate(false)
|
setEngineParamsUpdate(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const processNonStreamingResponse = async (
|
||||||
|
response: OpenAI.Chat.Completions.ChatCompletion,
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage
|
||||||
|
): Promise<boolean> => {
|
||||||
|
// Handle tool calls in the response
|
||||||
|
const toolCalls: ChatCompletionMessageToolCall[] =
|
||||||
|
response.choices[0]?.message?.tool_calls ?? []
|
||||||
|
const content = response.choices[0].message?.content
|
||||||
|
message.content = [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: content ?? '',
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
await postMessageProcessing(
|
||||||
|
toolCalls ?? [],
|
||||||
|
requestBuilder,
|
||||||
|
message,
|
||||||
|
content ?? ''
|
||||||
|
)
|
||||||
|
return !toolCalls || !toolCalls.length
|
||||||
|
}
|
||||||
|
|
||||||
|
const processStreamingResponse = async (
|
||||||
|
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage
|
||||||
|
): Promise<boolean> => {
|
||||||
|
// Variables to track and accumulate streaming content
|
||||||
|
let currentToolCall: {
|
||||||
|
id: string
|
||||||
|
function: { name: string; arguments: string }
|
||||||
|
} | null = null
|
||||||
|
let accumulatedContent = ''
|
||||||
|
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||||
|
// Process the streaming chunks
|
||||||
|
for await (const chunk of response) {
|
||||||
|
// Handle tool calls in the chunk
|
||||||
|
if (chunk.choices[0]?.delta?.tool_calls) {
|
||||||
|
const deltaToolCalls = chunk.choices[0].delta.tool_calls
|
||||||
|
|
||||||
|
// Handle the beginning of a new tool call
|
||||||
|
if (
|
||||||
|
deltaToolCalls[0]?.index !== undefined &&
|
||||||
|
deltaToolCalls[0]?.function
|
||||||
|
) {
|
||||||
|
const index = deltaToolCalls[0].index
|
||||||
|
|
||||||
|
// Create new tool call if this is the first chunk for it
|
||||||
|
if (!toolCalls[index]) {
|
||||||
|
toolCalls[index] = {
|
||||||
|
id: deltaToolCalls[0]?.id || '',
|
||||||
|
function: {
|
||||||
|
name: deltaToolCalls[0]?.function?.name || '',
|
||||||
|
arguments: deltaToolCalls[0]?.function?.arguments || '',
|
||||||
|
},
|
||||||
|
type: 'function',
|
||||||
|
}
|
||||||
|
currentToolCall = toolCalls[index]
|
||||||
|
} else {
|
||||||
|
// Continuation of existing tool call
|
||||||
|
currentToolCall = toolCalls[index]
|
||||||
|
|
||||||
|
// Append to function name or arguments if they exist in this chunk
|
||||||
|
if (deltaToolCalls[0]?.function?.name) {
|
||||||
|
currentToolCall!.function.name += deltaToolCalls[0].function.name
|
||||||
|
}
|
||||||
|
|
||||||
|
if (deltaToolCalls[0]?.function?.arguments) {
|
||||||
|
currentToolCall!.function.arguments +=
|
||||||
|
deltaToolCalls[0].function.arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular content in the chunk
|
||||||
|
if (chunk.choices[0]?.delta?.content) {
|
||||||
|
const content = chunk.choices[0].delta.content
|
||||||
|
accumulatedContent += content
|
||||||
|
|
||||||
|
message.content = [
|
||||||
|
{
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: accumulatedContent,
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await postMessageProcessing(
|
||||||
|
toolCalls ?? [],
|
||||||
|
requestBuilder,
|
||||||
|
message,
|
||||||
|
accumulatedContent ?? ''
|
||||||
|
)
|
||||||
|
return !toolCalls || !toolCalls.length
|
||||||
|
}
|
||||||
|
|
||||||
|
const postMessageProcessing = async (
|
||||||
|
toolCalls: ChatCompletionMessageToolCall[],
|
||||||
|
requestBuilder: MessageRequestBuilder,
|
||||||
|
message: ThreadMessage,
|
||||||
|
content: string
|
||||||
|
) => {
|
||||||
|
requestBuilder.pushAssistantMessage({
|
||||||
|
content,
|
||||||
|
role: 'assistant',
|
||||||
|
refusal: null,
|
||||||
|
tool_calls: toolCalls,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle completed tool calls
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
for (const toolCall of toolCalls) {
|
||||||
|
const toolId = ulid()
|
||||||
|
const toolCallsMetadata =
|
||||||
|
message.metadata?.tool_calls &&
|
||||||
|
Array.isArray(message.metadata?.tool_calls)
|
||||||
|
? message.metadata?.tool_calls
|
||||||
|
: []
|
||||||
|
message.metadata = {
|
||||||
|
...(message.metadata ?? {}),
|
||||||
|
tool_calls: [
|
||||||
|
...toolCallsMetadata,
|
||||||
|
{
|
||||||
|
tool: {
|
||||||
|
...toolCall,
|
||||||
|
id: toolId,
|
||||||
|
},
|
||||||
|
response: undefined,
|
||||||
|
state: 'pending',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
|
||||||
|
const result = await window.core.api.callTool({
|
||||||
|
toolName: toolCall.function.name,
|
||||||
|
arguments: JSON.parse(toolCall.function.arguments),
|
||||||
|
})
|
||||||
|
if (result.error) break
|
||||||
|
|
||||||
|
message.metadata = {
|
||||||
|
...(message.metadata ?? {}),
|
||||||
|
tool_calls: [
|
||||||
|
...toolCallsMetadata,
|
||||||
|
{
|
||||||
|
tool: {
|
||||||
|
...toolCall,
|
||||||
|
id: toolId,
|
||||||
|
},
|
||||||
|
response: result,
|
||||||
|
state: 'ready',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
requestBuilder.pushToolMessage(
|
||||||
|
result.content[0]?.text ?? '',
|
||||||
|
toolCall.id
|
||||||
|
)
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message.status = MessageStatus.Ready
|
||||||
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sendChatMessage,
|
sendChatMessage,
|
||||||
resendChatMessage,
|
resendChatMessage,
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
import { useRef, useState } from 'react'
|
import { useState } from 'react'
|
||||||
|
|
||||||
import { Slider, Input, Tooltip } from '@janhq/joi'
|
import { Slider, Input } from '@janhq/joi'
|
||||||
|
|
||||||
import { atom, useAtom } from 'jotai'
|
import { atom, useAtom } from 'jotai'
|
||||||
import { InfoIcon } from 'lucide-react'
|
|
||||||
|
|
||||||
export const hubModelSizeMinAtom = atom(0)
|
export const hubModelSizeMinAtom = atom(0)
|
||||||
export const hubModelSizeMaxAtom = atom(100)
|
export const hubModelSizeMaxAtom = atom(100)
|
||||||
|
|||||||
@ -3,7 +3,15 @@ import { useEffect, useRef, useState } from 'react'
|
|||||||
|
|
||||||
import { InferenceEngine } from '@janhq/core'
|
import { InferenceEngine } from '@janhq/core'
|
||||||
|
|
||||||
import { TextArea, Button, Tooltip, useClickOutside, Badge } from '@janhq/joi'
|
import {
|
||||||
|
TextArea,
|
||||||
|
Button,
|
||||||
|
Tooltip,
|
||||||
|
useClickOutside,
|
||||||
|
Badge,
|
||||||
|
Modal,
|
||||||
|
ModalClose,
|
||||||
|
} from '@janhq/joi'
|
||||||
import { useAtom, useAtomValue } from 'jotai'
|
import { useAtom, useAtomValue } from 'jotai'
|
||||||
import {
|
import {
|
||||||
FileTextIcon,
|
FileTextIcon,
|
||||||
@ -13,6 +21,7 @@ import {
|
|||||||
SettingsIcon,
|
SettingsIcon,
|
||||||
ChevronUpIcon,
|
ChevronUpIcon,
|
||||||
Settings2Icon,
|
Settings2Icon,
|
||||||
|
WrenchIcon,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
|
|
||||||
import { twMerge } from 'tailwind-merge'
|
import { twMerge } from 'tailwind-merge'
|
||||||
@ -45,6 +54,7 @@ import {
|
|||||||
isBlockingSendAtom,
|
isBlockingSendAtom,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'
|
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'
|
||||||
|
import { ModelTool } from '@/types/model'
|
||||||
|
|
||||||
const ChatInput = () => {
|
const ChatInput = () => {
|
||||||
const activeThread = useAtomValue(activeThreadAtom)
|
const activeThread = useAtomValue(activeThreadAtom)
|
||||||
@ -69,6 +79,8 @@ const ChatInput = () => {
|
|||||||
const isBlockingSend = useAtomValue(isBlockingSendAtom)
|
const isBlockingSend = useAtomValue(isBlockingSendAtom)
|
||||||
const activeAssistant = useAtomValue(activeAssistantAtom)
|
const activeAssistant = useAtomValue(activeAssistantAtom)
|
||||||
const { stopInference } = useActiveModel()
|
const { stopInference } = useActiveModel()
|
||||||
|
const [tools, setTools] = useState<any>([])
|
||||||
|
const [showToolsModal, setShowToolsModal] = useState(false)
|
||||||
|
|
||||||
const upload = uploader()
|
const upload = uploader()
|
||||||
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom(
|
||||||
@ -92,6 +104,12 @@ const ChatInput = () => {
|
|||||||
}
|
}
|
||||||
}, [activeSettingInputBox, selectedModel, setActiveSettingInputBox])
|
}, [activeSettingInputBox, selectedModel, setActiveSettingInputBox])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
window.core?.api?.getTools().then((data: ModelTool[]) => {
|
||||||
|
setTools(data)
|
||||||
|
})
|
||||||
|
}, [])
|
||||||
|
|
||||||
const onStopInferenceClick = async () => {
|
const onStopInferenceClick = async () => {
|
||||||
stopInference()
|
stopInference()
|
||||||
}
|
}
|
||||||
@ -136,6 +154,8 @@ const ChatInput = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log(tools)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative p-4 pb-2">
|
<div className="relative p-4 pb-2">
|
||||||
{renderPreview(fileUpload)}
|
{renderPreview(fileUpload)}
|
||||||
@ -385,6 +405,62 @@ const ChatInput = () => {
|
|||||||
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
||||||
/>
|
/>
|
||||||
</Badge>
|
</Badge>
|
||||||
|
{tools && tools.length > 0 && (
|
||||||
|
<>
|
||||||
|
<Badge
|
||||||
|
theme="secondary"
|
||||||
|
className={twMerge(
|
||||||
|
'flex cursor-pointer items-center gap-x-1'
|
||||||
|
)}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={() => setShowToolsModal(true)}
|
||||||
|
>
|
||||||
|
<WrenchIcon
|
||||||
|
size={16}
|
||||||
|
className="flex-shrink-0 cursor-pointer text-[hsla(var(--text-secondary))]"
|
||||||
|
/>
|
||||||
|
<span className="text-xs">{tools.length}</span>
|
||||||
|
</Badge>
|
||||||
|
|
||||||
|
<Modal
|
||||||
|
open={showToolsModal}
|
||||||
|
onOpenChange={setShowToolsModal}
|
||||||
|
title="Available MCP Tools"
|
||||||
|
content={
|
||||||
|
<div className="overflow-y-auto">
|
||||||
|
<div className="mb-2 py-2 text-sm text-[hsla(var(--text-secondary))]">
|
||||||
|
Jan can use tools provided by specialized servers using
|
||||||
|
Model Context Protocol.{' '}
|
||||||
|
<a
|
||||||
|
href="https://modelcontextprotocol.io/introduction"
|
||||||
|
target="_blank"
|
||||||
|
className="text-[hsla(var(--app-link))]"
|
||||||
|
>
|
||||||
|
Learn more about MCP
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
{tools.map((tool: any) => (
|
||||||
|
<div
|
||||||
|
key={tool.name}
|
||||||
|
className="flex items-center gap-x-3 px-4 py-3 hover:bg-[hsla(var(--dropdown-menu-hover-bg))]"
|
||||||
|
>
|
||||||
|
<WrenchIcon
|
||||||
|
size={16}
|
||||||
|
className="flex-shrink-0 text-[hsla(var(--text-secondary))]"
|
||||||
|
/>
|
||||||
|
<div>
|
||||||
|
<div className="font-medium">{tool.name}</div>
|
||||||
|
<div className="text-sm text-[hsla(var(--text-secondary))]">
|
||||||
|
{tool.description}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
{selectedModel && (
|
{selectedModel && (
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@ -0,0 +1,57 @@
|
|||||||
|
import React from 'react'
|
||||||
|
|
||||||
|
import { atom, useAtom } from 'jotai'
|
||||||
|
import { ChevronDown, ChevronUp, Loader } from 'lucide-react'
|
||||||
|
|
||||||
|
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
result: string
|
||||||
|
name: string
|
||||||
|
id: number
|
||||||
|
loading: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolCallBlockStateAtom = atom<{ [id: number]: boolean }>({})
|
||||||
|
|
||||||
|
const ToolCallBlock = ({ id, name, result, loading }: Props) => {
|
||||||
|
const [collapseState, setCollapseState] = useAtom(toolCallBlockStateAtom)
|
||||||
|
|
||||||
|
const isExpanded = collapseState[id] ?? false
|
||||||
|
const handleClick = () => {
|
||||||
|
setCollapseState((prev) => ({ ...prev, [id]: !isExpanded }))
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div className="mx-auto w-full">
|
||||||
|
<div className="mb-4 rounded-lg border border-dashed border-[hsla(var(--app-border))] p-2">
|
||||||
|
<div
|
||||||
|
className="flex cursor-pointer items-center gap-3"
|
||||||
|
onClick={handleClick}
|
||||||
|
>
|
||||||
|
{loading && (
|
||||||
|
<Loader className="h-4 w-4 animate-spin text-[hsla(var(--primary-bg))]" />
|
||||||
|
)}
|
||||||
|
<button className="flex items-center gap-2 focus:outline-none">
|
||||||
|
{isExpanded ? (
|
||||||
|
<ChevronUp className="h-4 w-4" />
|
||||||
|
) : (
|
||||||
|
<ChevronDown className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
<span className="font-medium">
|
||||||
|
{' '}
|
||||||
|
View result from <span className="font-bold">{name}</span>
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{isExpanded && (
|
||||||
|
<div className="mt-2 overflow-x-hidden pl-6 text-[hsla(var(--text-secondary))]">
|
||||||
|
<span>{result.trim()} </span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ToolCallBlock
|
||||||
@ -18,6 +18,8 @@ import ImageMessage from './ImageMessage'
|
|||||||
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
import { MarkdownTextMessage } from './MarkdownTextMessage'
|
||||||
import ThinkingBlock from './ThinkingBlock'
|
import ThinkingBlock from './ThinkingBlock'
|
||||||
|
|
||||||
|
import ToolCallBlock from './ToolCallBlock'
|
||||||
|
|
||||||
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
|
||||||
import {
|
import {
|
||||||
editMessageAtom,
|
editMessageAtom,
|
||||||
@ -65,16 +67,21 @@ const MessageContainer: React.FC<
|
|||||||
[props.content]
|
[props.content]
|
||||||
)
|
)
|
||||||
|
|
||||||
const attachedFile = useMemo(() => 'attachments' in props, [props])
|
const attachedFile = useMemo(
|
||||||
|
() => 'attachments' in props && props.attachments?.length,
|
||||||
|
[props]
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'group relative mx-auto px-4 py-2',
|
'group relative mx-auto px-4',
|
||||||
|
!(props.metadata && 'parent_id' in props.metadata) && 'py-2',
|
||||||
chatWidth === 'compact' && 'max-w-[700px]',
|
chatWidth === 'compact' && 'max-w-[700px]',
|
||||||
isUser && 'pb-4 pt-0'
|
isUser && 'pb-4 pt-0'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
|
{!(props.metadata && 'parent_id' in props.metadata) && (
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'mb-2 flex items-center justify-start',
|
'mb-2 flex items-center justify-start',
|
||||||
@ -105,17 +112,17 @@ const MessageContainer: React.FC<
|
|||||||
displayDate(props.created_at ?? Date.now() / 1000)}
|
displayDate(props.created_at ?? Date.now() / 1000)}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div className="flex w-full flex-col ">
|
<div className="flex w-full flex-col">
|
||||||
<div
|
<div
|
||||||
className={twMerge(
|
className={twMerge(
|
||||||
'absolute right-0 order-1 flex cursor-pointer items-center justify-start gap-x-2 transition-all',
|
'absolute right-0 order-1 flex cursor-pointer items-center justify-start gap-x-2 transition-all',
|
||||||
isUser
|
twMerge(
|
||||||
? twMerge(
|
'hidden group-hover:absolute group-hover:-bottom-4 group-hover:right-4 group-hover:z-50 group-hover:flex',
|
||||||
'hidden group-hover:absolute group-hover:right-4 group-hover:top-4 group-hover:z-50 group-hover:flex',
|
|
||||||
image && 'group-hover:-top-2'
|
image && 'group-hover:-top-2'
|
||||||
)
|
),
|
||||||
: 'relative left-0 order-2 flex w-full justify-between opacity-0 group-hover:opacity-100',
|
|
||||||
props.isCurrentMessage && 'opacity-100'
|
props.isCurrentMessage && 'opacity-100'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@ -179,6 +186,22 @@ const MessageContainer: React.FC<
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
{props.metadata &&
|
||||||
|
'tool_calls' in props.metadata &&
|
||||||
|
Array.isArray(props.metadata.tool_calls) &&
|
||||||
|
props.metadata.tool_calls.length && (
|
||||||
|
<>
|
||||||
|
{props.metadata.tool_calls.map((toolCall) => (
|
||||||
|
<ToolCallBlock
|
||||||
|
id={toolCall.tool?.id}
|
||||||
|
name={toolCall.tool?.function?.name ?? ''}
|
||||||
|
key={toolCall.tool?.id}
|
||||||
|
result={JSON.stringify(toolCall.response)}
|
||||||
|
loading={toolCall.state === 'pending'}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import {
|
|||||||
Thread,
|
Thread,
|
||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
import { ChatCompletionMessage as OAIChatCompletionMessage } from 'openai/resources/chat'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
@ -45,12 +46,26 @@ export class MessageRequestBuilder {
|
|||||||
this.tools = tools
|
this.tools = tools
|
||||||
}
|
}
|
||||||
|
|
||||||
pushAssistantMessage(message: string) {
|
pushAssistantMessage(message: OAIChatCompletionMessage) {
|
||||||
|
const { content, refusal, ...rest } = message
|
||||||
|
const normalizedMessage = {
|
||||||
|
...rest,
|
||||||
|
...(content ? { content } : {}),
|
||||||
|
...(refusal ? { refusal } : {}),
|
||||||
|
}
|
||||||
|
this.messages = [
|
||||||
|
...this.messages,
|
||||||
|
normalizedMessage as ChatCompletionMessage,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
pushToolMessage(message: string, toolCallId: string) {
|
||||||
this.messages = [
|
this.messages = [
|
||||||
...this.messages,
|
...this.messages,
|
||||||
{
|
{
|
||||||
role: ChatCompletionRole.Assistant,
|
role: ChatCompletionRole.Tool,
|
||||||
content: message,
|
content: message,
|
||||||
|
tool_call_id: toolCallId,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -140,40 +155,13 @@ export class MessageRequestBuilder {
|
|||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizeMessages = (
|
|
||||||
messages: ChatCompletionMessage[]
|
|
||||||
): ChatCompletionMessage[] => {
|
|
||||||
const stack = new Stack<ChatCompletionMessage>()
|
|
||||||
for (const message of messages) {
|
|
||||||
if (stack.isEmpty()) {
|
|
||||||
stack.push(message)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
const topMessage = stack.peek()
|
|
||||||
|
|
||||||
if (message.role === topMessage.role) {
|
|
||||||
// add an empty message
|
|
||||||
stack.push({
|
|
||||||
role:
|
|
||||||
topMessage.role === ChatCompletionRole.User
|
|
||||||
? ChatCompletionRole.Assistant
|
|
||||||
: ChatCompletionRole.User,
|
|
||||||
content: '.', // some model requires not empty message
|
|
||||||
})
|
|
||||||
}
|
|
||||||
stack.push(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return stack.reverseOutput()
|
|
||||||
}
|
|
||||||
|
|
||||||
build(): MessageRequest {
|
build(): MessageRequest {
|
||||||
return {
|
return {
|
||||||
id: this.msgId,
|
id: this.msgId,
|
||||||
type: this.type,
|
type: this.type,
|
||||||
attachments: [],
|
attachments: [],
|
||||||
threadId: this.thread.id,
|
threadId: this.thread.id,
|
||||||
messages: this.normalizeMessages(this.messages),
|
messages: this.messages,
|
||||||
model: this.model,
|
model: this.model,
|
||||||
thread: this.thread,
|
thread: this.thread,
|
||||||
tools: this.tools,
|
tools: this.tools,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user