feat: Add triton trtllm for engine for remote models
This commit is contained in:
parent
44d4368e55
commit
f2688771ff
@ -174,7 +174,7 @@ export type ThreadState = {
|
|||||||
enum InferenceEngine {
|
enum InferenceEngine {
|
||||||
nitro = "nitro",
|
nitro = "nitro",
|
||||||
openai = "openai",
|
openai = "openai",
|
||||||
nvidia_triton = "nvidia_triton",
|
triton_trtllm = "triton_trtllm",
|
||||||
hf_endpoint = "hf_endpoint",
|
hf_endpoint = "hf_endpoint",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
78
extensions/inference-triton-trtllm-extension/README.md
Normal file
78
extensions/inference-triton-trtllm-extension/README.md
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Jan inference plugin
|
||||||
|
|
||||||
|
Created using Jan app example
|
||||||
|
|
||||||
|
# Create a Jan Plugin using Typescript
|
||||||
|
|
||||||
|
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
|
||||||
|
|
||||||
|
## Create Your Own Plugin
|
||||||
|
|
||||||
|
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
|
||||||
|
|
||||||
|
1. Click the Use this template button at the top of the repository
|
||||||
|
2. Select Create a new repository
|
||||||
|
3. Select an owner and name for your new repository
|
||||||
|
4. Click Create repository
|
||||||
|
5. Clone your new repository
|
||||||
|
|
||||||
|
## Initial Setup
|
||||||
|
|
||||||
|
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
>
|
||||||
|
> You'll need to have a reasonably modern version of
|
||||||
|
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
|
||||||
|
> [`nodenv`](https://github.com/nodenv/nodenv) or
|
||||||
|
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
|
||||||
|
> root of your repository to install the version specified in
|
||||||
|
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
|
||||||
|
|
||||||
|
1. :hammer_and_wrench: Install the dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
1. :building_construction: Package the TypeScript for distribution
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run bundle
|
||||||
|
```
|
||||||
|
|
||||||
|
1. :white_check_mark: Check your artifact
|
||||||
|
|
||||||
|
There will be a tgz file in your plugin directory now
|
||||||
|
|
||||||
|
## Update the Plugin Metadata
|
||||||
|
|
||||||
|
The [`package.json`](package.json) file defines metadata about your plugin, such as
|
||||||
|
plugin name, main entry, description and version.
|
||||||
|
|
||||||
|
When you copy this repository, update `package.json` with the name, description for your plugin.
|
||||||
|
|
||||||
|
## Update the Plugin Code
|
||||||
|
|
||||||
|
The [`src/`](./src/) directory is the heart of your plugin! This contains the
|
||||||
|
source code that will be run when your plugin extension functions are invoked. You can replace the
|
||||||
|
contents of this directory with your own code.
|
||||||
|
|
||||||
|
There are a few things to keep in mind when writing your plugin code:
|
||||||
|
|
||||||
|
- Most Jan Plugin Extension functions are processed asynchronously.
|
||||||
|
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { core } from "@janhq/core";
|
||||||
|
|
||||||
|
function onStart(): Promise<any> {
|
||||||
|
return core.invokePluginFunc(MODULE_PATH, "run", 0);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information about the Jan Plugin Core module, see the
|
||||||
|
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
|
||||||
|
|
||||||
|
So, what are you waiting for? Go ahead and start customizing your plugin!
|
||||||
|
|
||||||
41
extensions/inference-triton-trtllm-extension/package.json
Normal file
41
extensions/inference-triton-trtllm-extension/package.json
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
{
|
||||||
|
"name": "@janhq/inference-triton-trt-llm-extension",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Inference Engine for NVIDIA Triton with TensorRT-LLM Extension integration on Jan extension framework",
|
||||||
|
"main": "dist/index.js",
|
||||||
|
"module": "dist/module.js",
|
||||||
|
"author": "Jan <service@jan.ai>",
|
||||||
|
"license": "AGPL-3.0",
|
||||||
|
"scripts": {
|
||||||
|
"build": "tsc -b . && webpack --config webpack.config.js",
|
||||||
|
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install"
|
||||||
|
},
|
||||||
|
"exports": {
|
||||||
|
".": "./dist/index.js",
|
||||||
|
"./main": "./dist/module.js"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"cpx": "^1.5.0",
|
||||||
|
"rimraf": "^3.0.2",
|
||||||
|
"webpack": "^5.88.2",
|
||||||
|
"webpack-cli": "^5.1.4"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@janhq/core": "file:../../core",
|
||||||
|
"fetch-retry": "^5.0.6",
|
||||||
|
"path-browserify": "^1.0.1",
|
||||||
|
"ts-loader": "^9.5.0",
|
||||||
|
"ulid": "^2.3.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18.0.0"
|
||||||
|
},
|
||||||
|
"files": [
|
||||||
|
"dist/*",
|
||||||
|
"package.json",
|
||||||
|
"README.md"
|
||||||
|
],
|
||||||
|
"bundleDependencies": [
|
||||||
|
"fetch-retry"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
extensions/inference-triton-trtllm-extension/src/@types/global.d.ts
vendored
Normal file
7
extensions/inference-triton-trtllm-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import { Model } from "@janhq/core";
|
||||||
|
|
||||||
|
declare const MODULE: string;
|
||||||
|
|
||||||
|
declare interface EngineSettings {
|
||||||
|
base_url?: string;
|
||||||
|
}
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
import { Observable } from "rxjs";
|
||||||
|
import { EngineSettings } from "../@types/global";
|
||||||
|
import { Model } from "@janhq/core";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||||
|
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||||
|
* @param engine - The engine settings to use for the inference.
|
||||||
|
* @param model - The model to use for the inference.
|
||||||
|
* @returns An Observable that emits the generated response as a string.
|
||||||
|
*/
|
||||||
|
export function requestInference(
|
||||||
|
recentMessages: any[],
|
||||||
|
engine: EngineSettings,
|
||||||
|
model: Model,
|
||||||
|
controller?: AbortController
|
||||||
|
): Observable<string> {
|
||||||
|
return new Observable((subscriber) => {
|
||||||
|
const text_input = recentMessages.map((message) => message.text).join("\n");
|
||||||
|
const requestBody = JSON.stringify({
|
||||||
|
text_input: text_input,
|
||||||
|
max_tokens: 4096,
|
||||||
|
temperature: 0,
|
||||||
|
bad_words: "",
|
||||||
|
stop_words: "[DONE]",
|
||||||
|
stream: true
|
||||||
|
});
|
||||||
|
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Accept: "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
},
|
||||||
|
body: requestBody,
|
||||||
|
signal: controller?.signal,
|
||||||
|
})
|
||||||
|
.then(async (response) => {
|
||||||
|
const stream = response.body;
|
||||||
|
const decoder = new TextDecoder("utf-8");
|
||||||
|
const reader = stream?.getReader();
|
||||||
|
let content = "";
|
||||||
|
|
||||||
|
while (true && reader) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const text = decoder.decode(value);
|
||||||
|
const lines = text.trim().split("\n");
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||||
|
const data = JSON.parse(line.replace("data: ", ""));
|
||||||
|
content += data.choices[0]?.delta?.content ?? "";
|
||||||
|
subscriber.next(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
subscriber.complete();
|
||||||
|
})
|
||||||
|
.catch((err) => subscriber.error(err));
|
||||||
|
});
|
||||||
|
}
|
||||||
235
extensions/inference-triton-trtllm-extension/src/index.ts
Normal file
235
extensions/inference-triton-trtllm-extension/src/index.ts
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
/**
|
||||||
|
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
|
* @version 1.0.0
|
||||||
|
* @module inference-nvidia-triton-trt-llm-extension/src/index
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
ChatCompletionRole,
|
||||||
|
ContentType,
|
||||||
|
EventName,
|
||||||
|
MessageRequest,
|
||||||
|
MessageStatus,
|
||||||
|
ModelSettingParams,
|
||||||
|
ExtensionType,
|
||||||
|
ThreadContent,
|
||||||
|
ThreadMessage,
|
||||||
|
events,
|
||||||
|
fs,
|
||||||
|
Model,
|
||||||
|
} from "@janhq/core";
|
||||||
|
import { InferenceExtension } from "@janhq/core";
|
||||||
|
import { requestInference } from "./helpers/sse";
|
||||||
|
import { ulid } from "ulid";
|
||||||
|
import { join } from "path";
|
||||||
|
import { EngineSettings } from "./@types/global";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
|
*/
|
||||||
|
export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension {
|
||||||
|
private static readonly _homeDir = 'engines'
|
||||||
|
private static readonly _engineMetadataFileName = 'triton_trtllm.json'
|
||||||
|
|
||||||
|
static _currentModel: Model;
|
||||||
|
|
||||||
|
static _engineSettings: EngineSettings = {
|
||||||
|
"base_url": "",
|
||||||
|
};
|
||||||
|
|
||||||
|
controller = new AbortController();
|
||||||
|
isCancelled = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the type of the extension.
|
||||||
|
* @returns {ExtensionType} The type of the extension.
|
||||||
|
*/
|
||||||
|
// TODO: To fix
|
||||||
|
type(): ExtensionType {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Subscribes to events emitted by the @janhq/core package.
|
||||||
|
*/
|
||||||
|
onLoad(): void {
|
||||||
|
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir)
|
||||||
|
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
||||||
|
|
||||||
|
// Events subscription
|
||||||
|
events.on(EventName.OnMessageSent, (data) =>
|
||||||
|
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this)
|
||||||
|
);
|
||||||
|
|
||||||
|
events.on(EventName.OnModelInit, (model: Model) => {
|
||||||
|
JanInferenceTritonTrtLLMExtension.handleModelInit(model);
|
||||||
|
});
|
||||||
|
|
||||||
|
events.on(EventName.OnModelStop, (model: Model) => {
|
||||||
|
JanInferenceTritonTrtLLMExtension.handleModelStop(model);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops the model inference.
|
||||||
|
*/
|
||||||
|
onUnload(): void {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the model with the specified file name.
|
||||||
|
* @param {string} modelId - The ID of the model to initialize.
|
||||||
|
* @returns {Promise<void>} A promise that resolves when the model is initialized.
|
||||||
|
*/
|
||||||
|
async initModel(
|
||||||
|
modelId: string,
|
||||||
|
settings?: ModelSettingParams
|
||||||
|
): Promise<void> {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
static async writeDefaultEngineSettings() {
|
||||||
|
try {
|
||||||
|
const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName)
|
||||||
|
if (await fs.exists(engine_json)) {
|
||||||
|
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2))
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Stops the model.
|
||||||
|
* @returns {Promise<void>} A promise that resolves when the model is stopped.
|
||||||
|
*/
|
||||||
|
async stopModel(): Promise<void> {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops streaming inference.
|
||||||
|
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
|
||||||
|
*/
|
||||||
|
async stopInference(): Promise<void> {
|
||||||
|
this.isCancelled = true;
|
||||||
|
this.controller?.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Makes a single response inference request.
|
||||||
|
* @param {MessageRequest} data - The data for the inference request.
|
||||||
|
* @returns {Promise<any>} A promise that resolves with the inference response.
|
||||||
|
*/
|
||||||
|
async inference(data: MessageRequest): Promise<ThreadMessage> {
|
||||||
|
const timestamp = Date.now();
|
||||||
|
const message: ThreadMessage = {
|
||||||
|
thread_id: data.threadId,
|
||||||
|
created: timestamp,
|
||||||
|
updated: timestamp,
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
id: "",
|
||||||
|
role: ChatCompletionRole.Assistant,
|
||||||
|
object: "thread.message",
|
||||||
|
content: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
return new Promise(async (resolve, reject) => {
|
||||||
|
requestInference(data.messages ?? [],
|
||||||
|
JanInferenceTritonTrtLLMExtension._engineSettings,
|
||||||
|
JanInferenceTritonTrtLLMExtension._currentModel)
|
||||||
|
.subscribe({
|
||||||
|
next: (_content) => {},
|
||||||
|
complete: async () => {
|
||||||
|
resolve(message);
|
||||||
|
},
|
||||||
|
error: async (err) => {
|
||||||
|
reject(err);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async handleModelInit(model: Model) {
|
||||||
|
if (model.engine !== 'triton_trtllm') { return }
|
||||||
|
else {
|
||||||
|
JanInferenceTritonTrtLLMExtension._currentModel = model
|
||||||
|
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
||||||
|
// Todo: Check model list with API key
|
||||||
|
events.emit(EventName.OnModelReady, model)
|
||||||
|
// events.emit(EventName.OnModelFail, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async handleModelStop(model: Model) {
|
||||||
|
if (model.engine !== 'triton_trtllm') { return }
|
||||||
|
events.emit(EventName.OnModelStopped, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a new message request by making an inference request and emitting events.
|
||||||
|
* Function registered in event manager, should be static to avoid binding issues.
|
||||||
|
* Pass instance as a reference.
|
||||||
|
* @param {MessageRequest} data - The data for the new message request.
|
||||||
|
*/
|
||||||
|
private static async handleMessageRequest(
|
||||||
|
data: MessageRequest,
|
||||||
|
instance: JanInferenceTritonTrtLLMExtension
|
||||||
|
) {
|
||||||
|
if (data.model.engine !== 'triton_trtllm') { return }
|
||||||
|
|
||||||
|
const timestamp = Date.now();
|
||||||
|
const message: ThreadMessage = {
|
||||||
|
id: ulid(),
|
||||||
|
thread_id: data.threadId,
|
||||||
|
assistant_id: data.assistantId,
|
||||||
|
role: ChatCompletionRole.Assistant,
|
||||||
|
content: [],
|
||||||
|
status: MessageStatus.Pending,
|
||||||
|
created: timestamp,
|
||||||
|
updated: timestamp,
|
||||||
|
object: "thread.message",
|
||||||
|
};
|
||||||
|
events.emit(EventName.OnMessageResponse, message);
|
||||||
|
|
||||||
|
instance.isCancelled = false;
|
||||||
|
instance.controller = new AbortController();
|
||||||
|
|
||||||
|
requestInference(
|
||||||
|
data?.messages ?? [],
|
||||||
|
this._engineSettings,
|
||||||
|
JanInferenceTritonTrtLLMExtension._currentModel,
|
||||||
|
instance.controller
|
||||||
|
).subscribe({
|
||||||
|
next: (content) => {
|
||||||
|
const messageContent: ThreadContent = {
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: content.trim(),
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
message.content = [messageContent];
|
||||||
|
events.emit(EventName.OnMessageUpdate, message);
|
||||||
|
},
|
||||||
|
complete: async () => {
|
||||||
|
message.status = MessageStatus.Ready;
|
||||||
|
events.emit(EventName.OnMessageUpdate, message);
|
||||||
|
},
|
||||||
|
error: async (err) => {
|
||||||
|
const messageContent: ThreadContent = {
|
||||||
|
type: ContentType.Text,
|
||||||
|
text: {
|
||||||
|
value: "Error occurred: " + err.message,
|
||||||
|
annotations: [],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
message.content = [messageContent];
|
||||||
|
message.status = MessageStatus.Ready;
|
||||||
|
events.emit(EventName.OnMessageUpdate, message);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
15
extensions/inference-triton-trtllm-extension/tsconfig.json
Normal file
15
extensions/inference-triton-trtllm-extension/tsconfig.json
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es2016",
|
||||||
|
"module": "ES6",
|
||||||
|
"moduleResolution": "node",
|
||||||
|
|
||||||
|
"outDir": "./dist",
|
||||||
|
"esModuleInterop": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"strict": false,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"rootDir": "./src"
|
||||||
|
},
|
||||||
|
"include": ["./src"]
|
||||||
|
}
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
const path = require("path");
|
||||||
|
const webpack = require("webpack");
|
||||||
|
const packageJson = require("./package.json");
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
experiments: { outputModule: true },
|
||||||
|
entry: "./src/index.ts", // Adjust the entry point to match your project's main file
|
||||||
|
mode: "production",
|
||||||
|
module: {
|
||||||
|
rules: [
|
||||||
|
{
|
||||||
|
test: /\.tsx?$/,
|
||||||
|
use: "ts-loader",
|
||||||
|
exclude: /node_modules/,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
plugins: [
|
||||||
|
new webpack.DefinePlugin({
|
||||||
|
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
output: {
|
||||||
|
filename: "index.js", // Adjust the output file name as needed
|
||||||
|
path: path.resolve(__dirname, "dist"),
|
||||||
|
library: { type: "module" }, // Specify ESM output format
|
||||||
|
},
|
||||||
|
resolve: {
|
||||||
|
extensions: [".ts", ".js"],
|
||||||
|
fallback: {
|
||||||
|
path: require.resolve("path-browserify"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
optimization: {
|
||||||
|
minimize: false,
|
||||||
|
},
|
||||||
|
// Add loaders and other configuration as needed for your project
|
||||||
|
};
|
||||||
Loading…
x
Reference in New Issue
Block a user