parent
36c1306390
commit
acd3be3a2a
9
extensions/inference-openai-extension/jest.config.js
Normal file
9
extensions/inference-openai-extension/jest.config.js
Normal file
@ -0,0 +1,9 @@
|
||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
transform: {
|
||||
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
|
||||
},
|
||||
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
|
||||
}
|
||||
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
RemoteOAIEngine: jest.fn().mockImplementation(() => ({
|
||||
onLoad: jest.fn(),
|
||||
registerSettings: jest.fn(),
|
||||
registerModels: jest.fn(),
|
||||
getSetting: jest.fn(),
|
||||
onSettingUpdate: jest.fn(),
|
||||
})),
|
||||
}))
|
||||
import JanInferenceOpenAIExtension, { Settings } from '.'
|
||||
|
||||
describe('JanInferenceOpenAIExtension', () => {
|
||||
let extension: JanInferenceOpenAIExtension
|
||||
|
||||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
extension = new JanInferenceOpenAIExtension()
|
||||
})
|
||||
|
||||
it('should initialize with settings and models', async () => {
|
||||
await extension.onLoad()
|
||||
// Assuming there are some default SETTINGS and MODELS being registered
|
||||
expect(extension.apiKey).toBe(undefined)
|
||||
expect(extension.inferenceUrl).toBe('')
|
||||
})
|
||||
|
||||
it('should transform the payload for preview models', () => {
|
||||
const payload: any = {
|
||||
max_tokens: 100,
|
||||
model: 'o1-mini',
|
||||
// Add other required properties...
|
||||
}
|
||||
|
||||
const transformedPayload = extension.transformPayload(payload)
|
||||
expect(transformedPayload.max_completion_tokens).toBe(payload.max_tokens)
|
||||
expect(transformedPayload).not.toHaveProperty('max_tokens')
|
||||
expect(transformedPayload).toHaveProperty('max_completion_tokens')
|
||||
})
|
||||
|
||||
it('should not transform the payload for non-preview models', () => {
|
||||
const payload: any = {
|
||||
max_tokens: 100,
|
||||
model: 'non-preview-model',
|
||||
// Add other required properties...
|
||||
}
|
||||
|
||||
const transformedPayload = extension.transformPayload(payload)
|
||||
expect(transformedPayload).toEqual(payload)
|
||||
})
|
||||
})
|
||||
@ -6,16 +6,17 @@
|
||||
* @module inference-openai-extension/src/index
|
||||
*/
|
||||
|
||||
import { RemoteOAIEngine, SettingComponentProps } from '@janhq/core'
|
||||
import { ModelRuntimeParams, PayloadType, RemoteOAIEngine } from '@janhq/core'
|
||||
|
||||
declare const SETTINGS: Array<any>
|
||||
declare const MODELS: Array<any>
|
||||
|
||||
enum Settings {
|
||||
export enum Settings {
|
||||
apiKey = 'openai-api-key',
|
||||
chatCompletionsEndPoint = 'chat-completions-endpoint',
|
||||
}
|
||||
|
||||
type OpenAIPayloadType = PayloadType &
|
||||
ModelRuntimeParams & { max_completion_tokens: number }
|
||||
/**
|
||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
@ -24,6 +25,7 @@ enum Settings {
|
||||
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
|
||||
inferenceUrl: string = ''
|
||||
provider: string = 'openai'
|
||||
previewModels = ['o1-mini', 'o1-preview']
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
super.onLoad()
|
||||
@ -63,4 +65,24 @@ export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tranform the payload before sending it to the inference endpoint.
|
||||
* The new preview models such as o1-mini and o1-preview replaced max_tokens by max_completion_tokens parameter.
|
||||
* Others do not.
|
||||
* @param payload
|
||||
* @returns
|
||||
*/
|
||||
transformPayload = (payload: OpenAIPayloadType): OpenAIPayloadType => {
|
||||
// Transform the payload for preview models
|
||||
if (this.previewModels.includes(payload.model)) {
|
||||
const { max_tokens, ...params } = payload
|
||||
return {
|
||||
...params,
|
||||
max_completion_tokens: max_tokens,
|
||||
}
|
||||
}
|
||||
// Pass through for non-preview models
|
||||
return payload
|
||||
}
|
||||
}
|
||||
|
||||
@ -10,5 +10,6 @@
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"]
|
||||
"include": ["./src"],
|
||||
"exclude": ["**/*.test.ts"]
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user