chore: stop inference event

This commit is contained in:
Louis 2023-12-08 16:38:13 +07:00 committed by hiro
parent ef9dfc9cce
commit b939692187
6 changed files with 32 additions and 95 deletions

View File

@ -18,6 +18,8 @@ export enum EventName {
OnModelStop = "OnModelStop", OnModelStop = "OnModelStop",
/** The `OnModelStopped` event is emitted when a model stopped ok. */ /** The `OnModelStopped` event is emitted when a model stopped ok. */
OnModelStopped = "OnModelStopped", OnModelStopped = "OnModelStopped",
/** The `OnInferenceStopped` event is emitted when a inference is stopped. */
OnInferenceStopped = "OnInferenceStopped",
} }
/** /**

View File

@ -5,26 +5,10 @@ import { BaseExtension } from "../extension";
* Inference extension. Start, stop and inference models. * Inference extension. Start, stop and inference models.
*/ */
export abstract class InferenceExtension extends BaseExtension { export abstract class InferenceExtension extends BaseExtension {
/**
* Initializes the model for the extension.
* @param modelId - The ID of the model to initialize.
*/
abstract initModel(modelId: string, settings?: ModelSettingParams): Promise<void>;
/**
* Stops the model for the extension.
*/
abstract stopModel(): Promise<void>;
/**
* Stops the streaming inference.
*/
abstract stopInference(): Promise<void>;
/** /**
* Processes an inference request. * Processes an inference request.
* @param data - The data for the inference request. * @param data - The data for the inference request.
* @returns The result of the inference request. * @returns The result of the inference request.
*/ */
abstract inferenceRequest(data: MessageRequest): Promise<ThreadMessage>; abstract inference(data: MessageRequest): Promise<ThreadMessage>;
} }

View File

@ -74,41 +74,17 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
events.on(EventName.OnModelStop, (model: Model) => { events.on(EventName.OnModelStop, (model: Model) => {
JanInferenceNitroExtension.handleModelStop(model); JanInferenceNitroExtension.handleModelStop(model);
}); });
events.on(EventName.OnInferenceStopped, () => {
JanInferenceNitroExtension.handleInferenceStopped(this);
});
} }
/** /**
* Stops the model inference. * Stops the model inference.
*/ */
onUnload(): void { onUnload(): void {}
this.stopModel();
}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {}
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {
return executeOnMain(MODULE, "killSubprocess");
}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true;
this.controller?.abort();
}
private async writeDefaultEngineSettings() { private async writeDefaultEngineSettings() {
try { try {
@ -160,12 +136,19 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
} }
} }
private static async handleInferenceStopped(
instance: JanInferenceNitroExtension
) {
instance.isCancelled = true;
instance.controller?.abort();
}
/** /**
* Makes a single response inference request. * Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request. * @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response. * @returns {Promise<any>} A promise that resolves with the inference response.
*/ */
async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> { async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
thread_id: data.threadId, thread_id: data.threadId,

View File

@ -71,6 +71,9 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
events.on(EventName.OnModelStop, (model: OpenAIModel) => { events.on(EventName.OnModelStop, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelStop(model); JanInferenceOpenAIExtension.handleModelStop(model);
}); });
events.on(EventName.OnInferenceStopped, () => {
JanInferenceOpenAIExtension.handleInferenceStopped(this);
});
} }
/** /**
@ -78,18 +81,6 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
*/ */
onUnload(): void {} onUnload(): void {}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
return;
}
static async writeDefaultEngineSettings() { static async writeDefaultEngineSettings() {
try { try {
const engineFile = join( const engineFile = join(
@ -110,27 +101,13 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
console.error(err); console.error(err);
} }
} }
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true;
this.controller?.abort();
}
/** /**
* Makes a single response inference request. * Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request. * @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response. * @returns {Promise<any>} A promise that resolves with the inference response.
*/ */
async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> { async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
thread_id: data.threadId, thread_id: data.threadId,
@ -168,7 +145,6 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
JanInferenceOpenAIExtension.writeDefaultEngineSettings(); JanInferenceOpenAIExtension.writeDefaultEngineSettings();
// Todo: Check model list with API key // Todo: Check model list with API key
events.emit(EventName.OnModelReady, model); events.emit(EventName.OnModelReady, model);
// events.emit(EventName.OnModelFail, model)
} }
} }
@ -179,6 +155,13 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
events.emit(EventName.OnModelStopped, model); events.emit(EventName.OnModelStopped, model);
} }
private static async handleInferenceStopped(
instance: JanInferenceOpenAIExtension
) {
instance.isCancelled = true;
instance.controller?.abort();
}
/** /**
* Handles a new message request by making an inference request and emitting events. * Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues. * Function registered in event manager, should be static to avoid binding issues.

View File

@ -30,9 +30,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
const { resendChatMessage } = useSendChatMessage() const { resendChatMessage } = useSendChatMessage()
const onStopInferenceClick = async () => { const onStopInferenceClick = async () => {
await extensionManager events.emit(EventName.OnInferenceStopped, {})
.get<InferenceExtension>(ExtensionType.Inference)
?.stopInference()
setTimeout(() => { setTimeout(() => {
events.emit(EventName.OnMessageUpdate, { events.emit(EventName.OnMessageUpdate, {
...message, ...message,

View File

@ -55,23 +55,9 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
let downloadButton; let downloadButton = (
<Button onClick={() => onDownloadClick()}>Download</Button>
if (model.engine === 'openai') { )
downloadButton = (
<Button onClick={() => onDownloadClick()}>
Use
</Button>
);
} else if (model.engine === 'nitro') {
downloadButton = (
<Button onClick={() => onDownloadClick()}>
{model.metadata.size
? `Download (${toGigabytes(model.metadata.size)})`
: 'Download'}
</Button>
);
}
const onUseModelClick = () => { const onUseModelClick = () => {
startModel(model.id) startModel(model.id)