feat: Jan supports multiple assistants (#5024)
* feat: Jan supports multiple assistants * chore: persists current assistant to threads.json * chore: update assistant persistence * chore: simplify persistence objects
This commit is contained in:
parent
f6433544af
commit
2dac53e9ca
@ -27,8 +27,8 @@ export type Thread = {
|
||||
* @stored
|
||||
*/
|
||||
export type ThreadAssistantInfo = {
|
||||
assistant_id: string
|
||||
assistant_name: string
|
||||
id: string
|
||||
name: string
|
||||
model: ModelInfo
|
||||
instructions?: string
|
||||
tools?: AssistantTool[]
|
||||
|
||||
@ -97,8 +97,8 @@ pub struct ImageContentValue {
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ThreadAssistantInfo {
|
||||
pub assistant_id: String,
|
||||
pub assistant_name: String,
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub model: ModelInfo,
|
||||
pub instructions: Option<String>,
|
||||
pub tools: Option<Vec<AssistantTool>>,
|
||||
@ -456,16 +456,16 @@ pub async fn modify_thread_assistant<R: Runtime>(
|
||||
serde_json::from_str(&data).map_err(|e| e.to_string())?
|
||||
};
|
||||
let assistant_id = assistant
|
||||
.get("assistant_id")
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or("Missing assistant_id")?;
|
||||
.ok_or("Missing id")?;
|
||||
if let Some(assistants) = thread
|
||||
.get_mut("assistants")
|
||||
.and_then(|a: &mut serde_json::Value| a.as_array_mut())
|
||||
{
|
||||
if let Some(index) = assistants
|
||||
.iter()
|
||||
.position(|a| a.get("assistant_id").and_then(|v| v.as_str()) == Some(assistant_id))
|
||||
.position(|a| a.get("id").and_then(|v| v.as_str()) == Some(assistant_id))
|
||||
{
|
||||
assistants[index] = assistant.clone();
|
||||
let data = serde_json::to_string_pretty(&thread).map_err(|e| e.to_string())?;
|
||||
|
||||
@ -9,20 +9,25 @@ import {
|
||||
import { useAssistant } from '@/hooks/useAssistant'
|
||||
import AddEditAssistant from './dialogs/AddEditAssistant'
|
||||
import { IconCirclePlus, IconSettings } from '@tabler/icons-react'
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
|
||||
const DropdownAssistant = () => {
|
||||
const { assistants, addAssistant, updateAssistant } = useAssistant()
|
||||
const {
|
||||
assistants,
|
||||
currentAssistant,
|
||||
addAssistant,
|
||||
updateAssistant,
|
||||
setCurrentAssistant,
|
||||
} = useAssistant()
|
||||
const { updateCurrentThreadAssistant } = useThreads()
|
||||
const [dropdownOpen, setDropdownOpen] = useState(false)
|
||||
const [dialogOpen, setDialogOpen] = useState(false)
|
||||
const [editingAssistantId, setEditingAssistantId] = useState<string | null>(
|
||||
null
|
||||
)
|
||||
const [selectedAssistantId, setSelectedAssistantId] = useState<string | null>(
|
||||
assistants[0]?.id || null
|
||||
)
|
||||
|
||||
const selectedAssistant =
|
||||
assistants.find((a) => a.id === selectedAssistantId) || assistants[0]
|
||||
assistants.find((a) => a.id === currentAssistant.id) || assistants[0]
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -63,7 +68,10 @@ const DropdownAssistant = () => {
|
||||
<DropdownMenuItem className="flex justify-between items-center">
|
||||
<span
|
||||
className="truncate text-main-view-fg/70 flex-1 cursor-pointer"
|
||||
onClick={() => setSelectedAssistantId(assistant.id)}
|
||||
onClick={() => {
|
||||
setCurrentAssistant(assistant)
|
||||
updateCurrentThreadAssistant(assistant)
|
||||
}}
|
||||
>
|
||||
{assistant.name}
|
||||
</span>
|
||||
|
||||
@ -2,24 +2,17 @@ import { localStoregeKey } from '@/constants/localStorage'
|
||||
import { create } from 'zustand'
|
||||
import { persist } from 'zustand/middleware'
|
||||
|
||||
export type Assistant = {
|
||||
avatar?: string
|
||||
id: string
|
||||
name: string
|
||||
created_at: number
|
||||
description?: string
|
||||
instructions: string
|
||||
parameters: Record<string, unknown>
|
||||
}
|
||||
|
||||
interface AssistantState {
|
||||
assistants: Assistant[]
|
||||
currentAssistant: Assistant
|
||||
addAssistant: (assistant: Assistant) => void
|
||||
updateAssistant: (assistant: Assistant) => void
|
||||
deleteAssistant: (id: string) => void
|
||||
setCurrentAssistant: (assistant: Assistant) => void
|
||||
}
|
||||
|
||||
const defaultAssistant: Assistant = {
|
||||
export const defaultAssistant: Assistant = {
|
||||
avatar: '',
|
||||
id: 'jan',
|
||||
name: 'Jan',
|
||||
@ -33,6 +26,7 @@ export const useAssistant = create<AssistantState>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
assistants: [defaultAssistant],
|
||||
currentAssistant: defaultAssistant,
|
||||
addAssistant: (assistant) =>
|
||||
set({ assistants: [...get().assistants, assistant] }),
|
||||
updateAssistant: (assistant) =>
|
||||
@ -43,6 +37,9 @@ export const useAssistant = create<AssistantState>()(
|
||||
}),
|
||||
deleteAssistant: (id) =>
|
||||
set({ assistants: get().assistants.filter((a) => a.id !== id) }),
|
||||
setCurrentAssistant: (assistant) => {
|
||||
set({ currentAssistant: assistant })
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStoregeKey.assistant,
|
||||
|
||||
@ -18,10 +18,12 @@ import {
|
||||
} from '@/lib/completion'
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { useAssistant } from './useAssistant'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { tools } = useAppState()
|
||||
const { currentAssistant } = useAssistant()
|
||||
|
||||
const { getProviderByName, selectedModel, selectedProvider } =
|
||||
useModelProvider()
|
||||
@ -43,7 +45,8 @@ export const useChat = () => {
|
||||
id: selectedModel?.id ?? defaultModel(selectedProvider),
|
||||
provider: selectedProvider,
|
||||
},
|
||||
prompt
|
||||
prompt,
|
||||
currentAssistant
|
||||
)
|
||||
router.navigate({
|
||||
to: route.threadsDetail,
|
||||
@ -58,6 +61,7 @@ export const useChat = () => {
|
||||
router,
|
||||
selectedModel?.id,
|
||||
selectedProvider,
|
||||
currentAssistant,
|
||||
])
|
||||
|
||||
const sendMessage = useCallback(
|
||||
@ -79,6 +83,8 @@ export const useChat = () => {
|
||||
}
|
||||
|
||||
const builder = new CompletionMessagesBuilder()
|
||||
if (currentAssistant?.instructions?.length > 0)
|
||||
builder.addSystemMessage(currentAssistant?.instructions || '')
|
||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||
// TODO: If not amend messages history here
|
||||
builder.addUserMessage(message)
|
||||
@ -143,9 +149,10 @@ export const useChat = () => {
|
||||
addMessage,
|
||||
setPrompt,
|
||||
selectedModel,
|
||||
tools,
|
||||
currentAssistant?.instructions,
|
||||
setAbortController,
|
||||
updateLoadingModel,
|
||||
tools,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -17,9 +17,14 @@ type ThreadState = {
|
||||
deleteAllThreads: () => void
|
||||
unstarAllThreads: () => void
|
||||
setCurrentThreadId: (threadId?: string) => void
|
||||
createThread: (model: ThreadModel, title?: string) => Promise<Thread>
|
||||
createThread: (
|
||||
model: ThreadModel,
|
||||
title?: string,
|
||||
assistant?: Assistant
|
||||
) => Promise<Thread>
|
||||
updateCurrentThreadModel: (model: ThreadModel) => void
|
||||
getFilteredThreads: (searchTerm: string) => Thread[]
|
||||
updateCurrentThreadAssistant: (assistant: Assistant) => void
|
||||
searchIndex: Fuse<Thread> | null
|
||||
}
|
||||
|
||||
@ -152,18 +157,18 @@ export const useThreads = create<ThreadState>()(
|
||||
setCurrentThreadId: (threadId) => {
|
||||
set({ currentThreadId: threadId })
|
||||
},
|
||||
createThread: async (model, title) => {
|
||||
createThread: async (model, title, assistant) => {
|
||||
const newThread: Thread = {
|
||||
id: ulid(),
|
||||
title: title ?? 'New Thread',
|
||||
model,
|
||||
order: 1,
|
||||
updated: Date.now() / 1000,
|
||||
assistants: assistant ? [assistant] : [],
|
||||
}
|
||||
set((state) => ({
|
||||
searchIndex: new Fuse(Object.values(state.threads), fuseOptions),
|
||||
}))
|
||||
console.log('newThread', newThread)
|
||||
return await createThread(newThread).then((createdThread) => {
|
||||
set((state) => ({
|
||||
threads: {
|
||||
@ -175,6 +180,26 @@ export const useThreads = create<ThreadState>()(
|
||||
return createdThread
|
||||
})
|
||||
},
|
||||
updateCurrentThreadAssistant: (assistant) => {
|
||||
set((state) => {
|
||||
if (!state.currentThreadId) return { ...state }
|
||||
const currentThread = state.getCurrentThread()
|
||||
if (currentThread)
|
||||
updateThread({
|
||||
...currentThread,
|
||||
assistants: [{ ...assistant, model: currentThread.model }],
|
||||
})
|
||||
return {
|
||||
threads: {
|
||||
...state.threads,
|
||||
[state.currentThreadId as string]: {
|
||||
...state.threads[state.currentThreadId as string],
|
||||
assistants: [assistant],
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
updateCurrentThreadModel: (model) => {
|
||||
set((state) => {
|
||||
if (!state.currentThreadId) return { ...state }
|
||||
|
||||
@ -1,11 +1,30 @@
|
||||
import { ChatCompletionMessageParam } from 'token.js'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
/**
|
||||
* @fileoverview Helper functions for creating chat completion request.
|
||||
* These functions are used to create chat completion request objects
|
||||
*/
|
||||
export class CompletionMessagesBuilder {
|
||||
private messages: ChatCompletionMessageParam[] = []
|
||||
|
||||
constructor() {}
|
||||
|
||||
/**
|
||||
* Add a system message to the messages array.
|
||||
* @param content - The content of the system message.
|
||||
*/
|
||||
addSystemMessage(content: string) {
|
||||
this.messages.push({
|
||||
role: 'system',
|
||||
content: content,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a user message to the messages array.
|
||||
* @param content - The content of the user message.
|
||||
*/
|
||||
addUserMessage(content: string) {
|
||||
this.messages.push({
|
||||
role: 'user',
|
||||
@ -13,15 +32,30 @@ export class CompletionMessagesBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
addAssistantMessage(content: string, refusal?: string, calls?: ChatCompletionMessageToolCall[]) {
|
||||
/**
|
||||
* Add an assistant message to the messages array.
|
||||
* @param content - The content of the assistant message.
|
||||
* @param refusal - Optional refusal message.
|
||||
* @param calls - Optional tool calls associated with the message.
|
||||
*/
|
||||
addAssistantMessage(
|
||||
content: string,
|
||||
refusal?: string,
|
||||
calls?: ChatCompletionMessageToolCall[]
|
||||
) {
|
||||
this.messages.push({
|
||||
role: 'assistant',
|
||||
content: content,
|
||||
refusal: refusal,
|
||||
tool_calls: calls
|
||||
tool_calls: calls,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a tool message to the messages array.
|
||||
* @param content - The content of the tool message.
|
||||
* @param toolCallId - The ID of the tool call associated with the message.
|
||||
*/
|
||||
addToolMessage(content: string, toolCallId: string) {
|
||||
this.messages.push({
|
||||
role: 'tool',
|
||||
@ -30,6 +64,10 @@ export class CompletionMessagesBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the messages array.
|
||||
* @returns The array of chat completion messages.
|
||||
*/
|
||||
getMessages(): ChatCompletionMessageParam[] {
|
||||
return this.messages
|
||||
}
|
||||
|
||||
@ -15,6 +15,8 @@ type SearchParams = {
|
||||
}
|
||||
}
|
||||
import DropdownAssistant from '@/containers/DropdownAssistant'
|
||||
import { useEffect } from 'react'
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
|
||||
export const Route = createFileRoute(route.home as any)({
|
||||
component: Index,
|
||||
@ -28,6 +30,7 @@ function Index() {
|
||||
const { providers } = useModelProvider()
|
||||
const search = useSearch({ from: route.home as any })
|
||||
const selectedModel = search.model
|
||||
const { setCurrentThreadId } = useThreads()
|
||||
|
||||
// Conditional to check if there are any valid providers
|
||||
// required min 1 api_key or 1 model in llama.cpp
|
||||
@ -37,6 +40,10 @@ function Index() {
|
||||
(provider.provider === 'llama.cpp' && provider.models.length)
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
setCurrentThreadId(undefined)
|
||||
}, [setCurrentThreadId])
|
||||
|
||||
if (!hasValidProviders) {
|
||||
return <SetupScreen />
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ import { useMessages } from '@/hooks/useMessages'
|
||||
import { fetchMessages } from '@/services/messages'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
import DropdownAssistant from '@/containers/DropdownAssistant'
|
||||
import { useAssistant } from '@/hooks/useAssistant'
|
||||
|
||||
// as route.threadsDetail
|
||||
export const Route = createFileRoute('/threads/$threadId')({
|
||||
@ -28,6 +29,7 @@ function ThreadDetail() {
|
||||
const [isAtBottom, setIsAtBottom] = useState(true)
|
||||
const lastScrollTopRef = useRef(0)
|
||||
const { currentThreadId, getThreadById, setCurrentThreadId } = useThreads()
|
||||
const { setCurrentAssistant, assistants } = useAssistant()
|
||||
const { setMessages } = useMessages()
|
||||
const { streamingContent, loadingModel } = useAppState()
|
||||
|
||||
@ -45,9 +47,16 @@ function ThreadDetail() {
|
||||
const isFirstRender = useRef(true)
|
||||
|
||||
useEffect(() => {
|
||||
if (currentThreadId !== threadId) setCurrentThreadId(threadId)
|
||||
if (currentThreadId !== threadId) {
|
||||
setCurrentThreadId(threadId)
|
||||
const assistant = assistants.find(
|
||||
(assistant) => assistant.id === thread?.assistants?.[0]?.id
|
||||
)
|
||||
if (assistant) setCurrentAssistant(assistant)
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [threadId, currentThreadId])
|
||||
}, [threadId, currentThreadId, assistants])
|
||||
|
||||
useEffect(() => {
|
||||
fetchMessages(threadId).then((fetchedMessages) => {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { defaultAssistant } from '@/hooks/useAssistant'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
import { ConversationalExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||
|
||||
@ -20,9 +21,10 @@ export const fetchThreads = async (): Promise<Thread[]> => {
|
||||
order: e.metadata?.order,
|
||||
isFavorite: e.metadata?.is_favorite,
|
||||
model: {
|
||||
id: e.assistants?.[0]?.model.id,
|
||||
provider: e.assistants?.[0]?.model.engine,
|
||||
id: e.assistants?.[0]?.model?.id,
|
||||
provider: e.assistants?.[0]?.model?.engine,
|
||||
},
|
||||
assistants: e.assistants ?? [defaultAssistant],
|
||||
} as Thread
|
||||
})
|
||||
})
|
||||
@ -50,8 +52,8 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
|
||||
id: thread.model?.id ?? '*',
|
||||
engine: thread.model?.provider ?? 'llama.cpp',
|
||||
},
|
||||
assistant_id: 'jan',
|
||||
assistant_name: 'Jan',
|
||||
id: 'jan',
|
||||
name: 'Jan',
|
||||
},
|
||||
],
|
||||
metadata: {
|
||||
@ -63,10 +65,11 @@ export const createThread = async (thread: Thread): Promise<Thread> => {
|
||||
...e,
|
||||
updated: e.updated,
|
||||
model: {
|
||||
id: e.assistants?.[0]?.model.id,
|
||||
provider: e.assistants?.[0]?.model.engine,
|
||||
id: e.assistants?.[0]?.model?.id,
|
||||
provider: e.assistants?.[0]?.model?.engine,
|
||||
},
|
||||
order: 1,
|
||||
assistants: e.assistants ?? [defaultAssistant],
|
||||
} as Thread
|
||||
})
|
||||
.catch(() => thread) ?? thread
|
||||
@ -82,14 +85,24 @@ export const updateThread = (thread: Thread) => {
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.modifyThread({
|
||||
...thread,
|
||||
assistants: [
|
||||
assistants: thread.assistants?.map((e) => {
|
||||
return {
|
||||
model: {
|
||||
id: thread.model?.id ?? '*',
|
||||
engine: thread.model?.provider ?? 'llama.cpp',
|
||||
},
|
||||
id: e.id,
|
||||
name: e.name,
|
||||
instructions: e.instructions,
|
||||
}
|
||||
}) ?? [
|
||||
{
|
||||
model: {
|
||||
id: thread.model?.id ?? '*',
|
||||
engine: (thread.model?.provider ?? 'llama.cpp'),
|
||||
engine: thread.model?.provider ?? 'llama.cpp',
|
||||
},
|
||||
assistant_id: 'jan',
|
||||
assistant_name: 'Jan',
|
||||
id: 'jan',
|
||||
name: 'Jan',
|
||||
},
|
||||
],
|
||||
metadata: {
|
||||
|
||||
17
web-app/src/types/threads.d.ts
vendored
17
web-app/src/types/threads.d.ts
vendored
@ -31,11 +31,12 @@ type ThreadContent = {
|
||||
type ChatCompletionRole = 'system' | 'assistant' | 'user' | 'tool'
|
||||
|
||||
type ThreadModel = {
|
||||
id: string
|
||||
provider: string
|
||||
}
|
||||
id: string
|
||||
provider: string
|
||||
}
|
||||
|
||||
type Thread = {
|
||||
assistants?: ThreadAssistantInfo[]
|
||||
id: string
|
||||
title: string
|
||||
isFavorite?: boolean
|
||||
@ -44,3 +45,13 @@ type Thread = {
|
||||
updated: number
|
||||
order?: number
|
||||
}
|
||||
|
||||
type Assistant = {
|
||||
avatar?: string
|
||||
id: string
|
||||
name: string
|
||||
created_at: number
|
||||
description?: string
|
||||
instructions: string
|
||||
parameters: Record<string, unknown>
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user