diff --git a/extensions/inference-openai-extension/jest.config.js b/extensions/inference-openai-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/inference-openai-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/inference-openai-extension/src/OpenAIExtension.test.ts b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts new file mode 100644 index 000000000..4d46bc007 --- /dev/null +++ b/extensions/inference-openai-extension/src/OpenAIExtension.test.ts @@ -0,0 +1,54 @@ +/** + * @jest-environment jsdom + */ +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + RemoteOAIEngine: jest.fn().mockImplementation(() => ({ + onLoad: jest.fn(), + registerSettings: jest.fn(), + registerModels: jest.fn(), + getSetting: jest.fn(), + onSettingUpdate: jest.fn(), + })), +})) +import JanInferenceOpenAIExtension, { Settings } from '.' + +describe('JanInferenceOpenAIExtension', () => { + let extension: JanInferenceOpenAIExtension + + beforeEach(() => { + // @ts-ignore + extension = new JanInferenceOpenAIExtension() + }) + + it('should initialize with settings and models', async () => { + await extension.onLoad() + // Assuming there are some default SETTINGS and MODELS being registered + expect(extension.apiKey).toBe(undefined) + expect(extension.inferenceUrl).toBe('') + }) + + it('should transform the payload for preview models', () => { + const payload: any = { + max_tokens: 100, + model: 'o1-mini', + // Add other required properties... + } + + const transformedPayload = extension.transformPayload(payload) + expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens) + expect(transformedPayload).not.toHaveProperty('max_tokens') + expect(transformedPayload).toHaveProperty('max_completion_tokens') + }) + + it('should not transform the payload for non-preview models', () => { + const payload: any = { + max_tokens: 100, + model: 'non-preview-model', + // Add other required properties... + } + + const transformedPayload = extension.transformPayload(payload) + expect(transformedPayload).toEqual(payload) + }) +}) diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index 60446ccce..44c243adf 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -6,16 +6,17 @@ * @module inference-openai-extension/src/index */ -import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core' +import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core' declare const SETTINGS: Array declare const MODELS: Array -enum Settings { +export enum Settings { apiKey = 'openai-api-key', chatCompletionsEndPoint = 'chat-completions-endpoint', } - +type OpenAIPayloadType = PayloadType & + ModelRuntimeParams & { max_completion_tokens: number } /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -24,6 +25,7 @@ enum Settings { export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { inferenceUrl: string = '' provider: string = 'openai' + previewModels = ['o1-mini', 'o1-preview'] override async onLoad(): Promise { super.onLoad() @@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { } } } + + /** + * Tranform the payload before sending it to the inference endpoint. + * The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter. + * Others do not. + * @param payload + * @returns + */ + transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => { + // Transform the payload for preview models + if (this.previewModels.includes(payload.model)) { + const { max_tokens, ...params } = payload + return { + ...params, + max_completion_tokens: max_tokens, + } + } + // Pass through for non-preview models + return payload + } } diff --git a/extensions/inference-openai-extension/tsconfig.json b/extensions/inference-openai-extension/tsconfig.json index 2477d58ce..6db951c9e 100644 --- a/extensions/inference-openai-extension/tsconfig.json +++ b/extensions/inference-openai-extension/tsconfig.json @@ -10,5 +10,6 @@ "skipLibCheck": true, "rootDir": "./src" }, - "include": ["./src"] + "include": ["./src"], + "exclude": ["**/*.test.ts"] }