diff --git a/core/src/browser/extensions/engines/EngineManager.ts b/core/src/browser/extensions/engines/EngineManager.ts index 90ce75ac5..7bf7a9924 100644 --- a/core/src/browser/extensions/engines/EngineManager.ts +++ b/core/src/browser/extensions/engines/EngineManager.ts @@ -6,6 +6,7 @@ import { AIEngine } from './AIEngine' */ export class EngineManager { public engines = new Map() + public controller: AbortController | null = null /** * Registers an engine. diff --git a/core/src/browser/extensions/engines/OAIEngine.test.ts b/core/src/browser/extensions/engines/OAIEngine.test.ts index 66537d0be..0e985fd1b 100644 --- a/core/src/browser/extensions/engines/OAIEngine.test.ts +++ b/core/src/browser/extensions/engines/OAIEngine.test.ts @@ -12,11 +12,7 @@ import { ChatCompletionRole, ContentType, } from '../../../types' -import { requestInference } from './helpers/sse' -import { ulid } from 'ulidx' -jest.mock('./helpers/sse') -jest.mock('ulidx') jest.mock('../../events') class TestOAIEngine extends OAIEngine { @@ -48,79 +44,6 @@ describe('OAIEngine', () => { ) }) - it('should handle inference request', async () => { - const data: MessageRequest = { - model: { engine: 'test-provider', id: 'test-model' } as any, - threadId: 'test-thread', - type: MessageRequestType.Thread, - assistantId: 'test-assistant', - messages: [{ role: ChatCompletionRole.User, content: 'Hello' }], - } - - ;(ulid as jest.Mock).mockReturnValue('test-id') - ;(requestInference as jest.Mock).mockReturnValue({ - subscribe: ({ next, complete }: any) => { - next('test response') - complete() - }, - }) - - await engine.inference(data) - - expect(requestInference).toHaveBeenCalledWith( - 'http://test-inference-url', - expect.objectContaining({ model: 'test-model' }), - expect.any(Object), - expect.any(AbortController), - { Authorization: 'Bearer test-token' }, - undefined - ) - - expect(events.emit).toHaveBeenCalledWith( - MessageEvent.OnMessageResponse, - expect.objectContaining({ id: 'test-id' }) - ) - expect(events.emit).toHaveBeenCalledWith( - MessageEvent.OnMessageUpdate, - expect.objectContaining({ - content: [ - { - type: ContentType.Text, - text: { value: 'test response', annotations: [] }, - }, - ], - status: MessageStatus.Ready, - }) - ) - }) - - it('should handle inference error', async () => { - const data: MessageRequest = { - model: { engine: 'test-provider', id: 'test-model' } as any, - threadId: 'test-thread', - type: MessageRequestType.Thread, - assistantId: 'test-assistant', - messages: [{ role: ChatCompletionRole.User, content: 'Hello' }], - } - - ;(ulid as jest.Mock).mockReturnValue('test-id') - ;(requestInference as jest.Mock).mockReturnValue({ - subscribe: ({ error }: any) => { - error({ message: 'test error', code: 500 }) - }, - }) - - await engine.inference(data) - - expect(events.emit).toHaveBeenLastCalledWith( - MessageEvent.OnMessageUpdate, - expect.objectContaining({ - status: 'error', - error_code: 500, - }) - ) - }) - it('should stop inference', () => { engine.stopInference() expect(engine.isCancelled).toBe(true) diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts index 50940a97b..3502aa1f7 100644 --- a/core/src/browser/extensions/engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -1,18 +1,9 @@ -import { requestInference } from './helpers/sse' -import { ulid } from 'ulidx' import { AIEngine } from './AIEngine' import { - ChatCompletionRole, - ContentType, InferenceEvent, MessageEvent, MessageRequest, - MessageRequestType, - MessageStatus, Model, - ModelInfo, - ThreadContent, - ThreadMessage, } from '../../../types' import { events } from '../../events' @@ -53,112 +44,6 @@ export abstract class OAIEngine extends AIEngine { */ override onUnload(): void {} - /* - * Inference request - */ - override async inference(data: MessageRequest) { - if (!data.model?.id) { - events.emit(MessageEvent.OnMessageResponse, { - status: MessageStatus.Error, - content: [ - { - type: ContentType.Text, - text: { - value: 'No model ID provided', - annotations: [], - }, - }, - ], - }) - return - } - - const timestamp = Date.now() / 1000 - const message: ThreadMessage = { - id: ulid(), - thread_id: data.thread?.id ?? data.threadId, - type: data.type, - assistant_id: data.assistantId, - role: ChatCompletionRole.Assistant, - content: [], - status: MessageStatus.Pending, - created_at: timestamp, - completed_at: timestamp, - object: 'thread.message', - } - - if (data.type !== MessageRequestType.Summary) { - events.emit(MessageEvent.OnMessageResponse, message) - } - - this.isCancelled = false - this.controller = new AbortController() - - const model: ModelInfo = { - ...(this.loadedModel ? this.loadedModel : {}), - ...data.model, - } - - const header = await this.headers() - let requestBody = { - messages: data.messages ?? [], - model: model.id, - stream: true, - tools: data.tools, - ...model.parameters, - } - if (this.transformPayload) { - requestBody = this.transformPayload(requestBody) - } - - requestInference( - this.inferenceUrl, - requestBody, - model, - this.controller, - header, - this.transformResponse - ).subscribe({ - next: (content: any) => { - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: content.trim(), - annotations: [], - }, - } - message.content = [messageContent] - events.emit(MessageEvent.OnMessageUpdate, message) - }, - complete: async () => { - message.status = message.content.length - ? MessageStatus.Ready - : MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - }, - error: async (err: any) => { - if (this.isCancelled || message.content.length) { - message.status = MessageStatus.Stopped - events.emit(MessageEvent.OnMessageUpdate, message) - return - } - message.status = MessageStatus.Error - message.content[0] = { - type: ContentType.Text, - text: { - value: - typeof message === 'string' - ? err.message - : (JSON.stringify(err.message) ?? err.detail), - annotations: [], - }, - } - message.error_code = err.code - events.emit(MessageEvent.OnMessageUpdate, message) - }, - }) - } - /** * Stops the inference. */ diff --git a/core/src/browser/extensions/engines/helpers/sse.test.ts b/core/src/browser/extensions/engines/helpers/sse.test.ts deleted file mode 100644 index f8c2ac6b4..000000000 --- a/core/src/browser/extensions/engines/helpers/sse.test.ts +++ /dev/null @@ -1,146 +0,0 @@ -import { lastValueFrom, Observable } from 'rxjs' -import { requestInference } from './sse' - -import { ReadableStream } from 'stream/web' -describe('requestInference', () => { - it('should send a request to the inference server and return an Observable', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: true, - json: () => - Promise.resolve({ - choices: [{ message: { content: 'Generated response' } }], - }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: 'OK', - // Add other required properties here - }) - ) - jest.spyOn(global, 'fetch').mockImplementation(mockFetch) - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com' - const requestBody = { message: 'Hello' } - const model = { id: 'model-id', parameters: { stream: false } } - - // Call the function - const result = requestInference(inferenceUrl, requestBody, model) - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).resolves.toEqual('Generated response') - }) - - it('returns 401 error', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: false, - json: () => - Promise.resolve({ - error: { message: 'Invalid API Key.', code: 'invalid_api_key' }, - }), - headers: new Headers(), - redirected: false, - status: 401, - statusText: 'invalid_api_key', - // Add other required properties here - }) - ) - jest.spyOn(global, 'fetch').mockImplementation(mockFetch) - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com' - const requestBody = { message: 'Hello' } - const model = { id: 'model-id', parameters: { stream: false } } - - // Call the function - const result = requestInference(inferenceUrl, requestBody, model) - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).rejects.toEqual({ - message: 'Invalid API Key.', - code: 'invalid_api_key', - }) - }) -}) - -it('should handle a successful response with a transformResponse function', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: true, - json: () => - Promise.resolve({ - choices: [{ message: { content: 'Generated response' } }], - }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: 'OK', - }) - ) - jest.spyOn(global, 'fetch').mockImplementation(mockFetch) - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com' - const requestBody = { message: 'Hello' } - const model = { id: 'model-id', parameters: { stream: false } } - const transformResponse = (data: any) => - data.choices[0].message.content.toUpperCase() - - // Call the function - const result = requestInference( - inferenceUrl, - requestBody, - model, - undefined, - undefined, - transformResponse - ) - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE') -}) - -it('should handle a successful response with streaming enabled', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: true, - body: new ReadableStream({ - start(controller) { - controller.enqueue( - new TextEncoder().encode( - 'data: {"choices": [{"delta": {"content": "Streamed"}}]}' - ) - ) - controller.enqueue(new TextEncoder().encode('data: [DONE]')) - controller.close() - }, - }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: 'OK', - }) - ) - jest.spyOn(global, 'fetch').mockImplementation(mockFetch) - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com' - const requestBody = { message: 'Hello' } - const model = { id: 'model-id', parameters: { stream: true } } - - // Call the function - const result = requestInference(inferenceUrl, requestBody, model) - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).resolves.toEqual('Streamed') -}) diff --git a/core/src/browser/extensions/engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts deleted file mode 100644 index 5c63008ff..000000000 --- a/core/src/browser/extensions/engines/helpers/sse.ts +++ /dev/null @@ -1,132 +0,0 @@ -import { Observable } from 'rxjs' -import { ErrorCode, ModelRuntimeParams } from '../../../../types' -/** - * Sends a request to the inference server to generate a response based on the recent messages. - * @param recentMessages - An array of recent messages to use as context for the inference. - * @returns An Observable that emits the generated response as a string. - */ -export function requestInference( - inferenceUrl: string, - requestBody: any, - model: { - id: string - parameters?: ModelRuntimeParams - }, - controller?: AbortController, - headers?: HeadersInit, - transformResponse?: Function -): Observable { - return new Observable((subscriber) => { - fetch(inferenceUrl, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Accept': model.parameters?.stream - ? 'text/event-stream' - : 'application/json', - ...headers, - }, - body: JSON.stringify(requestBody), - signal: controller?.signal, - }) - .then(async (response) => { - if (!response.ok) { - if (response.status === 401) { - throw { - code: ErrorCode.InvalidApiKey, - message: 'Invalid API Key.', - } - } - let data = await response.json() - try { - handleError(data) - } catch (err) { - subscriber.error(err) - return - } - } - // There could be overriden stream parameter in the model - // that is set in request body (transformed payload) - if ( - requestBody?.stream === false || - model.parameters?.stream === false - ) { - const data = await response.json() - try { - handleError(data) - } catch (err) { - subscriber.error(err) - return - } - if (transformResponse) { - subscriber.next(transformResponse(data)) - } else { - subscriber.next( - data.choices - ? data.choices[0]?.message?.content - : (data.content[0]?.text ?? '') - ) - } - } else { - const stream = response.body - const decoder = new TextDecoder('utf-8') - const reader = stream?.getReader() - let content = '' - - while (true && reader) { - const { done, value } = await reader.read() - if (done) { - break - } - const text = decoder.decode(value) - const lines = text.trim().split('\n') - let cachedLines = '' - for (const line of lines) { - try { - if (transformResponse) { - content += transformResponse(line) - subscriber.next(content ?? '') - } else { - const toParse = cachedLines + line - if (!line.includes('data: [DONE]')) { - const data = JSON.parse(toParse.replace('data: ', '')) - try { - handleError(data) - } catch (err) { - subscriber.error(err) - return - } - content += data.choices[0]?.delta?.content ?? '' - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', '') - } - if (content !== '') subscriber.next(content) - } - } - } catch { - cachedLines = line - } - } - } - } - subscriber.complete() - }) - .catch((err) => subscriber.error(err)) - }) -} - -/** - * Handle error and normalize it to a common format. - * @param data - */ -const handleError = (data: any) => { - if ( - data.error || - data.message || - data.detail || - (Array.isArray(data) && data.length && data[0].error) - ) { - throw data.error ?? data[0]?.error ?? data - } -} diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index b4b8a5033..0031c13aa 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -157,10 +157,13 @@ export function useActiveModel() { stopModel() return } - if (!activeModel) return + // if (!activeModel) return - const engine = EngineManager.instance().get(InferenceEngine.cortex) - engine?.stopInference() + // const engine = EngineManager.instance().get(InferenceEngine.cortex) + // engine?.stopInference() + // NOTE: this only works correctly if there is only 1 concurrent request + // at any point in time, which is a reasonable assumption to have. + EngineManager.instance().controller?.abort() }, [activeModel, stateModel, stopModel]) return { activeModel, startModel, stopModel, stopInference, stateModel } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index dbfc168bb..67501ab84 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -12,11 +12,9 @@ import { ThreadAssistantInfo, events, MessageEvent, - ContentType, EngineManager, InferenceEngine, MessageStatus, - ChatCompletionRole, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' @@ -26,12 +24,10 @@ import { ChatCompletionTool, ChatCompletionMessageToolCall, } from 'openai/resources/chat' - import { CompletionResponse, StreamCompletionResponse, TokenJS, - models, } from 'token.js' import { ulid } from 'ulidx' @@ -208,6 +204,18 @@ export default function useSendChatMessage( } // Build Message Request + // TODO: detect if model supports tools + const tools = (await window.core.api.getTools()) + ?.filter((tool: ModelTool) => !disabledTools.includes(tool.name)) + .map((tool: ModelTool) => ({ + type: 'function' as const, + function: { + name: tool.name, + description: tool.description?.slice(0, 1024), + parameters: tool.inputSchema, + strict: false, + }, + })) const requestBuilder = new MessageRequestBuilder( MessageRequestType.Thread, { @@ -217,17 +225,7 @@ export default function useSendChatMessage( }, activeThread, messages ?? currentMessages, - (await window.core.api.getTools()) - ?.filter((tool: ModelTool) => !disabledTools.includes(tool.name)) - .map((tool: ModelTool) => ({ - type: 'function' as const, - function: { - name: tool.name, - description: tool.description?.slice(0, 1024), - parameters: tool.inputSchema, - strict: false, - }, - })) + (tools && tools.length) ? tools : undefined, ).addSystemMessage(activeAssistant.instructions) requestBuilder.pushMessage(prompt, base64Blob, fileUpload) @@ -267,13 +265,15 @@ export default function useSendChatMessage( } // Start Model if not started + const isCortex = modelRequest.engine == InferenceEngine.cortex || + modelRequest.engine == InferenceEngine.cortex_llamacpp const modelId = selectedModel?.id ?? activeAssistantRef.current?.model.id if (base64Blob) { setFileUpload(undefined) } - if (modelRef.current?.id !== modelId && modelId) { + if (modelRef.current?.id !== modelId && modelId && isCortex) { const error = await startModel(modelId).catch((error: Error) => error) if (error) { updateThreadWaiting(activeThread.id, false) @@ -282,92 +282,98 @@ export default function useSendChatMessage( } setIsGeneratingResponse(true) - if (requestBuilder.tools && requestBuilder.tools.length) { - let isDone = false + let isDone = false - const engine = - engines?.[requestBuilder.model.engine as InferenceEngine]?.[0] - const apiKey = engine?.api_key - const provider = convertBuiltInEngine(engine?.engine) + const engine = + engines?.[requestBuilder.model.engine as InferenceEngine]?.[0] + const apiKey = engine?.api_key + const provider = convertBuiltInEngine(engine?.engine) - const tokenJS = new TokenJS({ - apiKey: apiKey ?? (await window.core.api.appToken()), - baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`, + const tokenJS = new TokenJS({ + apiKey: apiKey ?? (await window.core.api.appToken()), + baseURL: apiKey ? undefined : `${API_BASE_URL}/v1`, + }) + + extendBuiltInEngineModels(tokenJS, provider, modelId) + + // llama.cpp currently does not support streaming when tools are used. + const useStream = (requestBuilder.tools && isCortex) ? + false : + modelRequest.parameters?.stream + + let parentMessageId: string | undefined + while (!isDone) { + let messageId = ulid() + if (!parentMessageId) { + parentMessageId = ulid() + messageId = parentMessageId + } + const data = requestBuilder.build() + const message: ThreadMessage = createMessage({ + id: messageId, + thread_id: activeThread.id, + assistant_id: activeAssistant.assistant_id, + metadata: { + ...(messageId !== parentMessageId + ? { parent_id: parentMessageId } + : {}), + }, }) + events.emit(MessageEvent.OnMessageResponse, message) - extendBuiltInEngineModels(tokenJS, provider, modelId) - - let parentMessageId: string | undefined - while (!isDone) { - let messageId = ulid() - if (!parentMessageId) { - parentMessageId = ulid() - messageId = parentMessageId - } - const data = requestBuilder.build() - const message: ThreadMessage = createMessage({ - id: messageId, - thread_id: activeThread.id, - assistant_id: activeAssistant.assistant_id, - metadata: { - ...(messageId !== parentMessageId - ? { parent_id: parentMessageId } - : {}), - }, - }) - events.emit(MessageEvent.OnMessageResponse, message) - // Variables to track and accumulate streaming content - - if ( - data.model?.parameters?.stream && - data.model?.engine !== InferenceEngine.cortex && - data.model?.engine !== InferenceEngine.cortex_llamacpp - ) { - const response = await tokenJS.chat.completions.create({ + // we need to separate into 2 cases to appease linter + const controller = new AbortController() + EngineManager.instance().controller = controller + if (useStream) { + const response = await tokenJS.chat.completions.create( + { stream: true, provider, messages: requestBuilder.messages as ChatCompletionMessageParam[], model: data.model?.id ?? '', tools: data.tools as ChatCompletionTool[], - tool_choice: 'auto', - }) - - if (!message.content.length) { - message.content = emptyMessageContent + tool_choice: data.tools ? 'auto' : undefined, + }, + { + signal: controller.signal, } - - isDone = await processStreamingResponse( - response, - requestBuilder, - message - ) - } else { - const response = await tokenJS.chat.completions.create({ + ) + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = emptyMessageContent + } + isDone = await processStreamingResponse( + response, + requestBuilder, + message + ) + } else { + const response = await tokenJS.chat.completions.create( + { stream: false, provider, messages: requestBuilder.messages as ChatCompletionMessageParam[], model: data.model?.id ?? '', tools: data.tools as ChatCompletionTool[], - tool_choice: 'auto', - }) - // Variables to track and accumulate streaming content - if (!message.content.length) { - message.content = emptyMessageContent + tool_choice: data.tools ? 'auto' : undefined, + }, + { + signal: controller.signal, } - isDone = await processNonStreamingResponse( - response, - requestBuilder, - message - ) + ) + // Variables to track and accumulate streaming content + if (!message.content.length) { + message.content = emptyMessageContent } - message.status = MessageStatus.Ready - events.emit(MessageEvent.OnMessageUpdate, message) + isDone = await processNonStreamingResponse( + response, + requestBuilder, + message + ) } - } else { - // Request for inference - EngineManager.instance() - .get(InferenceEngine.cortex) - ?.inference(requestBuilder.build()) + + message.status = MessageStatus.Ready + events.emit(MessageEvent.OnMessageUpdate, message) } } catch (error) { setIsGeneratingResponse(false) diff --git a/web/package.json b/web/package.json index cf76a95f2..dd02f3de5 100644 --- a/web/package.json +++ b/web/package.json @@ -65,7 +65,7 @@ "swr": "^2.2.5", "tailwind-merge": "^2.0.0", "tailwindcss": "3.4.17", - "token.js": "npm:token.js-fork@0.7.2", + "token.js": "npm:token.js-fork@0.7.6", "ulidx": "^2.3.0", "use-debounce": "^10.0.0", "uuid": "^9.0.1",