feat: Add triton trtllm for engine for remote models

This commit is contained in:
hiro 2023-12-12 08:54:43 +07:00
parent 44d4368e55
commit f2688771ff
8 changed files with 478 additions and 1 deletions

View File

@ -174,7 +174,7 @@ export type ThreadState = {
enum InferenceEngine {
nitro = "nitro",
openai = "openai",
nvidia_triton = "nvidia_triton",
triton_trtllm = "triton_trtllm",
hf_endpoint = "hf_endpoint",
}

View File

@ -0,0 +1,78 @@
# Jan inference plugin
Created using Jan app example
# Create a Jan Plugin using Typescript
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
## Create Your Own Plugin
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository
2. Select Create a new repository
3. Select an owner and name for your new repository
4. Click Create repository
5. Clone your new repository
## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
> [!NOTE]
>
> You'll need to have a reasonably modern version of
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
> [`nodenv`](https://github.com/nodenv/nodenv) or
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
> root of your repository to install the version specified in
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
1. :hammer_and_wrench: Install the dependencies
```bash
npm install
```
1. :building_construction: Package the TypeScript for distribution
```bash
npm run bundle
```
1. :white_check_mark: Check your artifact
There will be a tgz file in your plugin directory now
## Update the Plugin Metadata
The [`package.json`](package.json) file defines metadata about your plugin, such as
plugin name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your plugin.
## Update the Plugin Code
The [`src/`](./src/) directory is the heart of your plugin! This contains the
source code that will be run when your plugin extension functions are invoked. You can replace the
contents of this directory with your own code.
There are a few things to keep in mind when writing your plugin code:
- Most Jan Plugin Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript
import { core } from "@janhq/core";
function onStart(): Promise<any> {
return core.invokePluginFunc(MODULE_PATH, "run", 0);
}
```
For more information about the Jan Plugin Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your plugin!

View File

@ -0,0 +1,41 @@
{
"name": "@janhq/inference-triton-trt-llm-extension",
"version": "1.0.0",
"description": "Inference Engine for NVIDIA Triton with TensorRT-LLM Extension integration on Jan extension framework",
"main": "dist/index.js",
"module": "dist/module.js",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4"
},
"dependencies": {
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1",
"ts-loader": "^9.5.0",
"ulid": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"fetch-retry"
]
}

View File

@ -0,0 +1,7 @@
import { Model } from "@janhq/core";
declare const MODULE: string;
declare interface EngineSettings {
base_url?: string;
}

View File

@ -0,0 +1,63 @@
import { Observable } from "rxjs";
import { EngineSettings } from "../@types/global";
import { Model } from "@janhq/core";
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const text_input = recentMessages.map((message) => message.text).join("\n");
const requestBody = JSON.stringify({
text_input: text_input,
max_tokens: 4096,
temperature: 0,
bad_words: "",
stop_words: "[DONE]",
stream: true
});
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origin": "*",
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
subscriber.next(content);
}
}
}
subscriber.complete();
})
.catch((err) => subscriber.error(err));
});
}

View File

@ -0,0 +1,235 @@
/**
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0
* @module inference-nvidia-triton-trt-llm-extension/src/index
*/
import {
ChatCompletionRole,
ContentType,
EventName,
MessageRequest,
MessageStatus,
ModelSettingParams,
ExtensionType,
ThreadContent,
ThreadMessage,
events,
fs,
Model,
} from "@janhq/core";
import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse";
import { ulid } from "ulid";
import { join } from "path";
import { EngineSettings } from "./@types/global";
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension {
private static readonly _homeDir = 'engines'
private static readonly _engineMetadataFileName = 'triton_trtllm.json'
static _currentModel: Model;
static _engineSettings: EngineSettings = {
"base_url": "",
};
controller = new AbortController();
isCancelled = false;
/**
* Returns the type of the extension.
* @returns {ExtensionType} The type of the extension.
*/
// TODO: To fix
type(): ExtensionType {
return undefined;
}
/**
* Subscribes to events emitted by the @janhq/core package.
*/
onLoad(): void {
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir)
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
// Events subscription
events.on(EventName.OnMessageSent, (data) =>
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this)
);
events.on(EventName.OnModelInit, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelInit(model);
});
events.on(EventName.OnModelStop, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelStop(model);
});
}
/**
* Stops the model inference.
*/
onUnload(): void {}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
return
}
static async writeDefaultEngineSettings() {
try {
const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName)
if (await fs.exists(engine_json)) {
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
}
else {
await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2))
}
} catch (err) {
console.error(err)
}
}
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true;
this.controller?.abort();
}
/**
* Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response.
*/
async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now();
const message: ThreadMessage = {
thread_id: data.threadId,
created: timestamp,
updated: timestamp,
status: MessageStatus.Ready,
id: "",
role: ChatCompletionRole.Assistant,
object: "thread.message",
content: [],
};
return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? [],
JanInferenceTritonTrtLLMExtension._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel)
.subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
},
error: async (err) => {
reject(err);
},
});
});
}
private static async handleModelInit(model: Model) {
if (model.engine !== 'triton_trtllm') { return }
else {
JanInferenceTritonTrtLLMExtension._currentModel = model
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(EventName.OnModelReady, model)
// events.emit(EventName.OnModelFail, model)
}
}
private static async handleModelStop(model: Model) {
if (model.engine !== 'triton_trtllm') { return }
events.emit(EventName.OnModelStopped, model)
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceTritonTrtLLMExtension
) {
if (data.model.engine !== 'triton_trtllm') { return }
const timestamp = Date.now();
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: "thread.message",
};
events.emit(EventName.OnMessageResponse, message);
instance.isCancelled = false;
instance.controller = new AbortController();
requestInference(
data?.messages ?? [],
this._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel,
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
};
message.content = [messageContent];
events.emit(EventName.OnMessageUpdate, message);
},
complete: async () => {
message.status = MessageStatus.Ready;
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: "Error occurred: " + err.message,
annotations: [],
},
};
message.content = [messageContent];
message.status = MessageStatus.Ready;
events.emit(EventName.OnMessageUpdate, message);
},
});
}
}

View File

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "es2016",
"module": "ES6",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
}

View File

@ -0,0 +1,38 @@
const path = require("path");
const webpack = require("webpack");
const packageJson = require("./package.json");
module.exports = {
experiments: { outputModule: true },
entry: "./src/index.ts", // Adjust the entry point to match your project's main file
mode: "production",
module: {
rules: [
{
test: /\.tsx?$/,
use: "ts-loader",
exclude: /node_modules/,
},
],
},
plugins: [
new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
}),
],
output: {
filename: "index.js", // Adjust the output file name as needed
path: path.resolve(__dirname, "dist"),
library: { type: "module" }, // Specify ESM output format
},
resolve: {
extensions: [".ts", ".js"],
fallback: {
path: require.resolve("path-browserify"),
},
},
optimization: {
minimize: false,
},
// Add loaders and other configuration as needed for your project
};