chore: fix model settings are not applied accordingly on change (#5231)
* chore: fix model settings are not applied accordingly on change * chore: handle failed tool call * chore: stop inference and model on reject
This commit is contained in:
parent
dcb3f794d3
commit
51a321219d
@ -23,8 +23,8 @@
|
||||
"description": "Number of prompts that can be processed simultaneously by the model.",
|
||||
"controllerType": "input",
|
||||
"controllerProps": {
|
||||
"value": "4",
|
||||
"placeholder": "4",
|
||||
"value": "1",
|
||||
"placeholder": "1",
|
||||
"type": "number",
|
||||
"textAlign": "right"
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
shouldReconnect = true
|
||||
|
||||
/** Default Engine model load settings */
|
||||
n_parallel: number = 4
|
||||
n_parallel?: number
|
||||
cont_batching: boolean = true
|
||||
caching_enabled: boolean = true
|
||||
flash_attn: boolean = true
|
||||
@ -114,8 +114,10 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
// Register Settings
|
||||
this.registerSettings(SETTINGS)
|
||||
|
||||
this.n_parallel =
|
||||
Number(await this.getSetting<string>(Settings.n_parallel, '4')) ?? 4
|
||||
const numParallel = await this.getSetting<string>(Settings.n_parallel, '')
|
||||
if (numParallel.length > 0 && parseInt(numParallel) > 0) {
|
||||
this.n_parallel = parseInt(numParallel)
|
||||
}
|
||||
this.cont_batching = await this.getSetting<boolean>(
|
||||
Settings.cont_batching,
|
||||
true
|
||||
@ -184,7 +186,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
*/
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
if (key === Settings.n_parallel && typeof value === 'string') {
|
||||
this.n_parallel = Number(value) ?? 1
|
||||
if (value.length > 0 && parseInt(value) > 0) {
|
||||
this.n_parallel = parseInt(value)
|
||||
}
|
||||
} else if (key === Settings.cont_batching && typeof value === 'boolean') {
|
||||
this.cont_batching = value as boolean
|
||||
} else if (key === Settings.caching_enabled && typeof value === 'boolean') {
|
||||
|
||||
@ -35,6 +35,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider'
|
||||
import { ModelLoader } from '@/containers/loaders/ModelLoader'
|
||||
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
|
||||
import { getConnectedServers } from '@/services/mcp'
|
||||
import { stopAllModels } from '@/services/models'
|
||||
|
||||
type ChatInputProps = {
|
||||
className?: string
|
||||
@ -161,6 +162,7 @@ const ChatInput = ({
|
||||
const stopStreaming = useCallback(
|
||||
(threadId: string) => {
|
||||
abortControllers[threadId]?.abort()
|
||||
stopAllModels()
|
||||
},
|
||||
[abortControllers]
|
||||
)
|
||||
|
||||
@ -61,6 +61,10 @@ export const useChat = () => {
|
||||
return getProviderByName(selectedProvider)
|
||||
}, [selectedProvider, getProviderByName])
|
||||
|
||||
const currentProviderId = useMemo(() => {
|
||||
return provider?.provider || selectedProvider
|
||||
}, [provider, selectedProvider])
|
||||
|
||||
useEffect(() => {
|
||||
function setTools() {
|
||||
getTools().then((data: MCPTool[]) => {
|
||||
@ -109,7 +113,10 @@ export const useChat = () => {
|
||||
const activeThread = await getCurrentThread()
|
||||
|
||||
resetTokenSpeed()
|
||||
if (!activeThread || !provider) return
|
||||
const activeProvider = currentProviderId
|
||||
? getProviderByName(currentProviderId)
|
||||
: provider
|
||||
if (!activeThread || !activeProvider) return
|
||||
const messages = getMessages(activeThread.id)
|
||||
const abortController = new AbortController()
|
||||
setAbortController(activeThread.id, abortController)
|
||||
@ -120,9 +127,11 @@ export const useChat = () => {
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider, selectedModel.id, abortController).catch(
|
||||
console.error
|
||||
)
|
||||
await startModel(
|
||||
activeProvider,
|
||||
selectedModel.id,
|
||||
abortController
|
||||
).catch(console.error)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
@ -148,7 +157,7 @@ export const useChat = () => {
|
||||
while (!isCompleted && !abortController.signal.aborted) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
activeProvider,
|
||||
builder.getMessages(),
|
||||
abortController,
|
||||
availableTools,
|
||||
@ -194,7 +203,7 @@ export const useChat = () => {
|
||||
accumulatedText.length === 0 &&
|
||||
toolCalls.length === 0 &&
|
||||
activeThread.model?.id &&
|
||||
provider.provider === 'llama.cpp'
|
||||
activeProvider.provider === 'llama.cpp'
|
||||
) {
|
||||
await stopModel(activeThread.model.id, 'cortex')
|
||||
throw new Error('No response received from the model')
|
||||
@ -235,6 +244,8 @@ export const useChat = () => {
|
||||
[
|
||||
getCurrentThread,
|
||||
resetTokenSpeed,
|
||||
currentProviderId,
|
||||
getProviderByName,
|
||||
provider,
|
||||
getMessages,
|
||||
setAbortController,
|
||||
@ -246,11 +257,11 @@ export const useChat = () => {
|
||||
currentAssistant,
|
||||
tools,
|
||||
updateLoadingModel,
|
||||
updateTokenSpeed,
|
||||
approvedTools,
|
||||
showApprovalModal,
|
||||
getDisabledToolsForThread,
|
||||
approvedTools,
|
||||
allowAllMCPPermissions,
|
||||
showApprovalModal,
|
||||
updateTokenSpeed,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -304,6 +304,17 @@ export const postMessageProcessing = async (
|
||||
arguments: toolCall.function.arguments.length
|
||||
? JSON.parse(toolCall.function.arguments)
|
||||
: {},
|
||||
}).catch((e) => {
|
||||
console.error('Tool call failed:', e)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Error calling tool ${toolCall.function.name}: ${e.message}`,
|
||||
},
|
||||
],
|
||||
error: true,
|
||||
}
|
||||
})
|
||||
: {
|
||||
content: [
|
||||
|
||||
@ -98,13 +98,15 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
'inferenceUrl' in value
|
||||
? (value.inferenceUrl as string).replace('/chat/completions', '')
|
||||
: '',
|
||||
settings: (await value.getSettings()).map((setting) => ({
|
||||
settings: (await value.getSettings()).map((setting) => {
|
||||
return {
|
||||
key: setting.key,
|
||||
title: setting.title,
|
||||
description: setting.description,
|
||||
controller_type: setting.controllerType as unknown,
|
||||
controller_props: setting.controllerProps as unknown,
|
||||
})) as ProviderSetting[],
|
||||
}
|
||||
}) as ProviderSetting[],
|
||||
models: models.map((model) => ({
|
||||
id: model.id,
|
||||
model: model.id,
|
||||
@ -117,9 +119,13 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
provider: providerName,
|
||||
settings: Object.values(modelSettings).reduce(
|
||||
(acc, setting) => {
|
||||
const value = model[
|
||||
let value = model[
|
||||
setting.key as keyof typeof model
|
||||
] as keyof typeof setting.controller_props.value
|
||||
if (setting.key === 'ctx_len') {
|
||||
// @ts-expect-error dynamic type
|
||||
value = 4096 // Default context length for Llama.cpp models
|
||||
}
|
||||
acc[setting.key] = {
|
||||
...setting,
|
||||
controller_props: {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user