From 51a321219d0ec813138f20a9833b0162d0142c6a Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 10 Jun 2025 16:26:42 +0700 Subject: [PATCH] chore: fix model settings are not applied accordingly on change (#5231) * chore: fix model settings are not applied accordingly on change * chore: handle failed tool call * chore: stop inference and model on reject --- .../resources/default_settings.json | 4 +-- .../inference-cortex-extension/src/index.ts | 12 +++++--- web-app/src/containers/ChatInput.tsx | 2 ++ web-app/src/hooks/useChat.ts | 29 +++++++++++++------ web-app/src/lib/completion.ts | 11 +++++++ web-app/src/services/providers.ts | 22 +++++++++----- 6 files changed, 57 insertions(+), 23 deletions(-) diff --git a/extensions/inference-cortex-extension/resources/default_settings.json b/extensions/inference-cortex-extension/resources/default_settings.json index 881f74404..d825affb2 100644 --- a/extensions/inference-cortex-extension/resources/default_settings.json +++ b/extensions/inference-cortex-extension/resources/default_settings.json @@ -23,8 +23,8 @@ "description": "Number of prompts that can be processed simultaneously by the model.", "controllerType": "input", "controllerProps": { - "value": "4", - "placeholder": "4", + "value": "1", + "placeholder": "1", "type": "number", "textAlign": "right" } diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 78786ade0..b217a4f48 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -55,7 +55,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { shouldReconnect = true /** Default Engine model load settings */ - n_parallel: number = 4 + n_parallel?: number cont_batching: boolean = true caching_enabled: boolean = true flash_attn: boolean = true @@ -114,8 +114,10 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { // Register Settings this.registerSettings(SETTINGS) - this.n_parallel = - Number(await this.getSetting(Settings.n_parallel, '4')) ?? 4 + const numParallel = await this.getSetting(Settings.n_parallel, '') + if (numParallel.length > 0 && parseInt(numParallel) > 0) { + this.n_parallel = parseInt(numParallel) + } this.cont_batching = await this.getSetting( Settings.cont_batching, true @@ -184,7 +186,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { */ onSettingUpdate(key: string, value: T): void { if (key === Settings.n_parallel && typeof value === 'string') { - this.n_parallel = Number(value) ?? 1 + if (value.length > 0 && parseInt(value) > 0) { + this.n_parallel = parseInt(value) + } } else if (key === Settings.cont_batching && typeof value === 'boolean') { this.cont_batching = value as boolean } else if (key === Settings.caching_enabled && typeof value === 'boolean') { diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 44c0b15b3..4f858f94f 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -35,6 +35,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider' import { ModelLoader } from '@/containers/loaders/ModelLoader' import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable' import { getConnectedServers } from '@/services/mcp' +import { stopAllModels } from '@/services/models' type ChatInputProps = { className?: string @@ -161,6 +162,7 @@ const ChatInput = ({ const stopStreaming = useCallback( (threadId: string) => { abortControllers[threadId]?.abort() + stopAllModels() }, [abortControllers] ) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 6d747f5e5..449b5daa9 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -61,6 +61,10 @@ export const useChat = () => { return getProviderByName(selectedProvider) }, [selectedProvider, getProviderByName]) + const currentProviderId = useMemo(() => { + return provider?.provider || selectedProvider + }, [provider, selectedProvider]) + useEffect(() => { function setTools() { getTools().then((data: MCPTool[]) => { @@ -109,7 +113,10 @@ export const useChat = () => { const activeThread = await getCurrentThread() resetTokenSpeed() - if (!activeThread || !provider) return + const activeProvider = currentProviderId + ? getProviderByName(currentProviderId) + : provider + if (!activeThread || !activeProvider) return const messages = getMessages(activeThread.id) const abortController = new AbortController() setAbortController(activeThread.id, abortController) @@ -120,9 +127,11 @@ export const useChat = () => { try { if (selectedModel?.id) { updateLoadingModel(true) - await startModel(provider, selectedModel.id, abortController).catch( - console.error - ) + await startModel( + activeProvider, + selectedModel.id, + abortController + ).catch(console.error) updateLoadingModel(false) } @@ -148,7 +157,7 @@ export const useChat = () => { while (!isCompleted && !abortController.signal.aborted) { const completion = await sendCompletion( activeThread, - provider, + activeProvider, builder.getMessages(), abortController, availableTools, @@ -194,7 +203,7 @@ export const useChat = () => { accumulatedText.length === 0 && toolCalls.length === 0 && activeThread.model?.id && - provider.provider === 'llama.cpp' + activeProvider.provider === 'llama.cpp' ) { await stopModel(activeThread.model.id, 'cortex') throw new Error('No response received from the model') @@ -235,6 +244,8 @@ export const useChat = () => { [ getCurrentThread, resetTokenSpeed, + currentProviderId, + getProviderByName, provider, getMessages, setAbortController, @@ -246,11 +257,11 @@ export const useChat = () => { currentAssistant, tools, updateLoadingModel, - updateTokenSpeed, - approvedTools, - showApprovalModal, getDisabledToolsForThread, + approvedTools, allowAllMCPPermissions, + showApprovalModal, + updateTokenSpeed, ] ) diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 927a38669..2f3d227de 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -304,6 +304,17 @@ export const postMessageProcessing = async ( arguments: toolCall.function.arguments.length ? JSON.parse(toolCall.function.arguments) : {}, + }).catch((e) => { + console.error('Tool call failed:', e) + return { + content: [ + { + type: 'text', + text: `Error calling tool ${toolCall.function.name}: ${e.message}`, + }, + ], + error: true, + } }) : { content: [ diff --git a/web-app/src/services/providers.ts b/web-app/src/services/providers.ts index 3b877d39f..96d340cb4 100644 --- a/web-app/src/services/providers.ts +++ b/web-app/src/services/providers.ts @@ -98,13 +98,15 @@ export const getProviders = async (): Promise => { 'inferenceUrl' in value ? (value.inferenceUrl as string).replace('/chat/completions', '') : '', - settings: (await value.getSettings()).map((setting) => ({ - key: setting.key, - title: setting.title, - description: setting.description, - controller_type: setting.controllerType as unknown, - controller_props: setting.controllerProps as unknown, - })) as ProviderSetting[], + settings: (await value.getSettings()).map((setting) => { + return { + key: setting.key, + title: setting.title, + description: setting.description, + controller_type: setting.controllerType as unknown, + controller_props: setting.controllerProps as unknown, + } + }) as ProviderSetting[], models: models.map((model) => ({ id: model.id, model: model.id, @@ -117,9 +119,13 @@ export const getProviders = async (): Promise => { provider: providerName, settings: Object.values(modelSettings).reduce( (acc, setting) => { - const value = model[ + let value = model[ setting.key as keyof typeof model ] as keyof typeof setting.controller_props.value + if (setting.key === 'ctx_len') { + // @ts-expect-error dynamic type + value = 4096 // Default context length for Llama.cpp models + } acc[setting.key] = { ...setting, controller_props: {