Use token.js for non-tools calls (#4973)
* deprecate inference() * fix tool_choice. only startModel for Cortex * appease linter * remove sse * add stopInferencing support. temporarily with OpenAI * use abortSignal in token.js * bump token.js version
This commit is contained in:
parent
dbe7ef65e2
commit
dc23cc2716
@ -6,6 +6,7 @@ import { AIEngine } from './AIEngine'
|
|||||||
*/
|
*/
|
||||||
export class EngineManager {
|
export class EngineManager {
|
||||||
public engines = new Map<string, AIEngine>()
|
public engines = new Map<string, AIEngine>()
|
||||||
|
public controller: AbortController | null = null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Registers an engine.
|
* Registers an engine.
|
||||||
|
|||||||
@ -12,11 +12,7 @@ import {
|
|||||||
ChatCompletionRole,
|
ChatCompletionRole,
|
||||||
ContentType,
|
ContentType,
|
||||||
} from '../../../types'
|
} from '../../../types'
|
||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulidx'
|
|
||||||
|
|
||||||
jest.mock('./helpers/sse')
|
|
||||||
jest.mock('ulidx')
|
|
||||||
jest.mock('../../events')
|
jest.mock('../../events')
|
||||||
|
|
||||||
class TestOAIEngine extends OAIEngine {
|
class TestOAIEngine extends OAIEngine {
|
||||||
@ -48,79 +44,6 @@ describe('OAIEngine', () => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle inference request', async () => {
|
|
||||||
const data: MessageRequest = {
|
|
||||||
model: { engine: 'test-provider', id: 'test-model' } as any,
|
|
||||||
threadId: 'test-thread',
|
|
||||||
type: MessageRequestType.Thread,
|
|
||||||
assistantId: 'test-assistant',
|
|
||||||
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
|
|
||||||
}
|
|
||||||
|
|
||||||
;(ulid as jest.Mock).mockReturnValue('test-id')
|
|
||||||
;(requestInference as jest.Mock).mockReturnValue({
|
|
||||||
subscribe: ({ next, complete }: any) => {
|
|
||||||
next('test response')
|
|
||||||
complete()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
await engine.inference(data)
|
|
||||||
|
|
||||||
expect(requestInference).toHaveBeenCalledWith(
|
|
||||||
'http://test-inference-url',
|
|
||||||
expect.objectContaining({ model: 'test-model' }),
|
|
||||||
expect.any(Object),
|
|
||||||
expect.any(AbortController),
|
|
||||||
{ Authorization: 'Bearer test-token' },
|
|
||||||
undefined
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(events.emit).toHaveBeenCalledWith(
|
|
||||||
MessageEvent.OnMessageResponse,
|
|
||||||
expect.objectContaining({ id: 'test-id' })
|
|
||||||
)
|
|
||||||
expect(events.emit).toHaveBeenCalledWith(
|
|
||||||
MessageEvent.OnMessageUpdate,
|
|
||||||
expect.objectContaining({
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: { value: 'test response', annotations: [] },
|
|
||||||
},
|
|
||||||
],
|
|
||||||
status: MessageStatus.Ready,
|
|
||||||
})
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle inference error', async () => {
|
|
||||||
const data: MessageRequest = {
|
|
||||||
model: { engine: 'test-provider', id: 'test-model' } as any,
|
|
||||||
threadId: 'test-thread',
|
|
||||||
type: MessageRequestType.Thread,
|
|
||||||
assistantId: 'test-assistant',
|
|
||||||
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
|
|
||||||
}
|
|
||||||
|
|
||||||
;(ulid as jest.Mock).mockReturnValue('test-id')
|
|
||||||
;(requestInference as jest.Mock).mockReturnValue({
|
|
||||||
subscribe: ({ error }: any) => {
|
|
||||||
error({ message: 'test error', code: 500 })
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
await engine.inference(data)
|
|
||||||
|
|
||||||
expect(events.emit).toHaveBeenLastCalledWith(
|
|
||||||
MessageEvent.OnMessageUpdate,
|
|
||||||
expect.objectContaining({
|
|
||||||
status: 'error',
|
|
||||||
error_code: 500,
|
|
||||||
})
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should stop inference', () => {
|
it('should stop inference', () => {
|
||||||
engine.stopInference()
|
engine.stopInference()
|
||||||
expect(engine.isCancelled).toBe(true)
|
expect(engine.isCancelled).toBe(true)
|
||||||
|
|||||||
@ -1,18 +1,9 @@
|
|||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulidx'
|
|
||||||
import { AIEngine } from './AIEngine'
|
import { AIEngine } from './AIEngine'
|
||||||
import {
|
import {
|
||||||
ChatCompletionRole,
|
|
||||||
ContentType,
|
|
||||||
InferenceEvent,
|
InferenceEvent,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
MessageRequest,
|
MessageRequest,
|
||||||
MessageRequestType,
|
|
||||||
MessageStatus,
|
|
||||||
Model,
|
Model,
|
||||||
ModelInfo,
|
|
||||||
ThreadContent,
|
|
||||||
ThreadMessage,
|
|
||||||
} from '../../../types'
|
} from '../../../types'
|
||||||
import { events } from '../../events'
|
import { events } from '../../events'
|
||||||
|
|
||||||
@ -53,112 +44,6 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
*/
|
*/
|
||||||
override onUnload(): void {}
|
override onUnload(): void {}
|
||||||
|
|
||||||
/*
|
|
||||||
* Inference request
|
|
||||||
*/
|
|
||||||
override async inference(data: MessageRequest) {
|
|
||||||
if (!data.model?.id) {
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, {
|
|
||||||
status: MessageStatus.Error,
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: 'No model ID provided',
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timestamp = Date.now() / 1000
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: ulid(),
|
|
||||||
thread_id: data.thread?.id ?? data.threadId,
|
|
||||||
type: data.type,
|
|
||||||
assistant_id: data.assistantId,
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created_at: timestamp,
|
|
||||||
completed_at: timestamp,
|
|
||||||
object: 'thread.message',
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.type !== MessageRequestType.Summary) {
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.isCancelled = false
|
|
||||||
this.controller = new AbortController()
|
|
||||||
|
|
||||||
const model: ModelInfo = {
|
|
||||||
...(this.loadedModel ? this.loadedModel : {}),
|
|
||||||
...data.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
const header = await this.headers()
|
|
||||||
let requestBody = {
|
|
||||||
messages: data.messages ?? [],
|
|
||||||
model: model.id,
|
|
||||||
stream: true,
|
|
||||||
tools: data.tools,
|
|
||||||
...model.parameters,
|
|
||||||
}
|
|
||||||
if (this.transformPayload) {
|
|
||||||
requestBody = this.transformPayload(requestBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestInference(
|
|
||||||
this.inferenceUrl,
|
|
||||||
requestBody,
|
|
||||||
model,
|
|
||||||
this.controller,
|
|
||||||
header,
|
|
||||||
this.transformResponse
|
|
||||||
).subscribe({
|
|
||||||
next: (content: any) => {
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: content.trim(),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
complete: async () => {
|
|
||||||
message.status = message.content.length
|
|
||||||
? MessageStatus.Ready
|
|
||||||
: MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
error: async (err: any) => {
|
|
||||||
if (this.isCancelled || message.content.length) {
|
|
||||||
message.status = MessageStatus.Stopped
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
message.status = MessageStatus.Error
|
|
||||||
message.content[0] = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value:
|
|
||||||
typeof message === 'string'
|
|
||||||
? err.message
|
|
||||||
: (JSON.stringify(err.message) ?? err.detail),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.error_code = err.code
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stops the inference.
|
* Stops the inference.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -1,146 +0,0 @@
|
|||||||
import { lastValueFrom, Observable } from 'rxjs'
|
|
||||||
import { requestInference } from './sse'
|
|
||||||
|
|
||||||
import { ReadableStream } from 'stream/web'
|
|
||||||
describe('requestInference', () => {
|
|
||||||
it('should send a request to the inference server and return an Observable', () => {
|
|
||||||
// Mock the fetch function
|
|
||||||
const mockFetch: any = jest.fn(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
ok: true,
|
|
||||||
json: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
choices: [{ message: { content: 'Generated response' } }],
|
|
||||||
}),
|
|
||||||
headers: new Headers(),
|
|
||||||
redirected: false,
|
|
||||||
status: 200,
|
|
||||||
statusText: 'OK',
|
|
||||||
// Add other required properties here
|
|
||||||
})
|
|
||||||
)
|
|
||||||
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
|
|
||||||
|
|
||||||
// Define the test inputs
|
|
||||||
const inferenceUrl = 'https://inference-server.com'
|
|
||||||
const requestBody = { message: 'Hello' }
|
|
||||||
const model = { id: 'model-id', parameters: { stream: false } }
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
const result = requestInference(inferenceUrl, requestBody, model)
|
|
||||||
|
|
||||||
// Assert the expected behavior
|
|
||||||
expect(result).toBeInstanceOf(Observable)
|
|
||||||
expect(lastValueFrom(result)).resolves.toEqual('Generated response')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('returns 401 error', () => {
|
|
||||||
// Mock the fetch function
|
|
||||||
const mockFetch: any = jest.fn(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
ok: false,
|
|
||||||
json: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
error: { message: 'Invalid API Key.', code: 'invalid_api_key' },
|
|
||||||
}),
|
|
||||||
headers: new Headers(),
|
|
||||||
redirected: false,
|
|
||||||
status: 401,
|
|
||||||
statusText: 'invalid_api_key',
|
|
||||||
// Add other required properties here
|
|
||||||
})
|
|
||||||
)
|
|
||||||
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
|
|
||||||
|
|
||||||
// Define the test inputs
|
|
||||||
const inferenceUrl = 'https://inference-server.com'
|
|
||||||
const requestBody = { message: 'Hello' }
|
|
||||||
const model = { id: 'model-id', parameters: { stream: false } }
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
const result = requestInference(inferenceUrl, requestBody, model)
|
|
||||||
|
|
||||||
// Assert the expected behavior
|
|
||||||
expect(result).toBeInstanceOf(Observable)
|
|
||||||
expect(lastValueFrom(result)).rejects.toEqual({
|
|
||||||
message: 'Invalid API Key.',
|
|
||||||
code: 'invalid_api_key',
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle a successful response with a transformResponse function', () => {
|
|
||||||
// Mock the fetch function
|
|
||||||
const mockFetch: any = jest.fn(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
ok: true,
|
|
||||||
json: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
choices: [{ message: { content: 'Generated response' } }],
|
|
||||||
}),
|
|
||||||
headers: new Headers(),
|
|
||||||
redirected: false,
|
|
||||||
status: 200,
|
|
||||||
statusText: 'OK',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
|
|
||||||
|
|
||||||
// Define the test inputs
|
|
||||||
const inferenceUrl = 'https://inference-server.com'
|
|
||||||
const requestBody = { message: 'Hello' }
|
|
||||||
const model = { id: 'model-id', parameters: { stream: false } }
|
|
||||||
const transformResponse = (data: any) =>
|
|
||||||
data.choices[0].message.content.toUpperCase()
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
const result = requestInference(
|
|
||||||
inferenceUrl,
|
|
||||||
requestBody,
|
|
||||||
model,
|
|
||||||
undefined,
|
|
||||||
undefined,
|
|
||||||
transformResponse
|
|
||||||
)
|
|
||||||
|
|
||||||
// Assert the expected behavior
|
|
||||||
expect(result).toBeInstanceOf(Observable)
|
|
||||||
expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should handle a successful response with streaming enabled', () => {
|
|
||||||
// Mock the fetch function
|
|
||||||
const mockFetch: any = jest.fn(() =>
|
|
||||||
Promise.resolve({
|
|
||||||
ok: true,
|
|
||||||
body: new ReadableStream({
|
|
||||||
start(controller) {
|
|
||||||
controller.enqueue(
|
|
||||||
new TextEncoder().encode(
|
|
||||||
'data: {"choices": [{"delta": {"content": "Streamed"}}]}'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
controller.enqueue(new TextEncoder().encode('data: [DONE]'))
|
|
||||||
controller.close()
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
headers: new Headers(),
|
|
||||||
redirected: false,
|
|
||||||
status: 200,
|
|
||||||
statusText: 'OK',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
|
|
||||||
|
|
||||||
// Define the test inputs
|
|
||||||
const inferenceUrl = 'https://inference-server.com'
|
|
||||||
const requestBody = { message: 'Hello' }
|
|
||||||
const model = { id: 'model-id', parameters: { stream: true } }
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
const result = requestInference(inferenceUrl, requestBody, model)
|
|
||||||
|
|
||||||
// Assert the expected behavior
|
|
||||||
expect(result).toBeInstanceOf(Observable)
|
|
||||||
expect(lastValueFrom(result)).resolves.toEqual('Streamed')
|
|
||||||
})
|
|
||||||
@ -1,132 +0,0 @@
|
|||||||
import { Observable } from 'rxjs'
|
|
||||||
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
|
|
||||||
/**
|
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
|
||||||
* @returns An Observable that emits the generated response as a string.
|
|
||||||
*/
|
|
||||||
export function requestInference(
|
|
||||||
inferenceUrl: string,
|
|
||||||
requestBody: any,
|
|
||||||
model: {
|
|
||||||
id: string
|
|
||||||
parameters?: ModelRuntimeParams
|
|
||||||
},
|
|
||||||
controller?: AbortController,
|
|
||||||
headers?: HeadersInit,
|
|
||||||
transformResponse?: Function
|
|
||||||
): Observable<string> {
|
|
||||||
return new Observable((subscriber) => {
|
|
||||||
fetch(inferenceUrl, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'Accept': model.parameters?.stream
|
|
||||||
? 'text/event-stream'
|
|
||||||
: 'application/json',
|
|
||||||
...headers,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
|
||||||
signal: controller?.signal,
|
|
||||||
})
|
|
||||||
.then(async (response) => {
|
|
||||||
if (!response.ok) {
|
|
||||||
if (response.status === 401) {
|
|
||||||
throw {
|
|
||||||
code: ErrorCode.InvalidApiKey,
|
|
||||||
message: 'Invalid API Key.',
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let data = await response.json()
|
|
||||||
try {
|
|
||||||
handleError(data)
|
|
||||||
} catch (err) {
|
|
||||||
subscriber.error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// There could be overriden stream parameter in the model
|
|
||||||
// that is set in request body (transformed payload)
|
|
||||||
if (
|
|
||||||
requestBody?.stream === false ||
|
|
||||||
model.parameters?.stream === false
|
|
||||||
) {
|
|
||||||
const data = await response.json()
|
|
||||||
try {
|
|
||||||
handleError(data)
|
|
||||||
} catch (err) {
|
|
||||||
subscriber.error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (transformResponse) {
|
|
||||||
subscriber.next(transformResponse(data))
|
|
||||||
} else {
|
|
||||||
subscriber.next(
|
|
||||||
data.choices
|
|
||||||
? data.choices[0]?.message?.content
|
|
||||||
: (data.content[0]?.text ?? '')
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const stream = response.body
|
|
||||||
const decoder = new TextDecoder('utf-8')
|
|
||||||
const reader = stream?.getReader()
|
|
||||||
let content = ''
|
|
||||||
|
|
||||||
while (true && reader) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const text = decoder.decode(value)
|
|
||||||
const lines = text.trim().split('\n')
|
|
||||||
let cachedLines = ''
|
|
||||||
for (const line of lines) {
|
|
||||||
try {
|
|
||||||
if (transformResponse) {
|
|
||||||
content += transformResponse(line)
|
|
||||||
subscriber.next(content ?? '')
|
|
||||||
} else {
|
|
||||||
const toParse = cachedLines + line
|
|
||||||
if (!line.includes('data: [DONE]')) {
|
|
||||||
const data = JSON.parse(toParse.replace('data: ', ''))
|
|
||||||
try {
|
|
||||||
handleError(data)
|
|
||||||
} catch (err) {
|
|
||||||
subscriber.error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
content += data.choices[0]?.delta?.content ?? ''
|
|
||||||
if (content.startsWith('assistant: ')) {
|
|
||||||
content = content.replace('assistant: ', '')
|
|
||||||
}
|
|
||||||
if (content !== '') subscriber.next(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
cachedLines = line
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
subscriber.complete()
|
|
||||||
})
|
|
||||||
.catch((err) => subscriber.error(err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle error and normalize it to a common format.
|
|
||||||
* @param data
|
|
||||||
*/
|
|
||||||
const handleError = (data: any) => {
|
|
||||||
if (
|
|
||||||
data.error ||
|
|
||||||
data.message ||
|
|
||||||
data.detail ||
|
|
||||||
(Array.isArray(data) && data.length && data[0].error)
|
|
||||||
) {
|
|
||||||
throw data.error ?? data[0]?.error ?? data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -157,10 +157,13 @@ export function useActiveModel() {
|
|||||||
stopModel()
|
stopModel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (!activeModel) return
|
// if (!activeModel) return
|
||||||
|
|
||||||
const engine = EngineManager.instance().get(InferenceEngine.cortex)
|
// const engine = EngineManager.instance().get(InferenceEngine.cortex)
|
||||||
engine?.stopInference()
|
// engine?.stopInference()
|
||||||
|
// NOTE: this only works correctly if there is only 1 concurrent request
|
||||||
|
// at any point in time, which is a reasonable assumption to have.
|
||||||
|
EngineManager.instance().controller?.abort()
|
||||||
}, [activeModel, stateModel, stopModel])
|
}, [activeModel, stateModel, stopModel])
|
||||||
|
|
||||||
return { activeModel, startModel, stopModel, stopInference, stateModel }
|
return { activeModel, startModel, stopModel, stopInference, stateModel }
|
||||||
|
|||||||
@ -12,11 +12,9 @@ import {
|
|||||||
ThreadAssistantInfo,
|
ThreadAssistantInfo,
|
||||||
events,
|
events,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
ContentType,
|
|
||||||
EngineManager,
|
EngineManager,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
MessageStatus,
|
MessageStatus,
|
||||||
ChatCompletionRole,
|
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
@ -26,12 +24,10 @@ import {
|
|||||||
ChatCompletionTool,
|
ChatCompletionTool,
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
} from 'openai/resources/chat'
|
} from 'openai/resources/chat'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
StreamCompletionResponse,
|
StreamCompletionResponse,
|
||||||
TokenJS,
|
TokenJS,
|
||||||
models,
|
|
||||||
} from 'token.js'
|
} from 'token.js'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
@ -208,16 +204,8 @@ export default function useSendChatMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build Message Request
|
// Build Message Request
|
||||||
const requestBuilder = new MessageRequestBuilder(
|
// TODO: detect if model supports tools
|
||||||
MessageRequestType.Thread,
|
const tools = (await window.core.api.getTools())
|
||||||
{
|
|
||||||
...modelRequest,
|
|
||||||
settings: settingParams,
|
|
||||||
parameters: runtimeParams,
|
|
||||||
},
|
|
||||||
activeThread,
|
|
||||||
messages ?? currentMessages,
|
|
||||||
(await window.core.api.getTools())
|
|
||||||
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
|
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
|
||||||
.map((tool: ModelTool) => ({
|
.map((tool: ModelTool) => ({
|
||||||
type: 'function' as const,
|
type: 'function' as const,
|
||||||
@ -228,6 +216,16 @@ export default function useSendChatMessage(
|
|||||||
strict: false,
|
strict: false,
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
const requestBuilder = new MessageRequestBuilder(
|
||||||
|
MessageRequestType.Thread,
|
||||||
|
{
|
||||||
|
...modelRequest,
|
||||||
|
settings: settingParams,
|
||||||
|
parameters: runtimeParams,
|
||||||
|
},
|
||||||
|
activeThread,
|
||||||
|
messages ?? currentMessages,
|
||||||
|
(tools && tools.length) ? tools : undefined,
|
||||||
).addSystemMessage(activeAssistant.instructions)
|
).addSystemMessage(activeAssistant.instructions)
|
||||||
|
|
||||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
|
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
|
||||||
@ -267,13 +265,15 @@ export default function useSendChatMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start Model if not started
|
// Start Model if not started
|
||||||
|
const isCortex = modelRequest.engine == InferenceEngine.cortex ||
|
||||||
|
modelRequest.engine == InferenceEngine.cortex_llamacpp
|
||||||
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
|
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
|
||||||
|
|
||||||
if (base64Blob) {
|
if (base64Blob) {
|
||||||
setFileUpload(undefined)
|
setFileUpload(undefined)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId && modelId) {
|
if (modelRef.current?.id !== modelId && modelId && isCortex) {
|
||||||
const error = await startModel(modelId).catch((error: Error) => error)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
if (error) {
|
if (error) {
|
||||||
updateThreadWaiting(activeThread.id, false)
|
updateThreadWaiting(activeThread.id, false)
|
||||||
@ -282,7 +282,6 @@ export default function useSendChatMessage(
|
|||||||
}
|
}
|
||||||
setIsGeneratingResponse(true)
|
setIsGeneratingResponse(true)
|
||||||
|
|
||||||
if (requestBuilder.tools && requestBuilder.tools.length) {
|
|
||||||
let isDone = false
|
let isDone = false
|
||||||
|
|
||||||
const engine =
|
const engine =
|
||||||
@ -297,6 +296,11 @@ export default function useSendChatMessage(
|
|||||||
|
|
||||||
extendBuiltInEngineModels(tokenJS, provider, modelId)
|
extendBuiltInEngineModels(tokenJS, provider, modelId)
|
||||||
|
|
||||||
|
// llama.cpp currently does not support streaming when tools are used.
|
||||||
|
const useStream = (requestBuilder.tools && isCortex) ?
|
||||||
|
false :
|
||||||
|
modelRequest.parameters?.stream
|
||||||
|
|
||||||
let parentMessageId: string | undefined
|
let parentMessageId: string | undefined
|
||||||
while (!isDone) {
|
while (!isDone) {
|
||||||
let messageId = ulid()
|
let messageId = ulid()
|
||||||
@ -316,40 +320,47 @@ export default function useSendChatMessage(
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
events.emit(MessageEvent.OnMessageResponse, message)
|
||||||
// Variables to track and accumulate streaming content
|
|
||||||
|
|
||||||
if (
|
// we need to separate into 2 cases to appease linter
|
||||||
data.model?.parameters?.stream &&
|
const controller = new AbortController()
|
||||||
data.model?.engine !== InferenceEngine.cortex &&
|
EngineManager.instance().controller = controller
|
||||||
data.model?.engine !== InferenceEngine.cortex_llamacpp
|
if (useStream) {
|
||||||
) {
|
const response = await tokenJS.chat.completions.create(
|
||||||
const response = await tokenJS.chat.completions.create({
|
{
|
||||||
stream: true,
|
stream: true,
|
||||||
provider,
|
provider,
|
||||||
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||||
model: data.model?.id ?? '',
|
model: data.model?.id ?? '',
|
||||||
tools: data.tools as ChatCompletionTool[],
|
tools: data.tools as ChatCompletionTool[],
|
||||||
tool_choice: 'auto',
|
tool_choice: data.tools ? 'auto' : undefined,
|
||||||
})
|
},
|
||||||
|
{
|
||||||
|
signal: controller.signal,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
// Variables to track and accumulate streaming content
|
||||||
if (!message.content.length) {
|
if (!message.content.length) {
|
||||||
message.content = emptyMessageContent
|
message.content = emptyMessageContent
|
||||||
}
|
}
|
||||||
|
|
||||||
isDone = await processStreamingResponse(
|
isDone = await processStreamingResponse(
|
||||||
response,
|
response,
|
||||||
requestBuilder,
|
requestBuilder,
|
||||||
message
|
message
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
const response = await tokenJS.chat.completions.create({
|
const response = await tokenJS.chat.completions.create(
|
||||||
|
{
|
||||||
stream: false,
|
stream: false,
|
||||||
provider,
|
provider,
|
||||||
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
messages: requestBuilder.messages as ChatCompletionMessageParam[],
|
||||||
model: data.model?.id ?? '',
|
model: data.model?.id ?? '',
|
||||||
tools: data.tools as ChatCompletionTool[],
|
tools: data.tools as ChatCompletionTool[],
|
||||||
tool_choice: 'auto',
|
tool_choice: data.tools ? 'auto' : undefined,
|
||||||
})
|
},
|
||||||
|
{
|
||||||
|
signal: controller.signal,
|
||||||
|
}
|
||||||
|
)
|
||||||
// Variables to track and accumulate streaming content
|
// Variables to track and accumulate streaming content
|
||||||
if (!message.content.length) {
|
if (!message.content.length) {
|
||||||
message.content = emptyMessageContent
|
message.content = emptyMessageContent
|
||||||
@ -360,15 +371,10 @@ export default function useSendChatMessage(
|
|||||||
message
|
message
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
message.status = MessageStatus.Ready
|
message.status = MessageStatus.Ready
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
events.emit(MessageEvent.OnMessageUpdate, message)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Request for inference
|
|
||||||
EngineManager.instance()
|
|
||||||
.get(InferenceEngine.cortex)
|
|
||||||
?.inference(requestBuilder.build())
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setIsGeneratingResponse(false)
|
setIsGeneratingResponse(false)
|
||||||
updateThreadWaiting(activeThread.id, false)
|
updateThreadWaiting(activeThread.id, false)
|
||||||
|
|||||||
@ -65,7 +65,7 @@
|
|||||||
"swr": "^2.2.5",
|
"swr": "^2.2.5",
|
||||||
"tailwind-merge": "^2.0.0",
|
"tailwind-merge": "^2.0.0",
|
||||||
"tailwindcss": "3.4.17",
|
"tailwindcss": "3.4.17",
|
||||||
"token.js": "npm:token.js-fork@0.7.2",
|
"token.js": "npm:token.js-fork@0.7.6",
|
||||||
"ulidx": "^2.3.0",
|
"ulidx": "^2.3.0",
|
||||||
"use-debounce": "^10.0.0",
|
"use-debounce": "^10.0.0",
|
||||||
"uuid": "^9.0.1",
|
"uuid": "^9.0.1",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user