chore: send chat completion with messages history (#5070)
* chore: send chat completion with messages history * chore: handle abort controllers * chore: change max attempts setting * chore: handle stop running models in system monitor screen * Update web-app/src/services/models.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * chore: format time * chore: handle stop model load action --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
570bb8290f
commit
942f2f51b7
@ -31,7 +31,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Loads the model.
|
||||
*/
|
||||
async loadModel(model: Partial<Model>): Promise<any> {
|
||||
async loadModel(model: Partial<Model>, abortController?: AbortController): Promise<any> {
|
||||
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
|
||||
@ -29,7 +29,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
override async loadModel(model: Model & { file_path?: string }): Promise<void> {
|
||||
override async loadModel(model: Model & { file_path?: string }, abortController?: AbortController): Promise<void> {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id)
|
||||
const systemInfo = await systemInformation()
|
||||
|
||||
@ -184,13 +184,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
id: string
|
||||
settings?: object
|
||||
file_path?: string
|
||||
}
|
||||
},
|
||||
abortController: AbortController
|
||||
): Promise<void> {
|
||||
// Cortex will handle these settings
|
||||
const { llama_model_path, mmproj, ...settings } = model.settings ?? {}
|
||||
model.settings = settings
|
||||
|
||||
const controller = new AbortController()
|
||||
const controller = abortController ?? new AbortController()
|
||||
const { signal } = controller
|
||||
|
||||
this.abortControllers.set(model.id, controller)
|
||||
@ -292,7 +293,6 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
* Subscribe to cortex.cpp websocket events
|
||||
*/
|
||||
private subscribeToEvents() {
|
||||
console.log('Subscribing to events...')
|
||||
this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`)
|
||||
|
||||
this.socket.addEventListener('message', (event) => {
|
||||
@ -341,13 +341,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
* This is to handle the server segfault issue
|
||||
*/
|
||||
this.socket.onclose = (event) => {
|
||||
console.log('WebSocket closed:', event)
|
||||
// Notify app to update model running state
|
||||
events.emit(ModelEvent.OnModelStopped, {})
|
||||
|
||||
// Reconnect to the /events websocket
|
||||
if (this.shouldReconnect) {
|
||||
console.log(`Attempting to reconnect...`)
|
||||
setTimeout(() => this.subscribeToEvents(), 1000)
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,7 +272,7 @@ const ChatInput = ({
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="icon"
|
||||
onClick={() => stopStreaming(streamingContent.thread_id)}
|
||||
onClick={() => stopStreaming(currentThreadId ?? streamingContent.thread_id)}
|
||||
>
|
||||
<IconPlayerStopFilled />
|
||||
</Button>
|
||||
|
||||
@ -33,7 +33,7 @@ export const useChat = () => {
|
||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||
const { updateStreamingContent, updateLoadingModel, setAbortController } =
|
||||
useAppState()
|
||||
const { addMessage } = useMessages()
|
||||
const { getMessages, addMessage } = useMessages()
|
||||
const router = useRouter()
|
||||
|
||||
const provider = useMemo(() => {
|
||||
@ -73,18 +73,22 @@ export const useChat = () => {
|
||||
|
||||
resetTokenSpeed()
|
||||
if (!activeThread || !provider) return
|
||||
|
||||
const messages = getMessages(activeThread.id)
|
||||
const abortController = new AbortController()
|
||||
setAbortController(activeThread.id, abortController)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
addMessage(newUserThreadContent(activeThread.id, message))
|
||||
setPrompt('')
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await startModel(provider, selectedModel.id).catch(console.error)
|
||||
await startModel(provider, selectedModel.id, abortController).catch(
|
||||
console.error
|
||||
)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
|
||||
const builder = new CompletionMessagesBuilder()
|
||||
const builder = new CompletionMessagesBuilder(messages)
|
||||
if (currentAssistant?.instructions?.length > 0)
|
||||
builder.addSystemMessage(currentAssistant?.instructions || '')
|
||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||
@ -92,9 +96,15 @@ export const useChat = () => {
|
||||
builder.addUserMessage(message)
|
||||
|
||||
let isCompleted = false
|
||||
const abortController = new AbortController()
|
||||
setAbortController(activeThread.id, abortController)
|
||||
while (!isCompleted) {
|
||||
|
||||
let attempts = 0
|
||||
while (
|
||||
!isCompleted &&
|
||||
!abortController.signal.aborted &&
|
||||
// TODO: Max attempts can be set in the provider settings later
|
||||
attempts < 10
|
||||
) {
|
||||
attempts += 1
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
@ -143,7 +153,8 @@ export const useChat = () => {
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent
|
||||
finalContent,
|
||||
abortController
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
|
||||
@ -163,6 +174,7 @@ export const useChat = () => {
|
||||
getCurrentThread,
|
||||
resetTokenSpeed,
|
||||
provider,
|
||||
getMessages,
|
||||
updateStreamingContent,
|
||||
addMessage,
|
||||
setPrompt,
|
||||
|
||||
@ -171,22 +171,26 @@ export const isCompletionResponse = (
|
||||
*/
|
||||
export const startModel = async (
|
||||
provider: ProviderObject,
|
||||
model: string
|
||||
model: string,
|
||||
abortController?: AbortController
|
||||
): Promise<void> => {
|
||||
const providerObj = EngineManager.instance().get(
|
||||
normalizeProvider(provider.provider)
|
||||
)
|
||||
const modelObj = provider.models.find((m) => m.id === model)
|
||||
if (providerObj && modelObj)
|
||||
return providerObj?.loadModel({
|
||||
id: modelObj.id,
|
||||
settings: Object.fromEntries(
|
||||
Object.entries(modelObj.settings ?? {}).map(([key, value]) => [
|
||||
key,
|
||||
value.controller_props?.value, // assuming each setting is { value: ... }
|
||||
])
|
||||
),
|
||||
})
|
||||
return providerObj?.loadModel(
|
||||
{
|
||||
id: modelObj.id,
|
||||
settings: Object.fromEntries(
|
||||
Object.entries(modelObj.settings ?? {}).map(([key, value]) => [
|
||||
key,
|
||||
value.controller_props?.value, // assuming each setting is { value: ... }
|
||||
])
|
||||
),
|
||||
},
|
||||
abortController
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -279,11 +283,13 @@ export const extractToolCall = (
|
||||
export const postMessageProcessing = async (
|
||||
calls: ChatCompletionMessageToolCall[],
|
||||
builder: CompletionMessagesBuilder,
|
||||
message: ThreadMessage
|
||||
message: ThreadMessage,
|
||||
abortController: AbortController
|
||||
) => {
|
||||
// Handle completed tool calls
|
||||
if (calls.length) {
|
||||
for (const toolCall of calls) {
|
||||
if (abortController.signal.aborted) break
|
||||
const toolId = ulid()
|
||||
const toolCallsMetadata =
|
||||
message.metadata?.tool_calls &&
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { ChatCompletionMessageParam } from 'token.js'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { ThreadMessage } from '@janhq/core'
|
||||
|
||||
/**
|
||||
* @fileoverview Helper functions for creating chat completion request.
|
||||
@ -8,8 +9,14 @@ import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
export class CompletionMessagesBuilder {
|
||||
private messages: ChatCompletionMessageParam[] = []
|
||||
|
||||
constructor() {}
|
||||
|
||||
constructor(messages: ThreadMessage[]) {
|
||||
this.messages = messages
|
||||
.filter((e) => !e.metadata?.error)
|
||||
.map<ChatCompletionMessageParam>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text?.value ?? '.',
|
||||
}) as ChatCompletionMessageParam)
|
||||
}
|
||||
/**
|
||||
* Add a system message to the messages array.
|
||||
* @param content - The content of the system message.
|
||||
|
||||
@ -7,8 +7,9 @@ import type { HardwareData } from '@/hooks/useHardware'
|
||||
import { route } from '@/constants/routes'
|
||||
import { formatDuration, formatMegaBytes } from '@/lib/utils'
|
||||
import { IconDeviceDesktopAnalytics } from '@tabler/icons-react'
|
||||
import { getActiveModels } from '@/services/models'
|
||||
import { getActiveModels, stopModel } from '@/services/models'
|
||||
import { ActiveModel } from '@/types/models'
|
||||
import { Button } from '@/components/ui/button'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const Route = createFileRoute(route.systemMonitor as any)({
|
||||
@ -40,6 +41,18 @@ function SystemMonitor() {
|
||||
return () => clearInterval(intervalId)
|
||||
}, [setHardwareData, setActiveModels, updateCPUUsage, updateRAMAvailable])
|
||||
|
||||
const stopRunningModel = (modelId: string) => {
|
||||
stopModel(modelId)
|
||||
.then(() => {
|
||||
setActiveModels((prevModels) =>
|
||||
prevModels.filter((model) => model.id !== modelId)
|
||||
)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error stopping model:', error)
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate RAM usage percentage
|
||||
const ramUsagePercentage =
|
||||
((hardwareData.ram.total - hardwareData.ram.available) /
|
||||
@ -154,15 +167,18 @@ function SystemMonitor() {
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-main-view-fg/70">Uptime</span>
|
||||
<span className="text-main-view-fg">
|
||||
{formatDuration(model.start_time)}
|
||||
{model.start_time && formatDuration(model.start_time)}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-main-view-fg/70">Status</span>
|
||||
<span className="text-main-view-fg/70">Actions</span>
|
||||
<span className="text-main-view-fg">
|
||||
<div className="bg-green-500/20 px-1 font-bold py-0.5 rounded text-green-700 text-xs">
|
||||
Running
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={() => stopRunningModel(model.id)}
|
||||
>
|
||||
Stop
|
||||
</Button>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -226,6 +226,29 @@ export const getActiveModels = async (provider?: string) => {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops a model for a given provider.
|
||||
* @param model
|
||||
* @param provider
|
||||
* @returns
|
||||
*/
|
||||
export const stopModel = async (model: string, provider?: string) => {
|
||||
const providerName = provider || 'cortex' // we will go down to llama.cpp extension later on
|
||||
const extension = EngineManager.instance().get(providerName)
|
||||
|
||||
if (!extension) throw new Error('Model extension not found')
|
||||
|
||||
try {
|
||||
return await extension.unloadModel({
|
||||
model,
|
||||
id: model,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Failed to stop model:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the proxy options for model downloads.
|
||||
* @param param0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user