From 942f2f51b7277b6f380f7d4c95634c3d2ede973c Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 22 May 2025 20:13:50 +0700 Subject: [PATCH] chore: send chat completion with messages history (#5070) * chore: send chat completion with messages history * chore: handle abort controllers * chore: change max attempts setting * chore: handle stop running models in system monitor screen * Update web-app/src/services/models.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * chore: format time * chore: handle stop model load action --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- .../browser/extensions/engines/AIEngine.ts | 2 +- .../extensions/engines/LocalOAIEngine.ts | 2 +- .../inference-cortex-extension/src/index.ts | 8 ++---- web-app/src/containers/ChatInput.tsx | 2 +- web-app/src/hooks/useChat.ts | 28 +++++++++++++------ web-app/src/lib/completion.ts | 28 +++++++++++-------- web-app/src/lib/messages.ts | 11 ++++++-- web-app/src/routes/system-monitor.tsx | 28 +++++++++++++++---- web-app/src/services/models.ts | 23 +++++++++++++++ 9 files changed, 97 insertions(+), 35 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 25f83184b..4f96eb93a 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -31,7 +31,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads the model. */ - async loadModel(model: Partial): Promise { + async loadModel(model: Partial, abortController?: AbortController): Promise { if (model?.engine?.toString() !== this.provider) return Promise.resolve() events.emit(ModelEvent.OnModelReady, model) return Promise.resolve() diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index b54f8fbde..026c5b2fe 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -29,7 +29,7 @@ export abstract class LocalOAIEngine extends OAIEngine { /** * Load the model. */ - override async loadModel(model: Model & { file_path?: string }): Promise { + override async loadModel(model: Model & { file_path?: string }, abortController?: AbortController): Promise { if (model.engine.toString() !== this.provider) return const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id) const systemInfo = await systemInformation() diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 85c760b09..a9fce82b6 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -184,13 +184,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { id: string settings?: object file_path?: string - } + }, + abortController: AbortController ): Promise { // Cortex will handle these settings const { llama_model_path, mmproj, ...settings } = model.settings ?? {} model.settings = settings - const controller = new AbortController() + const controller = abortController ?? new AbortController() const { signal } = controller this.abortControllers.set(model.id, controller) @@ -292,7 +293,6 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { * Subscribe to cortex.cpp websocket events */ private subscribeToEvents() { - console.log('Subscribing to events...') this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`) this.socket.addEventListener('message', (event) => { @@ -341,13 +341,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { * This is to handle the server segfault issue */ this.socket.onclose = (event) => { - console.log('WebSocket closed:', event) // Notify app to update model running state events.emit(ModelEvent.OnModelStopped, {}) // Reconnect to the /events websocket if (this.shouldReconnect) { - console.log(`Attempting to reconnect...`) setTimeout(() => this.subscribeToEvents(), 1000) } } diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 24b1bf57c..a8ceaed04 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -272,7 +272,7 @@ const ChatInput = ({ diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 4af528dcf..7756442a8 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -33,7 +33,7 @@ export const useChat = () => { const { getCurrentThread: retrieveThread, createThread } = useThreads() const { updateStreamingContent, updateLoadingModel, setAbortController } = useAppState() - const { addMessage } = useMessages() + const { getMessages, addMessage } = useMessages() const router = useRouter() const provider = useMemo(() => { @@ -73,18 +73,22 @@ export const useChat = () => { resetTokenSpeed() if (!activeThread || !provider) return - + const messages = getMessages(activeThread.id) + const abortController = new AbortController() + setAbortController(activeThread.id, abortController) updateStreamingContent(emptyThreadContent) addMessage(newUserThreadContent(activeThread.id, message)) setPrompt('') try { if (selectedModel?.id) { updateLoadingModel(true) - await startModel(provider, selectedModel.id).catch(console.error) + await startModel(provider, selectedModel.id, abortController).catch( + console.error + ) updateLoadingModel(false) } - const builder = new CompletionMessagesBuilder() + const builder = new CompletionMessagesBuilder(messages) if (currentAssistant?.instructions?.length > 0) builder.addSystemMessage(currentAssistant?.instructions || '') // REMARK: Would it possible to not attach the entire message history to the request? @@ -92,9 +96,15 @@ export const useChat = () => { builder.addUserMessage(message) let isCompleted = false - const abortController = new AbortController() - setAbortController(activeThread.id, abortController) - while (!isCompleted) { + + let attempts = 0 + while ( + !isCompleted && + !abortController.signal.aborted && + // TODO: Max attempts can be set in the provider settings later + attempts < 10 + ) { + attempts += 1 const completion = await sendCompletion( activeThread, provider, @@ -143,7 +153,8 @@ export const useChat = () => { const updatedMessage = await postMessageProcessing( toolCalls, builder, - finalContent + finalContent, + abortController ) addMessage(updatedMessage ?? finalContent) @@ -163,6 +174,7 @@ export const useChat = () => { getCurrentThread, resetTokenSpeed, provider, + getMessages, updateStreamingContent, addMessage, setPrompt, diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 92a17f321..b8cb532aa 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -171,22 +171,26 @@ export const isCompletionResponse = ( */ export const startModel = async ( provider: ProviderObject, - model: string + model: string, + abortController?: AbortController ): Promise => { const providerObj = EngineManager.instance().get( normalizeProvider(provider.provider) ) const modelObj = provider.models.find((m) => m.id === model) if (providerObj && modelObj) - return providerObj?.loadModel({ - id: modelObj.id, - settings: Object.fromEntries( - Object.entries(modelObj.settings ?? {}).map(([key, value]) => [ - key, - value.controller_props?.value, // assuming each setting is { value: ... } - ]) - ), - }) + return providerObj?.loadModel( + { + id: modelObj.id, + settings: Object.fromEntries( + Object.entries(modelObj.settings ?? {}).map(([key, value]) => [ + key, + value.controller_props?.value, // assuming each setting is { value: ... } + ]) + ), + }, + abortController + ) } /** @@ -279,11 +283,13 @@ export const extractToolCall = ( export const postMessageProcessing = async ( calls: ChatCompletionMessageToolCall[], builder: CompletionMessagesBuilder, - message: ThreadMessage + message: ThreadMessage, + abortController: AbortController ) => { // Handle completed tool calls if (calls.length) { for (const toolCall of calls) { + if (abortController.signal.aborted) break const toolId = ulid() const toolCallsMetadata = message.metadata?.tool_calls && diff --git a/web-app/src/lib/messages.ts b/web-app/src/lib/messages.ts index 2f46b8eb0..ddd350ac0 100644 --- a/web-app/src/lib/messages.ts +++ b/web-app/src/lib/messages.ts @@ -1,5 +1,6 @@ import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageToolCall } from 'openai/resources' +import { ThreadMessage } from '@janhq/core' /** * @fileoverview Helper functions for creating chat completion request. @@ -8,8 +9,14 @@ import { ChatCompletionMessageToolCall } from 'openai/resources' export class CompletionMessagesBuilder { private messages: ChatCompletionMessageParam[] = [] - constructor() {} - + constructor(messages: ThreadMessage[]) { + this.messages = messages + .filter((e) => !e.metadata?.error) + .map((msg) => ({ + role: msg.role, + content: msg.content[0]?.text?.value ?? '.', + }) as ChatCompletionMessageParam) + } /** * Add a system message to the messages array. * @param content - The content of the system message. diff --git a/web-app/src/routes/system-monitor.tsx b/web-app/src/routes/system-monitor.tsx index 4ea82d035..db4ed12ad 100644 --- a/web-app/src/routes/system-monitor.tsx +++ b/web-app/src/routes/system-monitor.tsx @@ -7,8 +7,9 @@ import type { HardwareData } from '@/hooks/useHardware' import { route } from '@/constants/routes' import { formatDuration, formatMegaBytes } from '@/lib/utils' import { IconDeviceDesktopAnalytics } from '@tabler/icons-react' -import { getActiveModels } from '@/services/models' +import { getActiveModels, stopModel } from '@/services/models' import { ActiveModel } from '@/types/models' +import { Button } from '@/components/ui/button' // eslint-disable-next-line @typescript-eslint/no-explicit-any export const Route = createFileRoute(route.systemMonitor as any)({ @@ -40,6 +41,18 @@ function SystemMonitor() { return () => clearInterval(intervalId) }, [setHardwareData, setActiveModels, updateCPUUsage, updateRAMAvailable]) + const stopRunningModel = (modelId: string) => { + stopModel(modelId) + .then(() => { + setActiveModels((prevModels) => + prevModels.filter((model) => model.id !== modelId) + ) + }) + .catch((error) => { + console.error('Error stopping model:', error) + }) + } + // Calculate RAM usage percentage const ramUsagePercentage = ((hardwareData.ram.total - hardwareData.ram.available) / @@ -154,15 +167,18 @@ function SystemMonitor() {
Uptime - {formatDuration(model.start_time)} + {model.start_time && formatDuration(model.start_time)}
- Status + Actions -
- Running -
+
diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index 19550ede9..5ea7c3f47 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -226,6 +226,29 @@ export const getActiveModels = async (provider?: string) => { } } +/** + * Stops a model for a given provider. + * @param model + * @param provider + * @returns + */ +export const stopModel = async (model: string, provider?: string) => { + const providerName = provider || 'cortex' // we will go down to llama.cpp extension later on + const extension = EngineManager.instance().get(providerName) + + if (!extension) throw new Error('Model extension not found') + + try { + return await extension.unloadModel({ + model, + id: model, + }) + } catch (error) { + console.error('Failed to stop model:', error) + return [] + } +} + /** * Configures the proxy options for model downloads. * @param param0