feat: Jan supports multiple assistants (#5024)

* feat: Jan supports multiple assistants

* chore: persists current assistant to threads.json

* chore: update assistant persistence

* chore: simplify persistence objects
This commit is contained in:
Louis 2025-05-20 00:57:26 +07:00 committed by GitHub
parent f6433544af
commit 2dac53e9ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 160 additions and 45 deletions

View File

@ -27,8 +27,8 @@ export type Thread = {
* @stored * @stored
*/ */
export type ThreadAssistantInfo = { export type ThreadAssistantInfo = {
assistant_id: string id: string
assistant_name: string name: string
model: ModelInfo model: ModelInfo
instructions?: string instructions?: string
tools?: AssistantTool[] tools?: AssistantTool[]

View File

@ -97,8 +97,8 @@ pub struct ImageContentValue {
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ThreadAssistantInfo { pub struct ThreadAssistantInfo {
pub assistant_id: String, pub id: String,
pub assistant_name: String, pub name: String,
pub model: ModelInfo, pub model: ModelInfo,
pub instructions: Option<String>, pub instructions: Option<String>,
pub tools: Option<Vec<AssistantTool>>, pub tools: Option<Vec<AssistantTool>>,
@ -456,16 +456,16 @@ pub async fn modify_thread_assistant<R: Runtime>(
serde_json::from_str(&data).map_err(|e| e.to_string())? serde_json::from_str(&data).map_err(|e| e.to_string())?
}; };
let assistant_id = assistant let assistant_id = assistant
.get("assistant_id") .get("id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or("Missing assistant_id")?; .ok_or("Missing id")?;
if let Some(assistants) = thread if let Some(assistants) = thread
.get_mut("assistants") .get_mut("assistants")
.and_then(|a: &mut serde_json::Value| a.as_array_mut()) .and_then(|a: &mut serde_json::Value| a.as_array_mut())
{ {
if let Some(index) = assistants if let Some(index) = assistants
.iter() .iter()
.position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id)) .position(|a| a.get("id").and_then(|v| v.as_str()) == Some(assistant_id))
{ {
assistants[index] = assistant.clone(); assistants[index] = assistant.clone();
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?; let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;

View File

@ -9,20 +9,25 @@ import {
import { useAssistant } from '@/hooks/useAssistant' import { useAssistant } from '@/hooks/useAssistant'
import AddEditAssistant from './dialogs/AddEditAssistant' import AddEditAssistant from './dialogs/AddEditAssistant'
import { IconCirclePlus, IconSettings } from '@tabler/icons-react' import { IconCirclePlus, IconSettings } from '@tabler/icons-react'
import { useThreads } from '@/hooks/useThreads'
const DropdownAssistant = () => { const DropdownAssistant = () => {
const { assistants, addAssistant, updateAssistant } = useAssistant() const {
assistants,
currentAssistant,
addAssistant,
updateAssistant,
setCurrentAssistant,
} = useAssistant()
const { updateCurrentThreadAssistant } = useThreads()
const [dropdownOpen, setDropdownOpen] = useState(false) const [dropdownOpen, setDropdownOpen] = useState(false)
const [dialogOpen, setDialogOpen] = useState(false) const [dialogOpen, setDialogOpen] = useState(false)
const [editingAssistantId, setEditingAssistantId] = useState<string | null>( const [editingAssistantId, setEditingAssistantId] = useState<string | null>(
null null
) )
const [selectedAssistantId, setSelectedAssistantId] = useState<string | null>(
assistants[0]?.id || null
)
const selectedAssistant = const selectedAssistant =
assistants.find((a) => a.id === selectedAssistantId) || assistants[0] assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
return ( return (
<> <>
@ -63,7 +68,10 @@ const DropdownAssistant = () => {
<DropdownMenuItem className="flex justify-between items-center"> <DropdownMenuItem className="flex justify-between items-center">
<span <span
className="truncate text-main-view-fg/70 flex-1 cursor-pointer" className="truncate text-main-view-fg/70 flex-1 cursor-pointer"
onClick={() => setSelectedAssistantId(assistant.id)} onClick={() => {
setCurrentAssistant(assistant)
updateCurrentThreadAssistant(assistant)
}}
> >
{assistant.name} {assistant.name}
</span> </span>

View File

@ -2,24 +2,17 @@ import { localStoregeKey } from '@/constants/localStorage'
import { create } from 'zustand' import { create } from 'zustand'
import { persist } from 'zustand/middleware' import { persist } from 'zustand/middleware'
export type Assistant = {
avatar?: string
id: string
name: string
created_at: number
description?: string
instructions: string
parameters: Record<string, unknown>
}
interface AssistantState { interface AssistantState {
assistants: Assistant[] assistants: Assistant[]
currentAssistant: Assistant
addAssistant: (assistant: Assistant) => void addAssistant: (assistant: Assistant) => void
updateAssistant: (assistant: Assistant) => void updateAssistant: (assistant: Assistant) => void
deleteAssistant: (id: string) => void deleteAssistant: (id: string) => void
setCurrentAssistant: (assistant: Assistant) => void
} }
const defaultAssistant: Assistant = { export const defaultAssistant: Assistant = {
avatar: '', avatar: '',
id: 'jan', id: 'jan',
name: 'Jan', name: 'Jan',
@ -33,6 +26,7 @@ export const useAssistant = create<AssistantState>()(
persist( persist(
(set, get) => ({ (set, get) => ({
assistants: [defaultAssistant], assistants: [defaultAssistant],
currentAssistant: defaultAssistant,
addAssistant: (assistant) => addAssistant: (assistant) =>
set({ assistants: [...get().assistants, assistant] }), set({ assistants: [...get().assistants, assistant] }),
updateAssistant: (assistant) => updateAssistant: (assistant) =>
@ -43,6 +37,9 @@ export const useAssistant = create<AssistantState>()(
}), }),
deleteAssistant: (id) => deleteAssistant: (id) =>
set({ assistants: get().assistants.filter((a) => a.id !== id) }), set({ assistants: get().assistants.filter((a) => a.id !== id) }),
setCurrentAssistant: (assistant) => {
set({ currentAssistant: assistant })
},
}), }),
{ {
name: localStoregeKey.assistant, name: localStoregeKey.assistant,

View File

@ -18,10 +18,12 @@ import {
} from '@/lib/completion' } from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages' import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
export const useChat = () => { export const useChat = () => {
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
const { tools } = useAppState() const { tools } = useAppState()
const { currentAssistant } = useAssistant()
const { getProviderByName, selectedModel, selectedProvider } = const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider() useModelProvider()
@ -43,7 +45,8 @@ export const useChat = () => {
id: selectedModel?.id ?? defaultModel(selectedProvider), id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider, provider: selectedProvider,
}, },
prompt prompt,
currentAssistant
) )
router.navigate({ router.navigate({
to: route.threadsDetail, to: route.threadsDetail,
@ -58,6 +61,7 @@ export const useChat = () => {
router, router,
selectedModel?.id, selectedModel?.id,
selectedProvider, selectedProvider,
currentAssistant,
]) ])
const sendMessage = useCallback( const sendMessage = useCallback(
@ -79,6 +83,8 @@ export const useChat = () => {
} }
const builder = new CompletionMessagesBuilder() const builder = new CompletionMessagesBuilder()
if (currentAssistant?.instructions?.length > 0)
builder.addSystemMessage(currentAssistant?.instructions || '')
// REMARK: Would it possible to not attach the entire message history to the request? // REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here // TODO: If not amend messages history here
builder.addUserMessage(message) builder.addUserMessage(message)
@ -143,9 +149,10 @@ export const useChat = () => {
addMessage, addMessage,
setPrompt, setPrompt,
selectedModel, selectedModel,
tools, currentAssistant?.instructions,
setAbortController, setAbortController,
updateLoadingModel, updateLoadingModel,
tools,
] ]
) )

View File

@ -17,9 +17,14 @@ type ThreadState = {
deleteAllThreads: () => void deleteAllThreads: () => void
unstarAllThreads: () => void unstarAllThreads: () => void
setCurrentThreadId: (threadId?: string) => void setCurrentThreadId: (threadId?: string) => void
createThread: (model: ThreadModel, title?: string) => Promise<Thread> createThread: (
model: ThreadModel,
title?: string,
assistant?: Assistant
) => Promise<Thread>
updateCurrentThreadModel: (model: ThreadModel) => void updateCurrentThreadModel: (model: ThreadModel) => void
getFilteredThreads: (searchTerm: string) => Thread[] getFilteredThreads: (searchTerm: string) => Thread[]
updateCurrentThreadAssistant: (assistant: Assistant) => void
searchIndex: Fuse<Thread> | null searchIndex: Fuse<Thread> | null
} }
@ -152,18 +157,18 @@ export const useThreads = create<ThreadState>()(
setCurrentThreadId: (threadId) => { setCurrentThreadId: (threadId) => {
set({ currentThreadId: threadId }) set({ currentThreadId: threadId })
}, },
createThread: async (model, title) => { createThread: async (model, title, assistant) => {
const newThread: Thread = { const newThread: Thread = {
id: ulid(), id: ulid(),
title: title ?? 'New Thread', title: title ?? 'New Thread',
model, model,
order: 1, order: 1,
updated: Date.now() / 1000, updated: Date.now() / 1000,
assistants: assistant ? [assistant] : [],
} }
set((state) => ({ set((state) => ({
searchIndex: new Fuse(Object.values(state.threads), fuseOptions), searchIndex: new Fuse(Object.values(state.threads), fuseOptions),
})) }))
console.log('newThread', newThread)
return await createThread(newThread).then((createdThread) => { return await createThread(newThread).then((createdThread) => {
set((state) => ({ set((state) => ({
threads: { threads: {
@ -175,6 +180,26 @@ export const useThreads = create<ThreadState>()(
return createdThread return createdThread
}) })
}, },
updateCurrentThreadAssistant: (assistant) => {
set((state) => {
if (!state.currentThreadId) return { ...state }
const currentThread = state.getCurrentThread()
if (currentThread)
updateThread({
...currentThread,
assistants: [{ ...assistant, model: currentThread.model }],
})
return {
threads: {
...state.threads,
[state.currentThreadId as string]: {
...state.threads[state.currentThreadId as string],
assistants: [assistant],
},
},
}
})
},
updateCurrentThreadModel: (model) => { updateCurrentThreadModel: (model) => {
set((state) => { set((state) => {
if (!state.currentThreadId) return { ...state } if (!state.currentThreadId) return { ...state }

View File

@ -1,11 +1,30 @@
import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageParam } from 'token.js'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
/**
* @fileoverview Helper functions for creating chat completion request.
* These functions are used to create chat completion request objects
*/
export class CompletionMessagesBuilder { export class CompletionMessagesBuilder {
private messages: ChatCompletionMessageParam[] = [] private messages: ChatCompletionMessageParam[] = []
constructor() {} constructor() {}
/**
* Add a system message to the messages array.
* @param content - The content of the system message.
*/
addSystemMessage(content: string) {
this.messages.push({
role: 'system',
content: content,
})
}
/**
* Add a user message to the messages array.
* @param content - The content of the user message.
*/
addUserMessage(content: string) { addUserMessage(content: string) {
this.messages.push({ this.messages.push({
role: 'user', role: 'user',
@ -13,15 +32,30 @@ export class CompletionMessagesBuilder {
}) })
} }
addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) { /**
* Add an assistant message to the messages array.
* @param content - The content of the assistant message.
* @param refusal - Optional refusal message.
* @param calls - Optional tool calls associated with the message.
*/
addAssistantMessage(
content: string,
refusal?: string,
calls?: ChatCompletionMessageToolCall[]
) {
this.messages.push({ this.messages.push({
role: 'assistant', role: 'assistant',
content: content, content: content,
refusal: refusal, refusal: refusal,
tool_calls: calls tool_calls: calls,
}) })
} }
/**
* Add a tool message to the messages array.
* @param content - The content of the tool message.
* @param toolCallId - The ID of the tool call associated with the message.
*/
addToolMessage(content: string, toolCallId: string) { addToolMessage(content: string, toolCallId: string) {
this.messages.push({ this.messages.push({
role: 'tool', role: 'tool',
@ -30,6 +64,10 @@ export class CompletionMessagesBuilder {
}) })
} }
/**
* Return the messages array.
* @returns The array of chat completion messages.
*/
getMessages(): ChatCompletionMessageParam[] { getMessages(): ChatCompletionMessageParam[] {
return this.messages return this.messages
} }

View File

@ -15,6 +15,8 @@ type SearchParams = {
} }
} }
import DropdownAssistant from '@/containers/DropdownAssistant' import DropdownAssistant from '@/containers/DropdownAssistant'
import { useEffect } from 'react'
import { useThreads } from '@/hooks/useThreads'
export const Route = createFileRoute(route.home as any)({ export const Route = createFileRoute(route.home as any)({
component: Index, component: Index,
@ -28,6 +30,7 @@ function Index() {
const { providers } = useModelProvider() const { providers } = useModelProvider()
const search = useSearch({ from: route.home as any }) const search = useSearch({ from: route.home as any })
const selectedModel = search.model const selectedModel = search.model
const { setCurrentThreadId } = useThreads()
// Conditional to check if there are any valid providers // Conditional to check if there are any valid providers
// required min 1 api_key or 1 model in llama.cpp // required min 1 api_key or 1 model in llama.cpp
@ -37,6 +40,10 @@ function Index() {
(provider.provider === 'llama.cpp' && provider.models.length) (provider.provider === 'llama.cpp' && provider.models.length)
) )
useEffect(() => {
setCurrentThreadId(undefined)
}, [setCurrentThreadId])
if (!hasValidProviders) { if (!hasValidProviders) {
return <SetupScreen /> return <SetupScreen />
} }

View File

@ -16,6 +16,7 @@ import { useMessages } from '@/hooks/useMessages'
import { fetchMessages } from '@/services/messages' import { fetchMessages } from '@/services/messages'
import { useAppState } from '@/hooks/useAppState' import { useAppState } from '@/hooks/useAppState'
import DropdownAssistant from '@/containers/DropdownAssistant' import DropdownAssistant from '@/containers/DropdownAssistant'
import { useAssistant } from '@/hooks/useAssistant'
// as route.threadsDetail // as route.threadsDetail
export const Route = createFileRoute('/threads/$threadId')({ export const Route = createFileRoute('/threads/$threadId')({
@ -28,6 +29,7 @@ function ThreadDetail() {
const [isAtBottom, setIsAtBottom] = useState(true) const [isAtBottom, setIsAtBottom] = useState(true)
const lastScrollTopRef = useRef(0) const lastScrollTopRef = useRef(0)
const { currentThreadId, getThreadById, setCurrentThreadId } = useThreads() const { currentThreadId, getThreadById, setCurrentThreadId } = useThreads()
const { setCurrentAssistant, assistants } = useAssistant()
const { setMessages } = useMessages() const { setMessages } = useMessages()
const { streamingContent, loadingModel } = useAppState() const { streamingContent, loadingModel } = useAppState()
@ -45,9 +47,16 @@ function ThreadDetail() {
const isFirstRender = useRef(true) const isFirstRender = useRef(true)
useEffect(() => { useEffect(() => {
if (currentThreadId !== threadId) setCurrentThreadId(threadId) if (currentThreadId !== threadId) {
setCurrentThreadId(threadId)
const assistant = assistants.find(
(assistant) => assistant.id === thread?.assistants?.[0]?.id
)
if (assistant) setCurrentAssistant(assistant)
}
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [threadId, currentThreadId]) }, [threadId, currentThreadId, assistants])
useEffect(() => { useEffect(() => {
fetchMessages(threadId).then((fetchedMessages) => { fetchMessages(threadId).then((fetchedMessages) => {

View File

@ -1,3 +1,4 @@
import { defaultAssistant } from '@/hooks/useAssistant'
import { ExtensionManager } from '@/lib/extension' import { ExtensionManager } from '@/lib/extension'
import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core' import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core'
@ -20,9 +21,10 @@ export const fetchThreads = async (): Promise<Thread[]> => {
order: e.metadata?.order, order: e.metadata?.order,
isFavorite: e.metadata?.is_favorite, isFavorite: e.metadata?.is_favorite,
model: { model: {
id: e.assistants?.[0]?.model.id, id: e.assistants?.[0]?.model?.id,
provider: e.assistants?.[0]?.model.engine, provider: e.assistants?.[0]?.model?.engine,
}, },
assistants: e.assistants ?? [defaultAssistant],
} as Thread } as Thread
}) })
}) })
@ -50,8 +52,8 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
id: thread.model?.id ?? '*', id: thread.model?.id ?? '*',
engine: thread.model?.provider ?? 'llama.cpp', engine: thread.model?.provider ?? 'llama.cpp',
}, },
assistant_id: 'jan', id: 'jan',
assistant_name: 'Jan', name: 'Jan',
}, },
], ],
metadata: { metadata: {
@ -63,10 +65,11 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
...e, ...e,
updated: e.updated, updated: e.updated,
model: { model: {
id: e.assistants?.[0]?.model.id, id: e.assistants?.[0]?.model?.id,
provider: e.assistants?.[0]?.model.engine, provider: e.assistants?.[0]?.model?.engine,
}, },
order: 1, order: 1,
assistants: e.assistants ?? [defaultAssistant],
} as Thread } as Thread
}) })
.catch(() => thread) ?? thread .catch(() => thread) ?? thread
@ -82,14 +85,24 @@ export const updateThread = (thread: Thread) => {
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread({ ?.modifyThread({
...thread, ...thread,
assistants: [ assistants: thread.assistants?.map((e) => {
return {
model: {
id: thread.model?.id ?? '*',
engine: thread.model?.provider ?? 'llama.cpp',
},
id: e.id,
name: e.name,
instructions: e.instructions,
}
}) ?? [
{ {
model: { model: {
id: thread.model?.id ?? '*', id: thread.model?.id ?? '*',
engine: (thread.model?.provider ?? 'llama.cpp'), engine: thread.model?.provider ?? 'llama.cpp',
}, },
assistant_id: 'jan', id: 'jan',
assistant_name: 'Jan', name: 'Jan',
}, },
], ],
metadata: { metadata: {

View File

@ -36,6 +36,7 @@ type ThreadModel = {
} }
type Thread = { type Thread = {
assistants?: ThreadAssistantInfo[]
id: string id: string
title: string title: string
isFavorite?: boolean isFavorite?: boolean
@ -44,3 +45,13 @@ type Thread = {
updated: number updated: number
order?: number order?: number
} }
type Assistant = {
avatar?: string
id: string
name: string
created_at: number
description?: string
instructions: string
parameters: Record<string, unknown>
}