Use token.js for non-tools calls (#4973)

* deprecate inference()

* fix tool_choice. only startModel for Cortex

* appease linter

* remove sse

* add stopInferencing support. temporarily with OpenAI

* use abortSignal in token.js

* bump token.js version
This commit is contained in:
Thien Tran 2025-05-14 17:26:42 +08:00 committed by Louis
parent dbe7ef65e2
commit dc23cc2716
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
8 changed files with 98 additions and 558 deletions

View File

@ -6,6 +6,7 @@ import { AIEngine } from './AIEngine'
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()
public controller: AbortController | null = null
/**
* Registers an engine.

View File

@ -12,11 +12,7 @@ import {
ChatCompletionRole,
ContentType,
} from '../../../types'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
jest.mock('./helpers/sse')
jest.mock('ulidx')
jest.mock('../../events')
class TestOAIEngine extends OAIEngine {
@ -48,79 +44,6 @@ describe('OAIEngine', () => {
)
})
it('should handle inference request', async () => {
const data: MessageRequest = {
model: { engine: 'test-provider', id: 'test-model' } as any,
threadId: 'test-thread',
type: MessageRequestType.Thread,
assistantId: 'test-assistant',
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
}
;(ulid as jest.Mock).mockReturnValue('test-id')
;(requestInference as jest.Mock).mockReturnValue({
subscribe: ({ next, complete }: any) => {
next('test response')
complete()
},
})
await engine.inference(data)
expect(requestInference).toHaveBeenCalledWith(
'http://test-inference-url',
expect.objectContaining({ model: 'test-model' }),
expect.any(Object),
expect.any(AbortController),
{ Authorization: 'Bearer test-token' },
undefined
)
expect(events.emit).toHaveBeenCalledWith(
MessageEvent.OnMessageResponse,
expect.objectContaining({ id: 'test-id' })
)
expect(events.emit).toHaveBeenCalledWith(
MessageEvent.OnMessageUpdate,
expect.objectContaining({
content: [
{
type: ContentType.Text,
text: { value: 'test response', annotations: [] },
},
],
status: MessageStatus.Ready,
})
)
})
it('should handle inference error', async () => {
const data: MessageRequest = {
model: { engine: 'test-provider', id: 'test-model' } as any,
threadId: 'test-thread',
type: MessageRequestType.Thread,
assistantId: 'test-assistant',
messages: [{ role: ChatCompletionRole.User, content: 'Hello' }],
}
;(ulid as jest.Mock).mockReturnValue('test-id')
;(requestInference as jest.Mock).mockReturnValue({
subscribe: ({ error }: any) => {
error({ message: 'test error', code: 500 })
},
})
await engine.inference(data)
expect(events.emit).toHaveBeenLastCalledWith(
MessageEvent.OnMessageUpdate,
expect.objectContaining({
status: 'error',
error_code: 500,
})
)
})
it('should stop inference', () => {
engine.stopInference()
expect(engine.isCancelled).toBe(true)

View File

@ -1,18 +1,9 @@
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { AIEngine } from './AIEngine'
import {
ChatCompletionRole,
ContentType,
InferenceEvent,
MessageEvent,
MessageRequest,
MessageRequestType,
MessageStatus,
Model,
ModelInfo,
ThreadContent,
ThreadMessage,
} from '../../../types'
import { events } from '../../events'
@ -53,112 +44,6 @@ export abstract class OAIEngine extends AIEngine {
*/
override onUnload(): void {}
/*
* Inference request
*/
override async inference(data: MessageRequest) {
if (!data.model?.id) {
events.emit(MessageEvent.OnMessageResponse, {
status: MessageStatus.Error,
content: [
{
type: ContentType.Text,
text: {
value: 'No model ID provided',
annotations: [],
},
},
],
})
return
}
const timestamp = Date.now() / 1000
const message: ThreadMessage = {
id: ulid(),
thread_id: data.thread?.id ?? data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created_at: timestamp,
completed_at: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
this.isCancelled = false
this.controller = new AbortController()
const model: ModelInfo = {
...(this.loadedModel ? this.loadedModel : {}),
...data.model,
}
const header = await this.headers()
let requestBody = {
messages: data.messages ?? [],
model: model.id,
stream: true,
tools: data.tools,
...model.parameters,
}
if (this.transformPayload) {
requestBody = this.transformPayload(requestBody)
}
requestInference(
this.inferenceUrl,
requestBody,
model,
this.controller,
header,
this.transformResponse
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err: any) => {
if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
message.status = MessageStatus.Error
message.content[0] = {
type: ContentType.Text,
text: {
value:
typeof message === 'string'
? err.message
: (JSON.stringify(err.message) ?? err.detail),
annotations: [],
},
}
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
/**
* Stops the inference.
*/

View File

@ -1,146 +0,0 @@
import { lastValueFrom, Observable } from 'rxjs'
import { requestInference } from './sse'
import { ReadableStream } from 'stream/web'
describe('requestInference', () => {
it('should send a request to the inference server and return an Observable', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: true,
json: () =>
Promise.resolve({
choices: [{ message: { content: 'Generated response' } }],
}),
headers: new Headers(),
redirected: false,
status: 200,
statusText: 'OK',
// Add other required properties here
})
)
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
// Define the test inputs
const inferenceUrl = 'https://inference-server.com'
const requestBody = { message: 'Hello' }
const model = { id: 'model-id', parameters: { stream: false } }
// Call the function
const result = requestInference(inferenceUrl, requestBody, model)
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable)
expect(lastValueFrom(result)).resolves.toEqual('Generated response')
})
it('returns 401 error', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: false,
json: () =>
Promise.resolve({
error: { message: 'Invalid API Key.', code: 'invalid_api_key' },
}),
headers: new Headers(),
redirected: false,
status: 401,
statusText: 'invalid_api_key',
// Add other required properties here
})
)
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
// Define the test inputs
const inferenceUrl = 'https://inference-server.com'
const requestBody = { message: 'Hello' }
const model = { id: 'model-id', parameters: { stream: false } }
// Call the function
const result = requestInference(inferenceUrl, requestBody, model)
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable)
expect(lastValueFrom(result)).rejects.toEqual({
message: 'Invalid API Key.',
code: 'invalid_api_key',
})
})
})
it('should handle a successful response with a transformResponse function', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: true,
json: () =>
Promise.resolve({
choices: [{ message: { content: 'Generated response' } }],
}),
headers: new Headers(),
redirected: false,
status: 200,
statusText: 'OK',
})
)
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
// Define the test inputs
const inferenceUrl = 'https://inference-server.com'
const requestBody = { message: 'Hello' }
const model = { id: 'model-id', parameters: { stream: false } }
const transformResponse = (data: any) =>
data.choices[0].message.content.toUpperCase()
// Call the function
const result = requestInference(
inferenceUrl,
requestBody,
model,
undefined,
undefined,
transformResponse
)
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable)
expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE')
})
it('should handle a successful response with streaming enabled', () => {
// Mock the fetch function
const mockFetch: any = jest.fn(() =>
Promise.resolve({
ok: true,
body: new ReadableStream({
start(controller) {
controller.enqueue(
new TextEncoder().encode(
'data: {"choices": [{"delta": {"content": "Streamed"}}]}'
)
)
controller.enqueue(new TextEncoder().encode('data: [DONE]'))
controller.close()
},
}),
headers: new Headers(),
redirected: false,
status: 200,
statusText: 'OK',
})
)
jest.spyOn(global, 'fetch').mockImplementation(mockFetch)
// Define the test inputs
const inferenceUrl = 'https://inference-server.com'
const requestBody = { message: 'Hello' }
const model = { id: 'model-id', parameters: { stream: true } }
// Call the function
const result = requestInference(inferenceUrl, requestBody, model)
// Assert the expected behavior
expect(result).toBeInstanceOf(Observable)
expect(lastValueFrom(result)).resolves.toEqual('Streamed')
})

View File

@ -1,132 +0,0 @@
import { Observable } from 'rxjs'
import { ErrorCode, ModelRuntimeParams } from '../../../../types'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
inferenceUrl: string,
requestBody: any,
model: {
id: string
parameters?: ModelRuntimeParams
},
controller?: AbortController,
headers?: HeadersInit,
transformResponse?: Function
): Observable<string> {
return new Observable((subscriber) => {
fetch(inferenceUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
'Accept': model.parameters?.stream
? 'text/event-stream'
: 'application/json',
...headers,
},
body: JSON.stringify(requestBody),
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
if (response.status === 401) {
throw {
code: ErrorCode.InvalidApiKey,
message: 'Invalid API Key.',
}
}
let data = await response.json()
try {
handleError(data)
} catch (err) {
subscriber.error(err)
return
}
}
// There could be overriden stream parameter in the model
// that is set in request body (transformed payload)
if (
requestBody?.stream === false ||
model.parameters?.stream === false
) {
const data = await response.json()
try {
handleError(data)
} catch (err) {
subscriber.error(err)
return
}
if (transformResponse) {
subscriber.next(transformResponse(data))
} else {
subscriber.next(
data.choices
? data.choices[0]?.message?.content
: (data.content[0]?.text ?? '')
)
}
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
let cachedLines = ''
for (const line of lines) {
try {
if (transformResponse) {
content += transformResponse(line)
subscriber.next(content ?? '')
} else {
const toParse = cachedLines + line
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''))
try {
handleError(data)
} catch (err) {
subscriber.error(err)
return
}
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
if (content !== '') subscriber.next(content)
}
}
} catch {
cachedLines = line
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}
/**
* Handle error and normalize it to a common format.
* @param data
*/
const handleError = (data: any) => {
if (
data.error ||
data.message ||
data.detail ||
(Array.isArray(data) && data.length && data[0].error)
) {
throw data.error ?? data[0]?.error ?? data
}
}

View File

@ -157,10 +157,13 @@ export function useActiveModel() {
stopModel()
return
}
if (!activeModel) return
// if (!activeModel) return
const engine = EngineManager.instance().get(InferenceEngine.cortex)
engine?.stopInference()
// const engine = EngineManager.instance().get(InferenceEngine.cortex)
// engine?.stopInference()
// NOTE: this only works correctly if there is only 1 concurrent request
// at any point in time, which is a reasonable assumption to have.
EngineManager.instance().controller?.abort()
}, [activeModel, stateModel, stopModel])
return { activeModel, startModel, stopModel, stopInference, stateModel }

View File

@ -12,11 +12,9 @@ import {
ThreadAssistantInfo,
events,
MessageEvent,
ContentType,
EngineManager,
InferenceEngine,
MessageStatus,
ChatCompletionRole,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -26,12 +24,10 @@ import {
ChatCompletionTool,
ChatCompletionMessageToolCall,
} from 'openai/resources/chat'
import {
CompletionResponse,
StreamCompletionResponse,
TokenJS,
models,
} from 'token.js'
import { ulid } from 'ulidx'
@ -208,6 +204,18 @@ export default function useSendChatMessage(
}
// Build Message Request
// TODO: detect if model supports tools
const tools = (await window.core.api.getTools())
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
@ -217,17 +225,7 @@ export default function useSendChatMessage(
},
activeThread,
messages ?? currentMessages,
(await window.core.api.getTools())
?.filter((tool: ModelTool) => !disabledTools.includes(tool.name))
.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
(tools && tools.length) ? tools : undefined,
).addSystemMessage(activeAssistant.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
@ -267,13 +265,15 @@ export default function useSendChatMessage(
}
// Start Model if not started
const isCortex = modelRequest.engine == InferenceEngine.cortex ||
modelRequest.engine == InferenceEngine.cortex_llamacpp
const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId && modelId) {
if (modelRef.current?.id !== modelId && modelId && isCortex) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThread.id, false)
@ -282,92 +282,98 @@ export default function useSendChatMessage(
}
setIsGeneratingResponse(true)
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
let isDone = false
const engine =
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
const apiKey = engine?.api_key
const provider = convertBuiltInEngine(engine?.engine)
const engine =
engines?.[requestBuilder.model.engine as InferenceEngine]?.[0]
const apiKey = engine?.api_key
const provider = convertBuiltInEngine(engine?.engine)
const tokenJS = new TokenJS({
apiKey: apiKey ?? (await window.core.api.appToken()),
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
const tokenJS = new TokenJS({
apiKey: apiKey ?? (await window.core.api.appToken()),
baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`,
})
extendBuiltInEngineModels(tokenJS, provider, modelId)
// llama.cpp currently does not support streaming when tools are used.
const useStream = (requestBuilder.tools && isCortex) ?
false :
modelRequest.parameters?.stream
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = createMessage({
id: messageId,
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
})
events.emit(MessageEvent.OnMessageResponse, message)
extendBuiltInEngineModels(tokenJS, provider, modelId)
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = createMessage({
id: messageId,
thread_id: activeThread.id,
assistant_id: activeAssistant.assistant_id,
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
})
events.emit(MessageEvent.OnMessageResponse, message)
// Variables to track and accumulate streaming content
if (
data.model?.parameters?.stream &&
data.model?.engine !== InferenceEngine.cortex &&
data.model?.engine !== InferenceEngine.cortex_llamacpp
) {
const response = await tokenJS.chat.completions.create({
// we need to separate into 2 cases to appease linter
const controller = new AbortController()
EngineManager.instance().controller = controller
if (useStream) {
const response = await tokenJS.chat.completions.create(
{
stream: true,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: 'auto',
})
if (!message.content.length) {
message.content = emptyMessageContent
tool_choice: data.tools ? 'auto' : undefined,
},
{
signal: controller.signal,
}
isDone = await processStreamingResponse(
response,
requestBuilder,
message
)
} else {
const response = await tokenJS.chat.completions.create({
)
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
}
isDone = await processStreamingResponse(
response,
requestBuilder,
message
)
} else {
const response = await tokenJS.chat.completions.create(
{
stream: false,
provider,
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
tool_choice: 'auto',
})
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
tool_choice: data.tools ? 'auto' : undefined,
},
{
signal: controller.signal,
}
isDone = await processNonStreamingResponse(
response,
requestBuilder,
message
)
)
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = emptyMessageContent
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
isDone = await processNonStreamingResponse(
response,
requestBuilder,
message
)
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}
} catch (error) {
setIsGeneratingResponse(false)

View File

@ -65,7 +65,7 @@
"swr": "^2.2.5",
"tailwind-merge": "^2.0.0",
"tailwindcss": "3.4.17",
"token.js": "npm:token.js-fork@0.7.2",
"token.js": "npm:token.js-fork@0.7.6",
"ulidx": "^2.3.0",
"use-debounce": "^10.0.0",
"uuid": "^9.0.1",