Louis 7a6890bd7f
chore: remote engine error handling (#4646)
* chore: Gemini error handling

* chore: remote provider error handling

* chore: remote provider error handling

* chore: fix anthropic unsupported parameters

* chore: fix tests
2025-02-13 18:32:33 +07:00

176 lines
4.1 KiB
TypeScript

import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { AIEngine } from './AIEngine'
import {
ChatCompletionRole,
ContentType,
InferenceEvent,
MessageEvent,
MessageRequest,
MessageRequestType,
MessageStatus,
Model,
ModelInfo,
ThreadContent,
ThreadMessage,
} from '../../../types'
import { events } from '../../events'
/**
* Base OAI Inference Provider
* Applicable to all OAI compatible inference providers
*/
export abstract class OAIEngine extends AIEngine {
// The inference engine
abstract inferenceUrl: string
// Controller to handle stop requests
controller = new AbortController()
isCancelled = false
// The loaded model instance
loadedModel: Model | undefined
// Transform the payload
transformPayload?: Function
// Transform the response
transformResponse?: Function
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.inference(data)
)
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
}
/**
* On extension unload
*/
override onUnload(): void {}
/*
* Inference request
*/
override async inference(data: MessageRequest) {
if (!data.model?.id) {
events.emit(MessageEvent.OnMessageResponse, {
status: MessageStatus.Error,
content: [
{
type: ContentType.Text,
text: {
value: 'No model ID provided',
annotations: [],
},
},
],
})
return
}
const timestamp = Date.now() / 1000
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created_at: timestamp,
completed_at: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
this.isCancelled = false
this.controller = new AbortController()
const model: ModelInfo = {
...(this.loadedModel ? this.loadedModel : {}),
...data.model,
}
const header = await this.headers()
let requestBody = {
messages: data.messages ?? [],
model: model.id,
stream: true,
...model.parameters,
}
if (this.transformPayload) {
requestBody = this.transformPayload(requestBody)
}
requestInference(
this.inferenceUrl,
requestBody,
model,
this.controller,
header,
this.transformResponse
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err: any) => {
if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
message.status = MessageStatus.Error
message.content[0] = {
type: ContentType.Text,
text: {
value:
typeof message === 'string'
? err.message
: (JSON.stringify(err.message) ?? err.detail),
annotations: [],
},
}
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
/**
* Stops the inference.
*/
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
/**
* Headers for the inference request
*/
async headers(): Promise<HeadersInit> {
return {}
}
}