diff --git a/extensions/inference-cortex-extension/resources/default_settings.json b/extensions/inference-cortex-extension/resources/default_settings.json index f9f1bf651..c7c88b842 100644 --- a/extensions/inference-cortex-extension/resources/default_settings.json +++ b/extensions/inference-cortex-extension/resources/default_settings.json @@ -1,4 +1,13 @@ [ + { + "key": "auto_unload_models", + "title": "Auto-Unload Old Models", + "description": "Automatically unloads models that are not in use to free up memory. Ensure only one model is loaded at a time.", + "controllerType": "checkbox", + "controllerProps": { + "value": true + } + }, { "key": "cont_batching", "title": "Continuous Batching", diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index a9fce82b6..befec36b0 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -37,6 +37,7 @@ enum Settings { use_mmap = 'use_mmap', cpu_threads = 'cpu_threads', huggingfaceToken = 'hugging-face-access-token', + auto_unload_models = 'auto_unload_models', } type LoadedModelResponse = { data: { engine: string; id: string }[] } @@ -61,7 +62,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { use_mmap: boolean = true cache_type: string = 'f16' cpu_threads?: number - + auto_unload_models: boolean = true /** * The URL for making inference requests. */ @@ -126,6 +127,10 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { this.flash_attn = await this.getSetting(Settings.flash_attn, true) this.use_mmap = await this.getSetting(Settings.use_mmap, true) this.cache_type = await this.getSetting(Settings.cache_type, 'f16') + this.auto_unload_models = await this.getSetting( + Settings.auto_unload_models, + true + ) const threads_number = Number( await this.getSetting(Settings.cpu_threads, '') ) @@ -176,6 +181,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number } else if (key === Settings.huggingfaceToken) { this.updateCortexConfig({ huggingface_token: value }) + } else if (key === Settings.auto_unload_models) { + this.auto_unload_models = value as boolean } } @@ -205,7 +212,15 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { console.log(`Model ${model.id} already loaded`) return } - + if (this.auto_unload_models) { + // Unload the last used model if it is not the same as the current one + for (const lastUsedModel of loadedModels) { + if (lastUsedModel.id !== model.id) { + console.log(`Unloading last used model: ${lastUsedModel.id}`) + await this.unloadModel(lastUsedModel as Model) + } + } + } return await this.apiInstance().then((api) => api .post('v1/models/start', { diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index fcda8d0e5..f2ccfe022 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -126,11 +126,9 @@ export const useChat = () => { let availableTools = selectedModel?.capabilities?.includes('tools') ? tools : [] - while ( - !isCompleted && - !abortController.signal.aborted - // TODO: Max attempts can be set in the provider settings later - ) { + // TODO: Later replaced by Agent setup? + const followUpWithToolUse = true + while (!isCompleted && !abortController.signal.aborted) { const completion = await sendCompletion( activeThread, provider, @@ -200,7 +198,8 @@ export const useChat = () => { addMessage(updatedMessage ?? finalContent) isCompleted = !toolCalls.length - availableTools = [] + // Do not create agent loop if there is no need for it + if (!followUpWithToolUse) availableTools = [] } } catch (error) { toast.error( diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 6546efe00..c436ef956 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -308,7 +308,7 @@ export const postMessageProcessing = async ( } builder.addToolMessage(result.content[0]?.text ?? '', toolCall.id) // update message metadata - return message } + return message } } diff --git a/web-app/src/services/providers.ts b/web-app/src/services/providers.ts index a9bd16e9f..80ea03282 100644 --- a/web-app/src/services/providers.ts +++ b/web-app/src/services/providers.ts @@ -77,7 +77,22 @@ export const getProviders = async (): Promise => { ? (model.capabilities as string[]) : [ModelCapabilities.COMPLETION], provider: providerName, - settings: modelSettings , + settings: Object.values(modelSettings).reduce( + (acc, setting) => { + const value = model[ + setting.key as keyof typeof model + ] as keyof typeof setting.controller_props.value + acc[setting.key] = { + ...setting, + controller_props: { + ...setting.controller_props, + value: value ?? setting.controller_props.value, + }, + } + return acc + }, + {} as Record + ), })), } runtimeProviders.push(provider)