feat: Jan supports multiple assistants (#5024)

* feat: Jan supports multiple assistants

* chore: persists current assistant to threads.json

* chore: update assistant persistence

* chore: simplify persistence objects
This commit is contained in:
Louis 2025-05-20 00:57:26 +07:00 committed by GitHub
parent f6433544af
commit 2dac53e9ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 160 additions and 45 deletions

View File

@ -27,8 +27,8 @@ export type Thread = {
* @stored
*/
export type ThreadAssistantInfo = {
assistant_id: string
assistant_name: string
id: string
name: string
model: ModelInfo
instructions?: string
tools?: AssistantTool[]

View File

@ -97,8 +97,8 @@ pub struct ImageContentValue {
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ThreadAssistantInfo {
pub assistant_id: String,
pub assistant_name: String,
pub id: String,
pub name: String,
pub model: ModelInfo,
pub instructions: Option<String>,
pub tools: Option<Vec<AssistantTool>>,
@ -456,16 +456,16 @@ pub async fn modify_thread_assistant<R: Runtime>(
serde_json::from_str(&data).map_err(|e| e.to_string())?
};
let assistant_id = assistant
.get("assistant_id")
.get("id")
.and_then(|v| v.as_str())
.ok_or("Missing assistant_id")?;
.ok_or("Missing id")?;
if let Some(assistants) = thread
.get_mut("assistants")
.and_then(|a: &mut serde_json::Value| a.as_array_mut())
{
if let Some(index) = assistants
.iter()
.position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id))
.position(|a| a.get("id").and_then(|v| v.as_str()) == Some(assistant_id))
{
assistants[index] = assistant.clone();
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;

View File

@ -9,20 +9,25 @@ import {
import { useAssistant } from '@/hooks/useAssistant'
import AddEditAssistant from './dialogs/AddEditAssistant'
import { IconCirclePlus, IconSettings } from '@tabler/icons-react'
import { useThreads } from '@/hooks/useThreads'
const DropdownAssistant = () => {
const { assistants, addAssistant, updateAssistant } = useAssistant()
const {
assistants,
currentAssistant,
addAssistant,
updateAssistant,
setCurrentAssistant,
} = useAssistant()
const { updateCurrentThreadAssistant } = useThreads()
const [dropdownOpen, setDropdownOpen] = useState(false)
const [dialogOpen, setDialogOpen] = useState(false)
const [editingAssistantId, setEditingAssistantId] = useState<string | null>(
null
)
const [selectedAssistantId, setSelectedAssistantId] = useState<string | null>(
assistants[0]?.id || null
)
const selectedAssistant =
assistants.find((a) => a.id === selectedAssistantId) || assistants[0]
assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
return (
<>
@ -63,7 +68,10 @@ const DropdownAssistant = () => {
<DropdownMenuItem className="flex justify-between items-center">
<span
className="truncate text-main-view-fg/70 flex-1 cursor-pointer"
onClick={() => setSelectedAssistantId(assistant.id)}
onClick={() => {
setCurrentAssistant(assistant)
updateCurrentThreadAssistant(assistant)
}}
>
{assistant.name}
</span>

View File

@ -2,24 +2,17 @@ import { localStoregeKey } from '@/constants/localStorage'
import { create } from 'zustand'
import { persist } from 'zustand/middleware'
export type Assistant = {
avatar?: string
id: string
name: string
created_at: number
description?: string
instructions: string
parameters: Record<string, unknown>
}
interface AssistantState {
assistants: Assistant[]
currentAssistant: Assistant
addAssistant: (assistant: Assistant) => void
updateAssistant: (assistant: Assistant) => void
deleteAssistant: (id: string) => void
setCurrentAssistant: (assistant: Assistant) => void
}
const defaultAssistant: Assistant = {
export const defaultAssistant: Assistant = {
avatar: '',
id: 'jan',
name: 'Jan',
@ -33,6 +26,7 @@ export const useAssistant = create<AssistantState>()(
persist(
(set, get) => ({
assistants: [defaultAssistant],
currentAssistant: defaultAssistant,
addAssistant: (assistant) =>
set({ assistants: [...get().assistants, assistant] }),
updateAssistant: (assistant) =>
@ -43,6 +37,9 @@ export const useAssistant = create<AssistantState>()(
}),
deleteAssistant: (id) =>
set({ assistants: get().assistants.filter((a) => a.id !== id) }),
setCurrentAssistant: (assistant) => {
set({ currentAssistant: assistant })
},
}),
{
name: localStoregeKey.assistant,

View File

@ -18,10 +18,12 @@ import {
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const { tools } = useAppState()
const { currentAssistant } = useAssistant()
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
@ -43,7 +45,8 @@ export const useChat = () => {
id: selectedModel?.id ?? defaultModel(selectedProvider),
provider: selectedProvider,
},
prompt
prompt,
currentAssistant
)
router.navigate({
to: route.threadsDetail,
@ -58,6 +61,7 @@ export const useChat = () => {
router,
selectedModel?.id,
selectedProvider,
currentAssistant,
])
const sendMessage = useCallback(
@ -79,6 +83,8 @@ export const useChat = () => {
}
const builder = new CompletionMessagesBuilder()
if (currentAssistant?.instructions?.length > 0)
builder.addSystemMessage(currentAssistant?.instructions || '')
// REMARK: Would it possible to not attach the entire message history to the request?
// TODO: If not amend messages history here
builder.addUserMessage(message)
@ -143,9 +149,10 @@ export const useChat = () => {
addMessage,
setPrompt,
selectedModel,
tools,
currentAssistant?.instructions,
setAbortController,
updateLoadingModel,
tools,
]
)

View File

@ -17,9 +17,14 @@ type ThreadState = {
deleteAllThreads: () => void
unstarAllThreads: () => void
setCurrentThreadId: (threadId?: string) => void
createThread: (model: ThreadModel, title?: string) => Promise<Thread>
createThread: (
model: ThreadModel,
title?: string,
assistant?: Assistant
) => Promise<Thread>
updateCurrentThreadModel: (model: ThreadModel) => void
getFilteredThreads: (searchTerm: string) => Thread[]
updateCurrentThreadAssistant: (assistant: Assistant) => void
searchIndex: Fuse<Thread> | null
}
@ -152,18 +157,18 @@ export const useThreads = create<ThreadState>()(
setCurrentThreadId: (threadId) => {
set({ currentThreadId: threadId })
},
createThread: async (model, title) => {
createThread: async (model, title, assistant) => {
const newThread: Thread = {
id: ulid(),
title: title ?? 'New Thread',
model,
order: 1,
updated: Date.now() / 1000,
assistants: assistant ? [assistant] : [],
}
set((state) => ({
searchIndex: new Fuse(Object.values(state.threads), fuseOptions),
}))
console.log('newThread', newThread)
return await createThread(newThread).then((createdThread) => {
set((state) => ({
threads: {
@ -175,6 +180,26 @@ export const useThreads = create<ThreadState>()(
return createdThread
})
},
updateCurrentThreadAssistant: (assistant) => {
set((state) => {
if (!state.currentThreadId) return { ...state }
const currentThread = state.getCurrentThread()
if (currentThread)
updateThread({
...currentThread,
assistants: [{ ...assistant, model: currentThread.model }],
})
return {
threads: {
...state.threads,
[state.currentThreadId as string]: {
...state.threads[state.currentThreadId as string],
assistants: [assistant],
},
},
}
})
},
updateCurrentThreadModel: (model) => {
set((state) => {
if (!state.currentThreadId) return { ...state }

View File

@ -1,11 +1,30 @@
import { ChatCompletionMessageParam } from 'token.js'
import { ChatCompletionMessageToolCall } from 'openai/resources'
/**
* @fileoverview Helper functions for creating chat completion request.
* These functions are used to create chat completion request objects
*/
export class CompletionMessagesBuilder {
private messages: ChatCompletionMessageParam[] = []
constructor() {}
/**
* Add a system message to the messages array.
* @param content - The content of the system message.
*/
addSystemMessage(content: string) {
this.messages.push({
role: 'system',
content: content,
})
}
/**
* Add a user message to the messages array.
* @param content - The content of the user message.
*/
addUserMessage(content: string) {
this.messages.push({
role: 'user',
@ -13,15 +32,30 @@ export class CompletionMessagesBuilder {
})
}
addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) {
/**
* Add an assistant message to the messages array.
* @param content - The content of the assistant message.
* @param refusal - Optional refusal message.
* @param calls - Optional tool calls associated with the message.
*/
addAssistantMessage(
content: string,
refusal?: string,
calls?: ChatCompletionMessageToolCall[]
) {
this.messages.push({
role: 'assistant',
content: content,
refusal: refusal,
tool_calls: calls
tool_calls: calls,
})
}
/**
* Add a tool message to the messages array.
* @param content - The content of the tool message.
* @param toolCallId - The ID of the tool call associated with the message.
*/
addToolMessage(content: string, toolCallId: string) {
this.messages.push({
role: 'tool',
@ -30,6 +64,10 @@ export class CompletionMessagesBuilder {
})
}
/**
* Return the messages array.
* @returns The array of chat completion messages.
*/
getMessages(): ChatCompletionMessageParam[] {
return this.messages
}

View File

@ -15,6 +15,8 @@ type SearchParams = {
}
}
import DropdownAssistant from '@/containers/DropdownAssistant'
import { useEffect } from 'react'
import { useThreads } from '@/hooks/useThreads'
export const Route = createFileRoute(route.home as any)({
component: Index,
@ -28,6 +30,7 @@ function Index() {
const { providers } = useModelProvider()
const search = useSearch({ from: route.home as any })
const selectedModel = search.model
const { setCurrentThreadId } = useThreads()
// Conditional to check if there are any valid providers
// required min 1 api_key or 1 model in llama.cpp
@ -37,6 +40,10 @@ function Index() {
(provider.provider === 'llama.cpp' && provider.models.length)
)
useEffect(() => {
setCurrentThreadId(undefined)
}, [setCurrentThreadId])
if (!hasValidProviders) {
return <SetupScreen />
}

View File

@ -16,6 +16,7 @@ import { useMessages } from '@/hooks/useMessages'
import { fetchMessages } from '@/services/messages'
import { useAppState } from '@/hooks/useAppState'
import DropdownAssistant from '@/containers/DropdownAssistant'
import { useAssistant } from '@/hooks/useAssistant'
// as route.threadsDetail
export const Route = createFileRoute('/threads/$threadId')({
@ -28,6 +29,7 @@ function ThreadDetail() {
const [isAtBottom, setIsAtBottom] = useState(true)
const lastScrollTopRef = useRef(0)
const { currentThreadId, getThreadById, setCurrentThreadId } = useThreads()
const { setCurrentAssistant, assistants } = useAssistant()
const { setMessages } = useMessages()
const { streamingContent, loadingModel } = useAppState()
@ -45,9 +47,16 @@ function ThreadDetail() {
const isFirstRender = useRef(true)
useEffect(() => {
if (currentThreadId !== threadId) setCurrentThreadId(threadId)
if (currentThreadId !== threadId) {
setCurrentThreadId(threadId)
const assistant = assistants.find(
(assistant) => assistant.id === thread?.assistants?.[0]?.id
)
if (assistant) setCurrentAssistant(assistant)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [threadId, currentThreadId])
}, [threadId, currentThreadId, assistants])
useEffect(() => {
fetchMessages(threadId).then((fetchedMessages) => {

View File

@ -1,3 +1,4 @@
import { defaultAssistant } from '@/hooks/useAssistant'
import { ExtensionManager } from '@/lib/extension'
import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core'
@ -20,9 +21,10 @@ export const fetchThreads = async (): Promise<Thread[]> => {
order: e.metadata?.order,
isFavorite: e.metadata?.is_favorite,
model: {
id: e.assistants?.[0]?.model.id,
provider: e.assistants?.[0]?.model.engine,
id: e.assistants?.[0]?.model?.id,
provider: e.assistants?.[0]?.model?.engine,
},
assistants: e.assistants ?? [defaultAssistant],
} as Thread
})
})
@ -50,8 +52,8 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
id: thread.model?.id ?? '*',
engine: thread.model?.provider ?? 'llama.cpp',
},
assistant_id: 'jan',
assistant_name: 'Jan',
id: 'jan',
name: 'Jan',
},
],
metadata: {
@ -63,10 +65,11 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
...e,
updated: e.updated,
model: {
id: e.assistants?.[0]?.model.id,
provider: e.assistants?.[0]?.model.engine,
id: e.assistants?.[0]?.model?.id,
provider: e.assistants?.[0]?.model?.engine,
},
order: 1,
assistants: e.assistants ?? [defaultAssistant],
} as Thread
})
.catch(() => thread) ?? thread
@ -82,14 +85,24 @@ export const updateThread = (thread: Thread) => {
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread({
...thread,
assistants: [
assistants: thread.assistants?.map((e) => {
return {
model: {
id: thread.model?.id ?? '*',
engine: thread.model?.provider ?? 'llama.cpp',
},
id: e.id,
name: e.name,
instructions: e.instructions,
}
}) ?? [
{
model: {
id: thread.model?.id ?? '*',
engine: (thread.model?.provider ?? 'llama.cpp'),
engine: thread.model?.provider ?? 'llama.cpp',
},
assistant_id: 'jan',
assistant_name: 'Jan',
id: 'jan',
name: 'Jan',
},
],
metadata: {

View File

@ -31,11 +31,12 @@ type ThreadContent = {
type ChatCompletionRole = 'system' | 'assistant' | 'user' | 'tool'
type ThreadModel = {
id: string
provider: string
}
id: string
provider: string
}
type Thread = {
assistants?: ThreadAssistantInfo[]
id: string
title: string
isFavorite?: boolean
@ -44,3 +45,13 @@ type Thread = {
updated: number
order?: number
}
type Assistant = {
avatar?: string
id: string
name: string
created_at: number
description?: string
instructions: string
parameters: Record<string, unknown>
}