fix: #3698 - o1 preview models do not work with max_tokens (#3728)

This commit is contained in:
Louis 2024-09-24 16:35:08 +07:00 committed by GitHub
parent 36c1306390
commit acd3be3a2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 90 additions and 4 deletions

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -0,0 +1,54 @@
/**
* @jest-environment jsdom
*/
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
RemoteOAIEngine: jest.fn().mockImplementation(() => ({
onLoad: jest.fn(),
registerSettings: jest.fn(),
registerModels: jest.fn(),
getSetting: jest.fn(),
onSettingUpdate: jest.fn(),
})),
}))
import JanInferenceOpenAIExtension, { Settings } from '.'
describe('JanInferenceOpenAIExtension', () => {
let extension: JanInferenceOpenAIExtension
beforeEach(() => {
// @ts-ignore
extension = new JanInferenceOpenAIExtension()
})
it('should initialize with settings and models', async () => {
await extension.onLoad()
// Assuming there are some default SETTINGS and MODELS being registered
expect(extension.apiKey).toBe(undefined)
expect(extension.inferenceUrl).toBe('')
})
it('should transform the payload for preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'o1-mini',
// Add other required properties...
}
const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens)
expect(transformedPayload).not.toHaveProperty('max_tokens')
expect(transformedPayload).toHaveProperty('max_completion_tokens')
})
it('should not transform the payload for non-preview models', () => {
const payload: any = {
max_tokens: 100,
model: 'non-preview-model',
// Add other required properties...
}
const transformedPayload = extension.transformPayload(payload)
expect(transformedPayload).toEqual(payload)
})
})

View File

@ -6,16 +6,17 @@
* @module inference-openai-extension/src/index
*/
import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core'
import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core'
declare const SETTINGS: Array<any>
declare const MODELS: Array<any>
enum Settings {
export enum Settings {
apiKey = 'openai-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
type OpenAIPayloadType = PayloadType &
ModelRuntimeParams & { max_completion_tokens: number }
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
@ -24,6 +25,7 @@ enum Settings {
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'openai'
previewModels = ['o1-mini', 'o1-preview']
override async onLoad(): Promise<void> {
super.onLoad()
@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
}
}
}
/**
* Tranform the payload before sending it to the inference endpoint.
* The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter.
* Others do not.
* @param payload
* @returns
*/
transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => {
// Transform the payload for preview models
if (this.previewModels.includes(payload.model)) {
const { max_tokens, ...params } = payload
return {
...params,
max_completion_tokens: max_tokens,
}
}
// Pass through for non-preview models
return payload
}
}

View File

@ -10,5 +10,6 @@
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}