From 7a6890bd7f2b3318faf341d130af3a698a2813d4 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 13 Feb 2025 18:32:33 +0700 Subject: [PATCH] chore: remote engine error handling (#4646) * chore: Gemini error handling * chore: remote provider error handling * chore: remote provider error handling * chore: fix anthropic unsupported parameters * chore: fix tests --- .../extensions/engines/OAIEngine.test.ts | 22 ++- .../browser/extensions/engines/OAIEngine.ts | 13 +- .../extensions/engines/helpers/sse.test.ts | 150 ++++++++++-------- .../browser/extensions/engines/helpers/sse.ts | 56 ++++--- core/src/types/model/modelEntity.ts | 1 + .../models/anthropic.json | 3 + .../models/cohere.json | 2 + .../models/mistral.json | 3 + .../models/nvidia.json | 1 + .../models/openai.json | 18 +-- .../resources/anthropic.json | 2 +- .../engine-management-extension/src/index.ts | 2 +- .../bin/version.txt | 2 +- web/containers/AutoLink/index.tsx | 36 +++-- web/utils/componentSettings.ts | 4 + 15 files changed, 185 insertions(+), 130 deletions(-) diff --git a/core/src/browser/extensions/engines/OAIEngine.test.ts b/core/src/browser/extensions/engines/OAIEngine.test.ts index 81348786c..66537d0be 100644 --- a/core/src/browser/extensions/engines/OAIEngine.test.ts +++ b/core/src/browser/extensions/engines/OAIEngine.test.ts @@ -38,8 +38,14 @@ describe('OAIEngine', () => { it('should subscribe to events on load', () => { engine.onLoad() - expect(events.on).toHaveBeenCalledWith(MessageEvent.OnMessageSent, expect.any(Function)) - expect(events.on).toHaveBeenCalledWith(InferenceEvent.OnInferenceStopped, expect.any(Function)) + expect(events.on).toHaveBeenCalledWith( + MessageEvent.OnMessageSent, + expect.any(Function) + ) + expect(events.on).toHaveBeenCalledWith( + InferenceEvent.OnInferenceStopped, + expect.any(Function) + ) }) it('should handle inference request', async () => { @@ -77,7 +83,12 @@ describe('OAIEngine', () => { expect(events.emit).toHaveBeenCalledWith( MessageEvent.OnMessageUpdate, expect.objectContaining({ - content: [{ type: ContentType.Text, text: { value: 'test response', annotations: [] } }], + content: [ + { + type: ContentType.Text, + text: { value: 'test response', annotations: [] }, + }, + ], status: MessageStatus.Ready, }) ) @@ -101,11 +112,10 @@ describe('OAIEngine', () => { await engine.inference(data) - expect(events.emit).toHaveBeenCalledWith( + expect(events.emit).toHaveBeenLastCalledWith( MessageEvent.OnMessageUpdate, expect.objectContaining({ - content: [{ type: ContentType.Text, text: { value: 'test error', annotations: [] } }], - status: MessageStatus.Error, + status: 'error', error_code: 500, }) ) diff --git a/core/src/browser/extensions/engines/OAIEngine.ts b/core/src/browser/extensions/engines/OAIEngine.ts index 6b4c20a19..61032357c 100644 --- a/core/src/browser/extensions/engines/OAIEngine.ts +++ b/core/src/browser/extensions/engines/OAIEngine.ts @@ -42,7 +42,9 @@ export abstract class OAIEngine extends AIEngine { */ override onLoad() { super.onLoad() - events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) + events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) } @@ -128,7 +130,9 @@ export abstract class OAIEngine extends AIEngine { events.emit(MessageEvent.OnMessageUpdate, message) }, complete: async () => { - message.status = message.content.length ? MessageStatus.Ready : MessageStatus.Error + message.status = message.content.length + ? MessageStatus.Ready + : MessageStatus.Error events.emit(MessageEvent.OnMessageUpdate, message) }, error: async (err: any) => { @@ -141,7 +145,10 @@ export abstract class OAIEngine extends AIEngine { message.content[0] = { type: ContentType.Text, text: { - value: err.message, + value: + typeof message === 'string' + ? err.message + : (JSON.stringify(err.message) ?? err.detail), annotations: [], }, } diff --git a/core/src/browser/extensions/engines/helpers/sse.test.ts b/core/src/browser/extensions/engines/helpers/sse.test.ts index 0b78aa9b5..f8c2ac6b4 100644 --- a/core/src/browser/extensions/engines/helpers/sse.test.ts +++ b/core/src/browser/extensions/engines/helpers/sse.test.ts @@ -1,14 +1,17 @@ import { lastValueFrom, Observable } from 'rxjs' import { requestInference } from './sse' -import { ReadableStream } from 'stream/web'; +import { ReadableStream } from 'stream/web' describe('requestInference', () => { it('should send a request to the inference server and return an Observable', () => { // Mock the fetch function const mockFetch: any = jest.fn(() => Promise.resolve({ ok: true, - json: () => Promise.resolve({ choices: [{ message: { content: 'Generated response' } }] }), + json: () => + Promise.resolve({ + choices: [{ message: { content: 'Generated response' } }], + }), headers: new Headers(), redirected: false, status: 200, @@ -36,7 +39,10 @@ describe('requestInference', () => { const mockFetch: any = jest.fn(() => Promise.resolve({ ok: false, - json: () => Promise.resolve({ error: { message: 'Wrong API Key', code: 'invalid_api_key' } }), + json: () => + Promise.resolve({ + error: { message: 'Invalid API Key.', code: 'invalid_api_key' }, + }), headers: new Headers(), redirected: false, status: 401, @@ -56,69 +62,85 @@ describe('requestInference', () => { // Assert the expected behavior expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).rejects.toEqual({ message: 'Wrong API Key', code: 'invalid_api_key' }) + expect(lastValueFrom(result)).rejects.toEqual({ + message: 'Invalid API Key.', + code: 'invalid_api_key', + }) }) }) - it('should handle a successful response with a transformResponse function', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: true, - json: () => Promise.resolve({ choices: [{ message: { content: 'Generated response' } }] }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: 'OK', - }) - ) - jest.spyOn(global, 'fetch').mockImplementation(mockFetch) - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com' - const requestBody = { message: 'Hello' } - const model = { id: 'model-id', parameters: { stream: false } } - const transformResponse = (data: any) => data.choices[0].message.content.toUpperCase() - - // Call the function - const result = requestInference(inferenceUrl, requestBody, model, undefined, undefined, transformResponse) - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable) - expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE') - }) - - - it('should handle a successful response with streaming enabled', () => { - // Mock the fetch function - const mockFetch: any = jest.fn(() => - Promise.resolve({ - ok: true, - body: new ReadableStream({ - start(controller) { - controller.enqueue(new TextEncoder().encode('data: {"choices": [{"delta": {"content": "Streamed"}}]}')); - controller.enqueue(new TextEncoder().encode('data: [DONE]')); - controller.close(); - } +it('should handle a successful response with a transformResponse function', () => { + // Mock the fetch function + const mockFetch: any = jest.fn(() => + Promise.resolve({ + ok: true, + json: () => + Promise.resolve({ + choices: [{ message: { content: 'Generated response' } }], }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: 'OK', - }) - ); - jest.spyOn(global, 'fetch').mockImplementation(mockFetch); - - // Define the test inputs - const inferenceUrl = 'https://inference-server.com'; - const requestBody = { message: 'Hello' }; - const model = { id: 'model-id', parameters: { stream: true } }; - - // Call the function - const result = requestInference(inferenceUrl, requestBody, model); - - // Assert the expected behavior - expect(result).toBeInstanceOf(Observable); - expect(lastValueFrom(result)).resolves.toEqual('Streamed'); - }); + headers: new Headers(), + redirected: false, + status: 200, + statusText: 'OK', + }) + ) + jest.spyOn(global, 'fetch').mockImplementation(mockFetch) + // Define the test inputs + const inferenceUrl = 'https://inference-server.com' + const requestBody = { message: 'Hello' } + const model = { id: 'model-id', parameters: { stream: false } } + const transformResponse = (data: any) => + data.choices[0].message.content.toUpperCase() + + // Call the function + const result = requestInference( + inferenceUrl, + requestBody, + model, + undefined, + undefined, + transformResponse + ) + + // Assert the expected behavior + expect(result).toBeInstanceOf(Observable) + expect(lastValueFrom(result)).resolves.toEqual('GENERATED RESPONSE') +}) + +it('should handle a successful response with streaming enabled', () => { + // Mock the fetch function + const mockFetch: any = jest.fn(() => + Promise.resolve({ + ok: true, + body: new ReadableStream({ + start(controller) { + controller.enqueue( + new TextEncoder().encode( + 'data: {"choices": [{"delta": {"content": "Streamed"}}]}' + ) + ) + controller.enqueue(new TextEncoder().encode('data: [DONE]')) + controller.close() + }, + }), + headers: new Headers(), + redirected: false, + status: 200, + statusText: 'OK', + }) + ) + jest.spyOn(global, 'fetch').mockImplementation(mockFetch) + + // Define the test inputs + const inferenceUrl = 'https://inference-server.com' + const requestBody = { message: 'Hello' } + const model = { id: 'model-id', parameters: { stream: true } } + + // Call the function + const result = requestInference(inferenceUrl, requestBody, model) + + // Assert the expected behavior + expect(result).toBeInstanceOf(Observable) + expect(lastValueFrom(result)).resolves.toEqual('Streamed') +}) diff --git a/core/src/browser/extensions/engines/helpers/sse.ts b/core/src/browser/extensions/engines/helpers/sse.ts index 55cde56b4..5c63008ff 100644 --- a/core/src/browser/extensions/engines/helpers/sse.ts +++ b/core/src/browser/extensions/engines/helpers/sse.ts @@ -32,20 +32,19 @@ export function requestInference( }) .then(async (response) => { if (!response.ok) { - const data = await response.json() - let errorCode = ErrorCode.Unknown - if (data.error) { - errorCode = data.error.code ?? data.error.type ?? ErrorCode.Unknown - } else if (response.status === 401) { - errorCode = ErrorCode.InvalidApiKey + if (response.status === 401) { + throw { + code: ErrorCode.InvalidApiKey, + message: 'Invalid API Key.', + } } - const error = { - message: data.error?.message ?? data.message ?? 'Error occurred.', - code: errorCode, + let data = await response.json() + try { + handleError(data) + } catch (err) { + subscriber.error(err) + return } - subscriber.error(error) - subscriber.complete() - return } // There could be overriden stream parameter in the model // that is set in request body (transformed payload) @@ -54,9 +53,10 @@ export function requestInference( model.parameters?.stream === false ) { const data = await response.json() - if (data.error || data.message) { - subscriber.error(data.error ?? data) - subscriber.complete() + try { + handleError(data) + } catch (err) { + subscriber.error(err) return } if (transformResponse) { @@ -91,13 +91,10 @@ export function requestInference( const toParse = cachedLines + line if (!line.includes('data: [DONE]')) { const data = JSON.parse(toParse.replace('data: ', '')) - if ( - 'error' in data || - 'message' in data || - 'detail' in data - ) { - subscriber.error(data.error ?? data) - subscriber.complete() + try { + handleError(data) + } catch (err) { + subscriber.error(err) return } content += data.choices[0]?.delta?.content ?? '' @@ -118,3 +115,18 @@ export function requestInference( .catch((err) => subscriber.error(err)) }) } + +/** + * Handle error and normalize it to a common format. + * @param data + */ +const handleError = (data: any) => { + if ( + data.error || + data.message || + data.detail || + (Array.isArray(data) && data.length && data[0].error) + ) { + throw data.error ?? data[0]?.error ?? data + } +} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index d7db7b9d6..6e47c9ae4 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -150,6 +150,7 @@ export type ModelSettingParams = { */ export type ModelRuntimeParams = { temperature?: number + max_temperature?: number token_limit?: number top_k?: number top_p?: number diff --git a/extensions/engine-management-extension/models/anthropic.json b/extensions/engine-management-extension/models/anthropic.json index 46b5893d1..0212ce1da 100644 --- a/extensions/engine-management-extension/models/anthropic.json +++ b/extensions/engine-management-extension/models/anthropic.json @@ -8,6 +8,7 @@ "inference_params": { "max_tokens": 4096, "temperature": 0.7, + "max_temperature": 1.0, "stream": true }, "engine": "anthropic" @@ -21,6 +22,7 @@ "inference_params": { "max_tokens": 8192, "temperature": 0.7, + "max_temperature": 1.0, "stream": true }, "engine": "anthropic" @@ -34,6 +36,7 @@ "inference_params": { "max_tokens": 8192, "temperature": 0.7, + "max_temperature": 1.0, "stream": true }, "engine": "anthropic" diff --git a/extensions/engine-management-extension/models/cohere.json b/extensions/engine-management-extension/models/cohere.json index 458e4278b..96a830637 100644 --- a/extensions/engine-management-extension/models/cohere.json +++ b/extensions/engine-management-extension/models/cohere.json @@ -8,6 +8,7 @@ "inference_params": { "max_tokens": 4096, "temperature": 0.7, + "max_temperature": 1.0, "stream": false }, "engine": "cohere" @@ -21,6 +22,7 @@ "inference_params": { "max_tokens": 4096, "temperature": 0.7, + "max_temperature": 1.0, "stream": false }, "engine": "cohere" diff --git a/extensions/engine-management-extension/models/mistral.json b/extensions/engine-management-extension/models/mistral.json index 12fcf938d..47df5d506 100644 --- a/extensions/engine-management-extension/models/mistral.json +++ b/extensions/engine-management-extension/models/mistral.json @@ -8,6 +8,7 @@ "inference_params": { "max_tokens": 32000, "temperature": 0.7, + "max_temperature": 1.0, "top_p": 0.95, "stream": true }, @@ -22,6 +23,7 @@ "inference_params": { "max_tokens": 32000, "temperature": 0.7, + "max_temperature": 1.0, "top_p": 0.95, "stream": true }, @@ -36,6 +38,7 @@ "inference_params": { "max_tokens": 32000, "temperature": 0.7, + "max_temperature": 1.0, "top_p": 0.95, "stream": true }, diff --git a/extensions/engine-management-extension/models/nvidia.json b/extensions/engine-management-extension/models/nvidia.json index dfce9f8bc..cb6f9dec1 100644 --- a/extensions/engine-management-extension/models/nvidia.json +++ b/extensions/engine-management-extension/models/nvidia.json @@ -8,6 +8,7 @@ "inference_params": { "max_tokens": 1024, "temperature": 0.3, + "max_temperature": 1.0, "top_p": 1, "stream": false, "frequency_penalty": 0, diff --git a/extensions/engine-management-extension/models/openai.json b/extensions/engine-management-extension/models/openai.json index 7373118b3..ad3b2562d 100644 --- a/extensions/engine-management-extension/models/openai.json +++ b/extensions/engine-management-extension/models/openai.json @@ -79,11 +79,7 @@ "description": "OpenAI o1 is a new model with complex reasoning", "format": "api", "inference_params": { - "max_tokens": 100000, - "temperature": 1, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0 + "max_tokens": 100000 }, "engine": "openai" }, @@ -96,11 +92,7 @@ "format": "api", "inference_params": { "max_tokens": 32768, - "temperature": 1, - "top_p": 1, - "stream": true, - "frequency_penalty": 0, - "presence_penalty": 0 + "stream": true }, "engine": "openai" }, @@ -113,11 +105,7 @@ "format": "api", "inference_params": { "max_tokens": 65536, - "temperature": 1, - "top_p": 1, - "stream": true, - "frequency_penalty": 0, - "presence_penalty": 0 + "stream": true }, "engine": "openai" } diff --git a/extensions/engine-management-extension/resources/anthropic.json b/extensions/engine-management-extension/resources/anthropic.json index 2b73edcc1..98ac734b8 100644 --- a/extensions/engine-management-extension/resources/anthropic.json +++ b/extensions/engine-management-extension/resources/anthropic.json @@ -10,7 +10,7 @@ "transform_req": { "chat_completions": { "url": "https://api.anthropic.com/v1/messages", - "template": "{ {% for key, value in input_request %} {% if key == \"messages\" %} {% if input_request.messages.0.role == \"system\" %} \"system\": \"{{ input_request.messages.0.content }}\", \"messages\": [{% for message in input_request.messages %} {% if not loop.is_first %} {\"role\": \"{{ message.role }}\", \"content\": \"{{ message.content }}\" } {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %}] {% else %} \"messages\": [{% for message in input_request.messages %} {\"role\": \"{{ message.role}}\", \"content\": \"{{ message.content }}\" } {% if not loop.is_last %},{% endif %} {% endfor %}] {% endif %} {% if not loop.is_last %},{% endif %} {% else if key == \"system\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %}\"{{ key }}\": {{ tojson(value) }} {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %} }" + "template": "{ {% for key, value in input_request %} {% if key == \"messages\" %} {% if input_request.messages.0.role == \"system\" %} \"system\": {{ tojson(input_request.messages.0.content) }}, \"messages\": [{% for message in input_request.messages %} {% if not loop.is_first %} {\"role\": {{ tojson(message.role) }}, \"content\": {% if not message.content or message.content == \"\" %} \".\" {% else %} {{ tojson(message.content) }} {% endif %} } {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %}] {% else %} \"messages\": [{% for message in input_request.messages %} {\"role\": {{ tojson(message.role) }}, \"content\": {% if not message.content or message.content == \"\" %} \".\" {% else %} {{ tojson(message.content) }} {% endif %} } {% if not loop.is_last %},{% endif %} {% endfor %}] {% endif %} {% if not loop.is_last %},{% endif %} {% else if key == \"system\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"metadata\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %}\"{{ key }}\": {{ tojson(value) }} {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %} }" } }, "transform_resp": { diff --git a/extensions/engine-management-extension/src/index.ts b/extensions/engine-management-extension/src/index.ts index f2371883e..e2730cc71 100644 --- a/extensions/engine-management-extension/src/index.ts +++ b/extensions/engine-management-extension/src/index.ts @@ -199,7 +199,7 @@ export default class JanEngineManagementExtension extends EngineManagementExtens .post(`${API_URL}/v1/models/add`, { json: { inference_params: { - max_tokens: 8192, + max_tokens: 4096, temperature: 0.7, top_p: 0.95, stream: true, diff --git a/extensions/inference-cortex-extension/bin/version.txt b/extensions/inference-cortex-extension/bin/version.txt index 8a3eae09c..92fe52359 100644 --- a/extensions/inference-cortex-extension/bin/version.txt +++ b/extensions/inference-cortex-extension/bin/version.txt @@ -1 +1 @@ -1.0.10-rc6 +1.0.10-rc7 diff --git a/web/containers/AutoLink/index.tsx b/web/containers/AutoLink/index.tsx index 66c84f7f7..0f10f478a 100644 --- a/web/containers/AutoLink/index.tsx +++ b/web/containers/AutoLink/index.tsx @@ -10,23 +10,25 @@ const AutoLink = ({ text }: Props) => { return ( <> - {text.split(delimiter).map((word) => { - const match = word.match(delimiter) - if (match) { - const url = match[0] - return ( - - {url} - - ) - } - return word - })} + {text && + typeof text === 'string' && + text.split(delimiter).map((word) => { + const match = word.match(delimiter) + if (match) { + const url = match[0] + return ( + + {url} + + ) + } + return word + })} ) } diff --git a/web/utils/componentSettings.ts b/web/utils/componentSettings.ts index 6e55d02e5..8ebcfd7c9 100644 --- a/web/utils/componentSettings.ts +++ b/web/utils/componentSettings.ts @@ -27,6 +27,10 @@ export const getConfigurationsData = ( componentSetting.controllerProps.max || 4096 break + case 'temperature': + componentSetting.controllerProps.max = + selectedModel?.parameters?.max_temperature || 2 + break case 'ctx_len': componentSetting.controllerProps.max = selectedModel?.settings.ctx_len ||