jan/web/hooks/useCreateNewThread.ts

230 lines
6.9 KiB
TypeScript

import { useCallback } from 'react'
import {
ConversationalExtension,
ExtensionTypeEnum,
Thread,
ThreadAssistantInfo,
ThreadState,
AssistantTool,
Model,
Assistant,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai'
import { toaster } from '@/containers/Toast'
import { isLocalEngine } from '@/utils/modelEngine'
import { useActiveModel } from './useActiveModel'
import useRecommendedModel from './useRecommendedModel'
import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
updateThreadAtom,
setThreadModelParamsAtom,
isGeneratingResponseAtom,
createNewThreadAtom,
} from '@/helpers/atoms/Thread.atom'
export const useCreateNewThread = () => {
const createNewThread = useSetAtom(createNewThreadAtom)
const { setActiveThread } = useSetActiveThread()
const updateThread = useSetAtom(updateThreadAtom)
const setFileUpload = useSetAtom(fileUploadAtom)
const setSelectedModel = useSetAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const threads = useAtomValue(threadsAtom)
const { stopInference } = useActiveModel()
const { recommendedModel } = useRecommendedModel()
const requestCreateNewThread = async (
assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant,
model?: Model | undefined
) => {
// Stop generating if any
setIsGeneratingResponse(false)
stopInference()
const defaultModel = model || recommendedModel
if (!model) {
// if we have model, which means user wants to create new thread from Model hub. Allow them.
// check last thread message, if there empty last message use can not create thread
const lastMessage = threads[0]?.metadata?.lastMessage
if (!lastMessage && threads.length) {
return toaster({
title: 'No new thread created.',
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
type: 'warning',
})
}
}
// modify assistant tools when experimental on, retieval toggle enabled in default
const assistantTools: AssistantTool = {
type: 'retrieval',
enabled: true,
settings: assistant.tools && assistant.tools[0].settings,
}
// Default context length is 8192
const defaultContextLength = Math.min(
8192,
defaultModel?.settings.ctx_len ?? 8192
)
const overriddenSettings = {
ctx_len: !isLocalEngine(defaultModel?.engine)
? undefined
: defaultContextLength,
}
// Use ctx length by default
const overriddenParameters = {
max_tokens: !isLocalEngine(defaultModel?.engine)
? (defaultModel?.parameters.token_limit ?? 8192)
: defaultContextLength,
}
const createdAt = Date.now()
let instructions: string | undefined = assistant.instructions
if (copyOverInstructionEnabled) {
instructions = activeAssistant?.instructions ?? undefined
}
const assistantInfo: ThreadAssistantInfo = {
assistant_id: assistant.id,
assistant_name: assistant.name,
tools: experimentalEnabled ? [assistantTools] : assistant.tools,
model: {
id: defaultModel?.id ?? '*',
settings: { ...defaultModel?.settings, ...overriddenSettings },
parameters: { ...defaultModel?.parameters, ...overriddenParameters },
engine: defaultModel?.engine,
},
instructions,
}
const thread: Partial<Thread> = {
object: 'thread',
title: 'New Thread',
assistants: [assistantInfo],
created: createdAt,
updated: createdAt,
metadata: {
title: 'New Thread',
updated_at: Date.now(),
},
}
// add the new thread on top of the thread list to the state
//TODO: Why do we have thread list then thread states? Should combine them
try {
const createdThread = await persistNewThread(thread, assistantInfo)
if (!createdThread) throw 'Thread created failed.'
createNewThread(createdThread)
setSelectedModel(defaultModel)
setThreadModelParams(createdThread.id, {
...defaultModel?.settings,
...defaultModel?.parameters,
...overriddenSettings,
})
// Delete the file upload state
setFileUpload(undefined)
setActiveThread(createdThread)
} catch (ex) {
return toaster({
title: 'Thread created failed.',
description: `To avoid piling up empty threads, please reuse previous one before creating new.`,
type: 'error',
})
}
}
const updateThreadExtension = (thread: Thread) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread(thread)
}
const updateAssistantExtension = (
threadId: string,
assistant: ThreadAssistantInfo
) => {
return extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThreadAssistant(threadId, assistant)
}
const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300)
const updateAssistantCallback = useDebouncedCallback(
updateAssistantExtension,
300
)
const updateThreadMetadata = useCallback(
async (thread: Thread) => {
updateThread(thread)
updateThreadCallback(thread)
if (thread.assistants && thread.assistants?.length > 0) {
setActiveAssistant(thread.assistants[0])
updateAssistantCallback(thread.id, thread.assistants[0])
}
},
[
updateThread,
setActiveAssistant,
updateThreadCallback,
updateAssistantCallback,
]
)
const persistNewThread = async (
thread: Partial<Thread>,
assistantInfo: ThreadAssistantInfo
): Promise<Thread | undefined> => {
return await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createThread(thread)
.then(async (thread) => {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createThreadAssistant(thread.id, assistantInfo)
.catch(console.error)
return thread
})
.catch(() => undefined)
}
return {
requestCreateNewThread,
updateThreadMetadata,
}
}