chore: send chat completion with messages history (#5070)
* chore: send chat completion with messages history * chore: handle abort controllers * chore: change max attempts setting * chore: handle stop running models in system monitor screen * Update web-app/src/services/models.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * chore: format time * chore: handle stop model load action --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
570bb8290f
commit
942f2f51b7
@ -31,7 +31,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Loads the model.
|
* Loads the model.
|
||||||
*/
|
*/
|
||||||
async loadModel(model: Partial<Model>): Promise<any> {
|
async loadModel(model: Partial<Model>, abortController?: AbortController): Promise<any> {
|
||||||
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
|
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
|
|||||||
@ -29,7 +29,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
/**
|
/**
|
||||||
* Load the model.
|
* Load the model.
|
||||||
*/
|
*/
|
||||||
override async loadModel(model: Model & { file_path?: string }): Promise<void> {
|
override async loadModel(model: Model & { file_path?: string }, abortController?: AbortController): Promise<void> {
|
||||||
if (model.engine.toString() !== this.provider) return
|
if (model.engine.toString() !== this.provider) return
|
||||||
const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id)
|
const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id)
|
||||||
const systemInfo = await systemInformation()
|
const systemInfo = await systemInformation()
|
||||||
|
|||||||
@ -184,13 +184,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
id: string
|
id: string
|
||||||
settings?: object
|
settings?: object
|
||||||
file_path?: string
|
file_path?: string
|
||||||
}
|
},
|
||||||
|
abortController: AbortController
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// Cortex will handle these settings
|
// Cortex will handle these settings
|
||||||
const { llama_model_path, mmproj, ...settings } = model.settings ?? {}
|
const { llama_model_path, mmproj, ...settings } = model.settings ?? {}
|
||||||
model.settings = settings
|
model.settings = settings
|
||||||
|
|
||||||
const controller = new AbortController()
|
const controller = abortController ?? new AbortController()
|
||||||
const { signal } = controller
|
const { signal } = controller
|
||||||
|
|
||||||
this.abortControllers.set(model.id, controller)
|
this.abortControllers.set(model.id, controller)
|
||||||
@ -292,7 +293,6 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
* Subscribe to cortex.cpp websocket events
|
* Subscribe to cortex.cpp websocket events
|
||||||
*/
|
*/
|
||||||
private subscribeToEvents() {
|
private subscribeToEvents() {
|
||||||
console.log('Subscribing to events...')
|
|
||||||
this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`)
|
this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`)
|
||||||
|
|
||||||
this.socket.addEventListener('message', (event) => {
|
this.socket.addEventListener('message', (event) => {
|
||||||
@ -341,13 +341,11 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
* This is to handle the server segfault issue
|
* This is to handle the server segfault issue
|
||||||
*/
|
*/
|
||||||
this.socket.onclose = (event) => {
|
this.socket.onclose = (event) => {
|
||||||
console.log('WebSocket closed:', event)
|
|
||||||
// Notify app to update model running state
|
// Notify app to update model running state
|
||||||
events.emit(ModelEvent.OnModelStopped, {})
|
events.emit(ModelEvent.OnModelStopped, {})
|
||||||
|
|
||||||
// Reconnect to the /events websocket
|
// Reconnect to the /events websocket
|
||||||
if (this.shouldReconnect) {
|
if (this.shouldReconnect) {
|
||||||
console.log(`Attempting to reconnect...`)
|
|
||||||
setTimeout(() => this.subscribeToEvents(), 1000)
|
setTimeout(() => this.subscribeToEvents(), 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -272,7 +272,7 @@ const ChatInput = ({
|
|||||||
<Button
|
<Button
|
||||||
variant="destructive"
|
variant="destructive"
|
||||||
size="icon"
|
size="icon"
|
||||||
onClick={() => stopStreaming(streamingContent.thread_id)}
|
onClick={() => stopStreaming(currentThreadId ?? streamingContent.thread_id)}
|
||||||
>
|
>
|
||||||
<IconPlayerStopFilled />
|
<IconPlayerStopFilled />
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@ -33,7 +33,7 @@ export const useChat = () => {
|
|||||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||||
const { updateStreamingContent, updateLoadingModel, setAbortController } =
|
const { updateStreamingContent, updateLoadingModel, setAbortController } =
|
||||||
useAppState()
|
useAppState()
|
||||||
const { addMessage } = useMessages()
|
const { getMessages, addMessage } = useMessages()
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
|
|
||||||
const provider = useMemo(() => {
|
const provider = useMemo(() => {
|
||||||
@ -73,18 +73,22 @@ export const useChat = () => {
|
|||||||
|
|
||||||
resetTokenSpeed()
|
resetTokenSpeed()
|
||||||
if (!activeThread || !provider) return
|
if (!activeThread || !provider) return
|
||||||
|
const messages = getMessages(activeThread.id)
|
||||||
|
const abortController = new AbortController()
|
||||||
|
setAbortController(activeThread.id, abortController)
|
||||||
updateStreamingContent(emptyThreadContent)
|
updateStreamingContent(emptyThreadContent)
|
||||||
addMessage(newUserThreadContent(activeThread.id, message))
|
addMessage(newUserThreadContent(activeThread.id, message))
|
||||||
setPrompt('')
|
setPrompt('')
|
||||||
try {
|
try {
|
||||||
if (selectedModel?.id) {
|
if (selectedModel?.id) {
|
||||||
updateLoadingModel(true)
|
updateLoadingModel(true)
|
||||||
await startModel(provider, selectedModel.id).catch(console.error)
|
await startModel(provider, selectedModel.id, abortController).catch(
|
||||||
|
console.error
|
||||||
|
)
|
||||||
updateLoadingModel(false)
|
updateLoadingModel(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
const builder = new CompletionMessagesBuilder()
|
const builder = new CompletionMessagesBuilder(messages)
|
||||||
if (currentAssistant?.instructions?.length > 0)
|
if (currentAssistant?.instructions?.length > 0)
|
||||||
builder.addSystemMessage(currentAssistant?.instructions || '')
|
builder.addSystemMessage(currentAssistant?.instructions || '')
|
||||||
// REMARK: Would it possible to not attach the entire message history to the request?
|
// REMARK: Would it possible to not attach the entire message history to the request?
|
||||||
@ -92,9 +96,15 @@ export const useChat = () => {
|
|||||||
builder.addUserMessage(message)
|
builder.addUserMessage(message)
|
||||||
|
|
||||||
let isCompleted = false
|
let isCompleted = false
|
||||||
const abortController = new AbortController()
|
|
||||||
setAbortController(activeThread.id, abortController)
|
let attempts = 0
|
||||||
while (!isCompleted) {
|
while (
|
||||||
|
!isCompleted &&
|
||||||
|
!abortController.signal.aborted &&
|
||||||
|
// TODO: Max attempts can be set in the provider settings later
|
||||||
|
attempts < 10
|
||||||
|
) {
|
||||||
|
attempts += 1
|
||||||
const completion = await sendCompletion(
|
const completion = await sendCompletion(
|
||||||
activeThread,
|
activeThread,
|
||||||
provider,
|
provider,
|
||||||
@ -143,7 +153,8 @@ export const useChat = () => {
|
|||||||
const updatedMessage = await postMessageProcessing(
|
const updatedMessage = await postMessageProcessing(
|
||||||
toolCalls,
|
toolCalls,
|
||||||
builder,
|
builder,
|
||||||
finalContent
|
finalContent,
|
||||||
|
abortController
|
||||||
)
|
)
|
||||||
addMessage(updatedMessage ?? finalContent)
|
addMessage(updatedMessage ?? finalContent)
|
||||||
|
|
||||||
@ -163,6 +174,7 @@ export const useChat = () => {
|
|||||||
getCurrentThread,
|
getCurrentThread,
|
||||||
resetTokenSpeed,
|
resetTokenSpeed,
|
||||||
provider,
|
provider,
|
||||||
|
getMessages,
|
||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
addMessage,
|
addMessage,
|
||||||
setPrompt,
|
setPrompt,
|
||||||
|
|||||||
@ -171,14 +171,16 @@ export const isCompletionResponse = (
|
|||||||
*/
|
*/
|
||||||
export const startModel = async (
|
export const startModel = async (
|
||||||
provider: ProviderObject,
|
provider: ProviderObject,
|
||||||
model: string
|
model: string,
|
||||||
|
abortController?: AbortController
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const providerObj = EngineManager.instance().get(
|
const providerObj = EngineManager.instance().get(
|
||||||
normalizeProvider(provider.provider)
|
normalizeProvider(provider.provider)
|
||||||
)
|
)
|
||||||
const modelObj = provider.models.find((m) => m.id === model)
|
const modelObj = provider.models.find((m) => m.id === model)
|
||||||
if (providerObj && modelObj)
|
if (providerObj && modelObj)
|
||||||
return providerObj?.loadModel({
|
return providerObj?.loadModel(
|
||||||
|
{
|
||||||
id: modelObj.id,
|
id: modelObj.id,
|
||||||
settings: Object.fromEntries(
|
settings: Object.fromEntries(
|
||||||
Object.entries(modelObj.settings ?? {}).map(([key, value]) => [
|
Object.entries(modelObj.settings ?? {}).map(([key, value]) => [
|
||||||
@ -186,7 +188,9 @@ export const startModel = async (
|
|||||||
value.controller_props?.value, // assuming each setting is { value: ... }
|
value.controller_props?.value, // assuming each setting is { value: ... }
|
||||||
])
|
])
|
||||||
),
|
),
|
||||||
})
|
},
|
||||||
|
abortController
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -279,11 +283,13 @@ export const extractToolCall = (
|
|||||||
export const postMessageProcessing = async (
|
export const postMessageProcessing = async (
|
||||||
calls: ChatCompletionMessageToolCall[],
|
calls: ChatCompletionMessageToolCall[],
|
||||||
builder: CompletionMessagesBuilder,
|
builder: CompletionMessagesBuilder,
|
||||||
message: ThreadMessage
|
message: ThreadMessage,
|
||||||
|
abortController: AbortController
|
||||||
) => {
|
) => {
|
||||||
// Handle completed tool calls
|
// Handle completed tool calls
|
||||||
if (calls.length) {
|
if (calls.length) {
|
||||||
for (const toolCall of calls) {
|
for (const toolCall of calls) {
|
||||||
|
if (abortController.signal.aborted) break
|
||||||
const toolId = ulid()
|
const toolId = ulid()
|
||||||
const toolCallsMetadata =
|
const toolCallsMetadata =
|
||||||
message.metadata?.tool_calls &&
|
message.metadata?.tool_calls &&
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { ChatCompletionMessageParam } from 'token.js'
|
import { ChatCompletionMessageParam } from 'token.js'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
|
import { ThreadMessage } from '@janhq/core'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @fileoverview Helper functions for creating chat completion request.
|
* @fileoverview Helper functions for creating chat completion request.
|
||||||
@ -8,8 +9,14 @@ import { ChatCompletionMessageToolCall } from 'openai/resources'
|
|||||||
export class CompletionMessagesBuilder {
|
export class CompletionMessagesBuilder {
|
||||||
private messages: ChatCompletionMessageParam[] = []
|
private messages: ChatCompletionMessageParam[] = []
|
||||||
|
|
||||||
constructor() {}
|
constructor(messages: ThreadMessage[]) {
|
||||||
|
this.messages = messages
|
||||||
|
.filter((e) => !e.metadata?.error)
|
||||||
|
.map<ChatCompletionMessageParam>((msg) => ({
|
||||||
|
role: msg.role,
|
||||||
|
content: msg.content[0]?.text?.value ?? '.',
|
||||||
|
}) as ChatCompletionMessageParam)
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Add a system message to the messages array.
|
* Add a system message to the messages array.
|
||||||
* @param content - The content of the system message.
|
* @param content - The content of the system message.
|
||||||
|
|||||||
@ -7,8 +7,9 @@ import type { HardwareData } from '@/hooks/useHardware'
|
|||||||
import { route } from '@/constants/routes'
|
import { route } from '@/constants/routes'
|
||||||
import { formatDuration, formatMegaBytes } from '@/lib/utils'
|
import { formatDuration, formatMegaBytes } from '@/lib/utils'
|
||||||
import { IconDeviceDesktopAnalytics } from '@tabler/icons-react'
|
import { IconDeviceDesktopAnalytics } from '@tabler/icons-react'
|
||||||
import { getActiveModels } from '@/services/models'
|
import { getActiveModels, stopModel } from '@/services/models'
|
||||||
import { ActiveModel } from '@/types/models'
|
import { ActiveModel } from '@/types/models'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
export const Route = createFileRoute(route.systemMonitor as any)({
|
export const Route = createFileRoute(route.systemMonitor as any)({
|
||||||
@ -40,6 +41,18 @@ function SystemMonitor() {
|
|||||||
return () => clearInterval(intervalId)
|
return () => clearInterval(intervalId)
|
||||||
}, [setHardwareData, setActiveModels, updateCPUUsage, updateRAMAvailable])
|
}, [setHardwareData, setActiveModels, updateCPUUsage, updateRAMAvailable])
|
||||||
|
|
||||||
|
const stopRunningModel = (modelId: string) => {
|
||||||
|
stopModel(modelId)
|
||||||
|
.then(() => {
|
||||||
|
setActiveModels((prevModels) =>
|
||||||
|
prevModels.filter((model) => model.id !== modelId)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Error stopping model:', error)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate RAM usage percentage
|
// Calculate RAM usage percentage
|
||||||
const ramUsagePercentage =
|
const ramUsagePercentage =
|
||||||
((hardwareData.ram.total - hardwareData.ram.available) /
|
((hardwareData.ram.total - hardwareData.ram.available) /
|
||||||
@ -154,15 +167,18 @@ function SystemMonitor() {
|
|||||||
<div className="flex justify-between items-center">
|
<div className="flex justify-between items-center">
|
||||||
<span className="text-main-view-fg/70">Uptime</span>
|
<span className="text-main-view-fg/70">Uptime</span>
|
||||||
<span className="text-main-view-fg">
|
<span className="text-main-view-fg">
|
||||||
{formatDuration(model.start_time)}
|
{model.start_time && formatDuration(model.start_time)}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex justify-between items-center">
|
<div className="flex justify-between items-center">
|
||||||
<span className="text-main-view-fg/70">Status</span>
|
<span className="text-main-view-fg/70">Actions</span>
|
||||||
<span className="text-main-view-fg">
|
<span className="text-main-view-fg">
|
||||||
<div className="bg-green-500/20 px-1 font-bold py-0.5 rounded text-green-700 text-xs">
|
<Button
|
||||||
Running
|
variant="destructive"
|
||||||
</div>
|
onClick={() => stopRunningModel(model.id)}
|
||||||
|
>
|
||||||
|
Stop
|
||||||
|
</Button>
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -226,6 +226,29 @@ export const getActiveModels = async (provider?: string) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops a model for a given provider.
|
||||||
|
* @param model
|
||||||
|
* @param provider
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const stopModel = async (model: string, provider?: string) => {
|
||||||
|
const providerName = provider || 'cortex' // we will go down to llama.cpp extension later on
|
||||||
|
const extension = EngineManager.instance().get(providerName)
|
||||||
|
|
||||||
|
if (!extension) throw new Error('Model extension not found')
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await extension.unloadModel({
|
||||||
|
model,
|
||||||
|
id: model,
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to stop model:', error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configures the proxy options for model downloads.
|
* Configures the proxy options for model downloads.
|
||||||
* @param param0
|
* @param param0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user