diff --git a/core/src/types/thread/threadEntity.ts b/core/src/types/thread/threadEntity.ts index dd88b10ec..ab61787e6 100644 --- a/core/src/types/thread/threadEntity.ts +++ b/core/src/types/thread/threadEntity.ts @@ -27,8 +27,8 @@ export type Thread = { * @stored */ export type ThreadAssistantInfo = { - assistant_id: string - assistant_name: string + id: string + name: string model: ModelInfo instructions?: string tools?: AssistantTool[] diff --git a/src-tauri/src/core/threads.rs b/src-tauri/src/core/threads.rs index 051837992..3554e287c 100644 --- a/src-tauri/src/core/threads.rs +++ b/src-tauri/src/core/threads.rs @@ -97,8 +97,8 @@ pub struct ImageContentValue { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ThreadAssistantInfo { - pub assistant_id: String, - pub assistant_name: String, + pub id: String, + pub name: String, pub model: ModelInfo, pub instructions: Option, pub tools: Option>, @@ -456,16 +456,16 @@ pub async fn modify_thread_assistant( serde_json::from_str(&data).map_err(|e| e.to_string())? }; let assistant_id = assistant - .get("assistant_id") + .get("id") .and_then(|v| v.as_str()) - .ok_or("Missing assistant_id")?; + .ok_or("Missing id")?; if let Some(assistants) = thread .get_mut("assistants") .and_then(|a: &mut serde_json::Value| a.as_array_mut()) { if let Some(index) = assistants .iter() - .position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id)) + .position(|a| a.get("id").and_then(|v| v.as_str()) == Some(assistant_id)) { assistants[index] = assistant.clone(); let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; diff --git a/web-app/src/containers/DropdownAssistant.tsx b/web-app/src/containers/DropdownAssistant.tsx index 5c6b12c9a..51cdc4c7b 100644 --- a/web-app/src/containers/DropdownAssistant.tsx +++ b/web-app/src/containers/DropdownAssistant.tsx @@ -9,20 +9,25 @@ import { import { useAssistant } from '@/hooks/useAssistant' import AddEditAssistant from './dialogs/AddEditAssistant' import { IconCirclePlus, IconSettings } from '@tabler/icons-react' +import { useThreads } from '@/hooks/useThreads' const DropdownAssistant = () => { - const { assistants, addAssistant, updateAssistant } = useAssistant() + const { + assistants, + currentAssistant, + addAssistant, + updateAssistant, + setCurrentAssistant, + } = useAssistant() + const { updateCurrentThreadAssistant } = useThreads() const [dropdownOpen, setDropdownOpen] = useState(false) const [dialogOpen, setDialogOpen] = useState(false) const [editingAssistantId, setEditingAssistantId] = useState( null ) - const [selectedAssistantId, setSelectedAssistantId] = useState( - assistants[0]?.id || null - ) const selectedAssistant = - assistants.find((a) => a.id === selectedAssistantId) || assistants[0] + assistants.find((a) => a.id === currentAssistant.id) || assistants[0] return ( <> @@ -63,7 +68,10 @@ const DropdownAssistant = () => { setSelectedAssistantId(assistant.id)} + onClick={() => { + setCurrentAssistant(assistant) + updateCurrentThreadAssistant(assistant) + }} > {assistant.name} diff --git a/web-app/src/hooks/useAssistant.ts b/web-app/src/hooks/useAssistant.ts index 1e2e0835e..e22083029 100644 --- a/web-app/src/hooks/useAssistant.ts +++ b/web-app/src/hooks/useAssistant.ts @@ -2,24 +2,17 @@ import { localStoregeKey } from '@/constants/localStorage' import { create } from 'zustand' import { persist } from 'zustand/middleware' -export type Assistant = { - avatar?: string - id: string - name: string - created_at: number - description?: string - instructions: string - parameters: Record -} interface AssistantState { assistants: Assistant[] + currentAssistant: Assistant addAssistant: (assistant: Assistant) => void updateAssistant: (assistant: Assistant) => void deleteAssistant: (id: string) => void + setCurrentAssistant: (assistant: Assistant) => void } -const defaultAssistant: Assistant = { +export const defaultAssistant: Assistant = { avatar: '', id: 'jan', name: 'Jan', @@ -33,6 +26,7 @@ export const useAssistant = create()( persist( (set, get) => ({ assistants: [defaultAssistant], + currentAssistant: defaultAssistant, addAssistant: (assistant) => set({ assistants: [...get().assistants, assistant] }), updateAssistant: (assistant) => @@ -43,6 +37,9 @@ export const useAssistant = create()( }), deleteAssistant: (id) => set({ assistants: get().assistants.filter((a) => a.id !== id) }), + setCurrentAssistant: (assistant) => { + set({ currentAssistant: assistant }) + }, }), { name: localStoregeKey.assistant, diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 12effe82d..17220f88d 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -18,10 +18,12 @@ import { } from '@/lib/completion' import { CompletionMessagesBuilder } from '@/lib/messages' import { ChatCompletionMessageToolCall } from 'openai/resources' +import { useAssistant } from './useAssistant' export const useChat = () => { const { prompt, setPrompt } = usePrompt() const { tools } = useAppState() + const { currentAssistant } = useAssistant() const { getProviderByName, selectedModel, selectedProvider } = useModelProvider() @@ -43,7 +45,8 @@ export const useChat = () => { id: selectedModel?.id ?? defaultModel(selectedProvider), provider: selectedProvider, }, - prompt + prompt, + currentAssistant ) router.navigate({ to: route.threadsDetail, @@ -58,6 +61,7 @@ export const useChat = () => { router, selectedModel?.id, selectedProvider, + currentAssistant, ]) const sendMessage = useCallback( @@ -79,6 +83,8 @@ export const useChat = () => { } const builder = new CompletionMessagesBuilder() + if (currentAssistant?.instructions?.length > 0) + builder.addSystemMessage(currentAssistant?.instructions || '') // REMARK: Would it possible to not attach the entire message history to the request? // TODO: If not amend messages history here builder.addUserMessage(message) @@ -143,9 +149,10 @@ export const useChat = () => { addMessage, setPrompt, selectedModel, - tools, + currentAssistant?.instructions, setAbortController, updateLoadingModel, + tools, ] ) diff --git a/web-app/src/hooks/useThreads.ts b/web-app/src/hooks/useThreads.ts index f91dedf95..f83a9d69d 100644 --- a/web-app/src/hooks/useThreads.ts +++ b/web-app/src/hooks/useThreads.ts @@ -17,9 +17,14 @@ type ThreadState = { deleteAllThreads: () => void unstarAllThreads: () => void setCurrentThreadId: (threadId?: string) => void - createThread: (model: ThreadModel, title?: string) => Promise + createThread: ( + model: ThreadModel, + title?: string, + assistant?: Assistant + ) => Promise updateCurrentThreadModel: (model: ThreadModel) => void getFilteredThreads: (searchTerm: string) => Thread[] + updateCurrentThreadAssistant: (assistant: Assistant) => void searchIndex: Fuse | null } @@ -152,18 +157,18 @@ export const useThreads = create()( setCurrentThreadId: (threadId) => { set({ currentThreadId: threadId }) }, - createThread: async (model, title) => { + createThread: async (model, title, assistant) => { const newThread: Thread = { id: ulid(), title: title ?? 'New Thread', model, order: 1, updated: Date.now() / 1000, + assistants: assistant ? [assistant] : [], } set((state) => ({ searchIndex: new Fuse(Object.values(state.threads), fuseOptions), })) - console.log('newThread', newThread) return await createThread(newThread).then((createdThread) => { set((state) => ({ threads: { @@ -175,6 +180,26 @@ export const useThreads = create()( return createdThread }) }, + updateCurrentThreadAssistant: (assistant) => { + set((state) => { + if (!state.currentThreadId) return { ...state } + const currentThread = state.getCurrentThread() + if (currentThread) + updateThread({ + ...currentThread, + assistants: [{ ...assistant, model: currentThread.model }], + }) + return { + threads: { + ...state.threads, + [state.currentThreadId as string]: { + ...state.threads[state.currentThreadId as string], + assistants: [assistant], + }, + }, + } + }) + }, updateCurrentThreadModel: (model) => { set((state) => { if (!state.currentThreadId) return { ...state } diff --git a/web-app/src/lib/messages.ts b/web-app/src/lib/messages.ts index 1175a6549..2f46b8eb0 100644 --- a/web-app/src/lib/messages.ts +++ b/web-app/src/lib/messages.ts @@ -1,11 +1,30 @@ import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageToolCall } from 'openai/resources' +/** + * @fileoverview Helper functions for creating chat completion request. + * These functions are used to create chat completion request objects + */ export class CompletionMessagesBuilder { private messages: ChatCompletionMessageParam[] = [] constructor() {} + /** + * Add a system message to the messages array. + * @param content - The content of the system message. + */ + addSystemMessage(content: string) { + this.messages.push({ + role: 'system', + content: content, + }) + } + + /** + * Add a user message to the messages array. + * @param content - The content of the user message. + */ addUserMessage(content: string) { this.messages.push({ role: 'user', @@ -13,15 +32,30 @@ export class CompletionMessagesBuilder { }) } - addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) { + /** + * Add an assistant message to the messages array. + * @param content - The content of the assistant message. + * @param refusal - Optional refusal message. + * @param calls - Optional tool calls associated with the message. + */ + addAssistantMessage( + content: string, + refusal?: string, + calls?: ChatCompletionMessageToolCall[] + ) { this.messages.push({ role: 'assistant', content: content, refusal: refusal, - tool_calls: calls + tool_calls: calls, }) } + /** + * Add a tool message to the messages array. + * @param content - The content of the tool message. + * @param toolCallId - The ID of the tool call associated with the message. + */ addToolMessage(content: string, toolCallId: string) { this.messages.push({ role: 'tool', @@ -30,6 +64,10 @@ export class CompletionMessagesBuilder { }) } + /** + * Return the messages array. + * @returns The array of chat completion messages. + */ getMessages(): ChatCompletionMessageParam[] { return this.messages } diff --git a/web-app/src/routes/index.tsx b/web-app/src/routes/index.tsx index 667964640..8297c1022 100644 --- a/web-app/src/routes/index.tsx +++ b/web-app/src/routes/index.tsx @@ -15,6 +15,8 @@ type SearchParams = { } } import DropdownAssistant from '@/containers/DropdownAssistant' +import { useEffect } from 'react' +import { useThreads } from '@/hooks/useThreads' export const Route = createFileRoute(route.home as any)({ component: Index, @@ -28,6 +30,7 @@ function Index() { const { providers } = useModelProvider() const search = useSearch({ from: route.home as any }) const selectedModel = search.model + const { setCurrentThreadId } = useThreads() // Conditional to check if there are any valid providers // required min 1 api_key or 1 model in llama.cpp @@ -37,6 +40,10 @@ function Index() { (provider.provider === 'llama.cpp' && provider.models.length) ) + useEffect(() => { + setCurrentThreadId(undefined) + }, [setCurrentThreadId]) + if (!hasValidProviders) { return } diff --git a/web-app/src/routes/threads/$threadId.tsx b/web-app/src/routes/threads/$threadId.tsx index d99350ca6..b02e3183e 100644 --- a/web-app/src/routes/threads/$threadId.tsx +++ b/web-app/src/routes/threads/$threadId.tsx @@ -16,6 +16,7 @@ import { useMessages } from '@/hooks/useMessages' import { fetchMessages } from '@/services/messages' import { useAppState } from '@/hooks/useAppState' import DropdownAssistant from '@/containers/DropdownAssistant' +import { useAssistant } from '@/hooks/useAssistant' // as route.threadsDetail export const Route = createFileRoute('/threads/$threadId')({ @@ -28,6 +29,7 @@ function ThreadDetail() { const [isAtBottom, setIsAtBottom] = useState(true) const lastScrollTopRef = useRef(0) const { currentThreadId, getThreadById, setCurrentThreadId } = useThreads() + const { setCurrentAssistant, assistants } = useAssistant() const { setMessages } = useMessages() const { streamingContent, loadingModel } = useAppState() @@ -45,9 +47,16 @@ function ThreadDetail() { const isFirstRender = useRef(true) useEffect(() => { - if (currentThreadId !== threadId) setCurrentThreadId(threadId) + if (currentThreadId !== threadId) { + setCurrentThreadId(threadId) + const assistant = assistants.find( + (assistant) => assistant.id === thread?.assistants?.[0]?.id + ) + if (assistant) setCurrentAssistant(assistant) + } + // eslint-disable-next-line react-hooks/exhaustive-deps - }, [threadId, currentThreadId]) + }, [threadId, currentThreadId, assistants]) useEffect(() => { fetchMessages(threadId).then((fetchedMessages) => { diff --git a/web-app/src/services/threads.ts b/web-app/src/services/threads.ts index 1af6cfb1e..4b26779ce 100644 --- a/web-app/src/services/threads.ts +++ b/web-app/src/services/threads.ts @@ -1,3 +1,4 @@ +import { defaultAssistant } from '@/hooks/useAssistant' import { ExtensionManager } from '@/lib/extension' import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core' @@ -20,9 +21,10 @@ export const fetchThreads = async (): Promise => { order: e.metadata?.order, isFavorite: e.metadata?.is_favorite, model: { - id: e.assistants?.[0]?.model.id, - provider: e.assistants?.[0]?.model.engine, + id: e.assistants?.[0]?.model?.id, + provider: e.assistants?.[0]?.model?.engine, }, + assistants: e.assistants ?? [defaultAssistant], } as Thread }) }) @@ -50,8 +52,8 @@ export const createThread = async (thread: Thread): Promise => { id: thread.model?.id ?? '*', engine: thread.model?.provider ?? 'llama.cpp', }, - assistant_id: 'jan', - assistant_name: 'Jan', + id: 'jan', + name: 'Jan', }, ], metadata: { @@ -63,10 +65,11 @@ export const createThread = async (thread: Thread): Promise => { ...e, updated: e.updated, model: { - id: e.assistants?.[0]?.model.id, - provider: e.assistants?.[0]?.model.engine, + id: e.assistants?.[0]?.model?.id, + provider: e.assistants?.[0]?.model?.engine, }, order: 1, + assistants: e.assistants ?? [defaultAssistant], } as Thread }) .catch(() => thread) ?? thread @@ -82,14 +85,24 @@ export const updateThread = (thread: Thread) => { .get(ExtensionTypeEnum.Conversational) ?.modifyThread({ ...thread, - assistants: [ + assistants: thread.assistants?.map((e) => { + return { + model: { + id: thread.model?.id ?? '*', + engine: thread.model?.provider ?? 'llama.cpp', + }, + id: e.id, + name: e.name, + instructions: e.instructions, + } + }) ?? [ { model: { id: thread.model?.id ?? '*', - engine: (thread.model?.provider ?? 'llama.cpp'), + engine: thread.model?.provider ?? 'llama.cpp', }, - assistant_id: 'jan', - assistant_name: 'Jan', + id: 'jan', + name: 'Jan', }, ], metadata: { diff --git a/web-app/src/types/threads.d.ts b/web-app/src/types/threads.d.ts index e76e4b519..4f15039a5 100644 --- a/web-app/src/types/threads.d.ts +++ b/web-app/src/types/threads.d.ts @@ -31,11 +31,12 @@ type ThreadContent = { type ChatCompletionRole = 'system' | 'assistant' | 'user' | 'tool' type ThreadModel = { - id: string - provider: string - } + id: string + provider: string +} type Thread = { + assistants?: ThreadAssistantInfo[] id: string title: string isFavorite?: boolean @@ -44,3 +45,13 @@ type Thread = { updated: number order?: number } + +type Assistant = { + avatar?: string + id: string + name: string + created_at: number + description?: string + instructions: string + parameters: Record +}