✨enhancement: out of context troubleshooting (#5275)
* ✨enhancement: out of context troubleshooting * 🔧refactor: clean up
This commit is contained in:
parent
d131752419
commit
e20c801ff0
@ -14,7 +14,9 @@ import { Button } from '@/components/ui/button'
|
||||
export function useOutOfContextPromiseModal() {
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
const [modalProps, setModalProps] = useState<{
|
||||
resolveRef: ((value: unknown) => void) | null
|
||||
resolveRef:
|
||||
| ((value: 'ctx_len' | 'context_shift' | undefined) => void)
|
||||
| null
|
||||
}>({
|
||||
resolveRef: null,
|
||||
})
|
||||
@ -33,17 +35,23 @@ export function useOutOfContextPromiseModal() {
|
||||
return null
|
||||
}
|
||||
|
||||
const handleConfirm = () => {
|
||||
const handleContextLength = () => {
|
||||
setIsOpen(false)
|
||||
if (modalProps.resolveRef) {
|
||||
modalProps.resolveRef(true)
|
||||
modalProps.resolveRef('ctx_len')
|
||||
}
|
||||
}
|
||||
|
||||
const handleContextShift = () => {
|
||||
setIsOpen(false)
|
||||
if (modalProps.resolveRef) {
|
||||
modalProps.resolveRef('context_shift')
|
||||
}
|
||||
}
|
||||
const handleCancel = () => {
|
||||
setIsOpen(false)
|
||||
if (modalProps.resolveRef) {
|
||||
modalProps.resolveRef(false)
|
||||
modalProps.resolveRef(undefined)
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +72,7 @@ export function useOutOfContextPromiseModal() {
|
||||
<DialogDescription>
|
||||
{t(
|
||||
'outOfContextError.description',
|
||||
'This chat is reaching the AI’s memory limit, like a whiteboard filling up. We can expand the memory window (called context size) so it remembers more, but it may use more of your computer’s memory.'
|
||||
'This chat is reaching the AI’s memory limit, like a whiteboard filling up. We can expand the memory window (called context size) so it remembers more, but it may use more of your computer’s memory. We can also truncate the input, which means it will forget some of the chat history to make room for new messages.'
|
||||
)}
|
||||
<br />
|
||||
<br />
|
||||
@ -77,14 +85,17 @@ export function useOutOfContextPromiseModal() {
|
||||
<Button
|
||||
variant="default"
|
||||
className="bg-transparent border border-main-view-fg/20 hover:bg-main-view-fg/4"
|
||||
onClick={() => setIsOpen(false)}
|
||||
onClick={() => {
|
||||
handleContextShift()
|
||||
setIsOpen(false)
|
||||
}}
|
||||
>
|
||||
{t('common.cancel', 'Cancel')}
|
||||
{t('outOfContextError.truncateInput', 'Truncate Input')}
|
||||
</Button>
|
||||
<Button
|
||||
asChild
|
||||
onClick={() => {
|
||||
handleConfirm()
|
||||
handleContextLength()
|
||||
setIsOpen(false)
|
||||
}}
|
||||
>
|
||||
|
||||
@ -29,6 +29,7 @@ import { stopModel, startModel, stopAllModels } from '@/services/models'
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
import { useToolAvailable } from '@/hooks/useToolAvailable'
|
||||
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
|
||||
import { updateSettings } from '@/services/providers'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
@ -110,19 +111,41 @@ export const useChat = () => {
|
||||
currentAssistant,
|
||||
])
|
||||
|
||||
const restartModel = useCallback(
|
||||
async (
|
||||
provider: ProviderObject,
|
||||
modelId: string,
|
||||
abortController: AbortController
|
||||
) => {
|
||||
await stopAllModels()
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider, modelId, abortController).catch(console.error)
|
||||
updateLoadingModel(false)
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||||
},
|
||||
[updateLoadingModel]
|
||||
)
|
||||
|
||||
const increaseModelContextSize = useCallback(
|
||||
(model: Model, provider: ProviderObject) => {
|
||||
async (
|
||||
modelId: string,
|
||||
provider: ProviderObject,
|
||||
controller: AbortController
|
||||
) => {
|
||||
/**
|
||||
* Should increase the context size of the model by 2x
|
||||
* If the context size is not set or too low, it defaults to 8192.
|
||||
*/
|
||||
const model = provider.models.find((m) => m.id === modelId)
|
||||
if (!model) return undefined
|
||||
const ctxSize = Math.max(
|
||||
model.settings?.ctx_len?.controller_props.value
|
||||
? typeof model.settings.ctx_len.controller_props.value === 'string'
|
||||
? parseInt(model.settings.ctx_len.controller_props.value as string)
|
||||
: (model.settings.ctx_len.controller_props.value as number)
|
||||
: 8192,
|
||||
8192
|
||||
: 16384,
|
||||
16384
|
||||
)
|
||||
const updatedModel = {
|
||||
...model,
|
||||
@ -153,9 +176,54 @@ export const useChat = () => {
|
||||
models: updatedModels,
|
||||
})
|
||||
}
|
||||
stopAllModels()
|
||||
const updatedProvider = getProviderByName(provider.provider)
|
||||
if (updatedProvider)
|
||||
await restartModel(updatedProvider, model.id, controller)
|
||||
|
||||
console.log(
|
||||
updatedProvider?.models.find((e) => e.id === model.id)?.settings
|
||||
?.ctx_len?.controller_props.value
|
||||
)
|
||||
return updatedProvider
|
||||
},
|
||||
[updateProvider]
|
||||
[getProviderByName, restartModel, updateProvider]
|
||||
)
|
||||
const toggleOnContextShifting = useCallback(
|
||||
async (
|
||||
modelId: string,
|
||||
provider: ProviderObject,
|
||||
controller: AbortController
|
||||
) => {
|
||||
const providerName = provider.provider
|
||||
const newSettings = [...provider.settings]
|
||||
const settingKey = 'context_shift'
|
||||
// Handle different value types by forcing the type
|
||||
// Use type assertion to bypass type checking
|
||||
const settingIndex = provider.settings.findIndex(
|
||||
(s) => s.key === settingKey
|
||||
)
|
||||
;(
|
||||
newSettings[settingIndex].controller_props as {
|
||||
value: string | boolean | number
|
||||
}
|
||||
).value = true
|
||||
|
||||
// Create update object with updated settings
|
||||
const updateObj: Partial<ModelProvider> = {
|
||||
settings: newSettings,
|
||||
}
|
||||
|
||||
await updateSettings(providerName, updateObj.settings ?? [])
|
||||
updateProvider(providerName, {
|
||||
...provider,
|
||||
...updateObj,
|
||||
})
|
||||
const updatedProvider = getProviderByName(providerName)
|
||||
if (updatedProvider)
|
||||
await restartModel(updatedProvider, modelId, controller)
|
||||
return updatedProvider
|
||||
},
|
||||
[updateProvider, getProviderByName, restartModel]
|
||||
)
|
||||
|
||||
const sendMessage = useCallback(
|
||||
@ -167,7 +235,7 @@ export const useChat = () => {
|
||||
const activeThread = await getCurrentThread()
|
||||
|
||||
resetTokenSpeed()
|
||||
const activeProvider = currentProviderId
|
||||
let activeProvider = currentProviderId
|
||||
? getProviderByName(currentProviderId)
|
||||
: provider
|
||||
if (!activeThread || !activeProvider) return
|
||||
@ -210,7 +278,11 @@ export const useChat = () => {
|
||||
|
||||
// TODO: Later replaced by Agent setup?
|
||||
const followUpWithToolUse = true
|
||||
while (!isCompleted && !abortController.signal.aborted) {
|
||||
while (
|
||||
!isCompleted &&
|
||||
!abortController.signal.aborted &&
|
||||
activeProvider
|
||||
) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
activeProvider,
|
||||
@ -229,56 +301,90 @@ export const useChat = () => {
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
if (isCompletionResponse(completion)) {
|
||||
accumulatedText = completion.choices[0]?.message?.content || ''
|
||||
if (completion.choices[0]?.message?.tool_calls) {
|
||||
toolCalls.push(...completion.choices[0].message.tool_calls)
|
||||
}
|
||||
} else {
|
||||
for await (const part of completion) {
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
accumulatedText = completion.choices[0]?.message?.content || ''
|
||||
if (completion.choices[0]?.message?.tool_calls) {
|
||||
toolCalls.push(...completion.choices[0].message.tool_calls)
|
||||
}
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
} else {
|
||||
for await (const part of completion) {
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
const delta = part.choices[0]?.delta?.content || ''
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
const calls = extractToolCall(part, currentCall, toolCalls)
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: calls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
const calls = extractToolCall(part, currentCall, toolCalls)
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: calls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
updateTokenSpeed(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
}
|
||||
if (delta) {
|
||||
accumulatedText += delta
|
||||
// Create a new object each time to avoid reference issues
|
||||
// Use a timeout to prevent React from batching updates too quickly
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error && typeof error === 'object' && 'message' in error
|
||||
? error.message
|
||||
: error
|
||||
if (
|
||||
typeof errorMessage === 'string' &&
|
||||
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
|
||||
selectedModel &&
|
||||
troubleshooting
|
||||
) {
|
||||
const method = await showModal?.()
|
||||
if (method === 'ctx_len') {
|
||||
/// Increase context size
|
||||
activeProvider = await increaseModelContextSize(
|
||||
selectedModel.id,
|
||||
activeProvider,
|
||||
abortController
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
updateTokenSpeed(currentContent)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
continue
|
||||
} else if (method === 'context_shift' && selectedModel?.id) {
|
||||
/// Enable context_shift
|
||||
activeProvider = await toggleOnContextShifting(
|
||||
selectedModel?.id,
|
||||
activeProvider,
|
||||
abortController
|
||||
)
|
||||
continue
|
||||
} else throw error
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
// TODO: Remove this check when integrating new llama.cpp extension
|
||||
@ -320,21 +426,7 @@ export const useChat = () => {
|
||||
error && typeof error === 'object' && 'message' in error
|
||||
? error.message
|
||||
: error
|
||||
if (
|
||||
typeof errorMessage === 'string' &&
|
||||
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
|
||||
selectedModel &&
|
||||
troubleshooting
|
||||
) {
|
||||
showModal?.().then((confirmed) => {
|
||||
if (confirmed) {
|
||||
increaseModelContextSize(selectedModel, activeProvider)
|
||||
setTimeout(() => {
|
||||
sendMessage(message, showModal, false) // Retry sending the message without troubleshooting
|
||||
}, 1000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
toast.error(`Error sending message: ${errorMessage}`)
|
||||
console.error('Error sending message:', error)
|
||||
} finally {
|
||||
@ -355,7 +447,8 @@ export const useChat = () => {
|
||||
updateThreadTimestamp,
|
||||
setPrompt,
|
||||
selectedModel,
|
||||
currentAssistant,
|
||||
currentAssistant?.instructions,
|
||||
currentAssistant.parameters,
|
||||
tools,
|
||||
updateLoadingModel,
|
||||
getDisabledToolsForThread,
|
||||
@ -364,6 +457,7 @@ export const useChat = () => {
|
||||
showApprovalModal,
|
||||
updateTokenSpeed,
|
||||
increaseModelContextSize,
|
||||
toggleOnContextShifting,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
getActiveModels,
|
||||
importModel,
|
||||
startModel,
|
||||
stopAllModels,
|
||||
stopModel,
|
||||
} from '@/services/models'
|
||||
import {
|
||||
@ -299,6 +300,8 @@ function ProviderDetail() {
|
||||
...provider,
|
||||
...updateObj,
|
||||
})
|
||||
|
||||
stopAllModels()
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
@ -296,7 +296,8 @@ export const startModel = async (
|
||||
normalizeProvider(provider.provider)
|
||||
)
|
||||
const modelObj = provider.models.find((m) => m.id === model)
|
||||
if (providerObj && modelObj)
|
||||
|
||||
if (providerObj && modelObj) {
|
||||
return providerObj?.loadModel(
|
||||
{
|
||||
id: modelObj.id,
|
||||
@ -309,6 +310,7 @@ export const startModel = async (
|
||||
},
|
||||
abortController
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user