feat: Add triton trtllm for engine for remote models
This commit is contained in:
parent
44d4368e55
commit
f2688771ff
@ -174,7 +174,7 @@ export type ThreadState = {
|
||||
enum InferenceEngine {
|
||||
nitro = "nitro",
|
||||
openai = "openai",
|
||||
nvidia_triton = "nvidia_triton",
|
||||
triton_trtllm = "triton_trtllm",
|
||||
hf_endpoint = "hf_endpoint",
|
||||
}
|
||||
|
||||
|
||||
78
extensions/inference-triton-trtllm-extension/README.md
Normal file
78
extensions/inference-triton-trtllm-extension/README.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Jan inference plugin
|
||||
|
||||
Created using Jan app example
|
||||
|
||||
# Create a Jan Plugin using Typescript
|
||||
|
||||
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
|
||||
|
||||
## Create Your Own Plugin
|
||||
|
||||
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
|
||||
|
||||
1. Click the Use this template button at the top of the repository
|
||||
2. Select Create a new repository
|
||||
3. Select an owner and name for your new repository
|
||||
4. Click Create repository
|
||||
5. Clone your new repository
|
||||
|
||||
## Initial Setup
|
||||
|
||||
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> You'll need to have a reasonably modern version of
|
||||
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
|
||||
> [`nodenv`](https://github.com/nodenv/nodenv) or
|
||||
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
|
||||
> root of your repository to install the version specified in
|
||||
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
|
||||
|
||||
1. :hammer_and_wrench: Install the dependencies
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
1. :building_construction: Package the TypeScript for distribution
|
||||
|
||||
```bash
|
||||
npm run bundle
|
||||
```
|
||||
|
||||
1. :white_check_mark: Check your artifact
|
||||
|
||||
There will be a tgz file in your plugin directory now
|
||||
|
||||
## Update the Plugin Metadata
|
||||
|
||||
The [`package.json`](package.json) file defines metadata about your plugin, such as
|
||||
plugin name, main entry, description and version.
|
||||
|
||||
When you copy this repository, update `package.json` with the name, description for your plugin.
|
||||
|
||||
## Update the Plugin Code
|
||||
|
||||
The [`src/`](./src/) directory is the heart of your plugin! This contains the
|
||||
source code that will be run when your plugin extension functions are invoked. You can replace the
|
||||
contents of this directory with your own code.
|
||||
|
||||
There are a few things to keep in mind when writing your plugin code:
|
||||
|
||||
- Most Jan Plugin Extension functions are processed asynchronously.
|
||||
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
|
||||
|
||||
```typescript
|
||||
import { core } from "@janhq/core";
|
||||
|
||||
function onStart(): Promise<any> {
|
||||
return core.invokePluginFunc(MODULE_PATH, "run", 0);
|
||||
}
|
||||
```
|
||||
|
||||
For more information about the Jan Plugin Core module, see the
|
||||
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
|
||||
|
||||
So, what are you waiting for? Go ahead and start customizing your plugin!
|
||||
|
||||
41
extensions/inference-triton-trtllm-extension/package.json
Normal file
41
extensions/inference-triton-trtllm-extension/package.json
Normal file
@ -0,0 +1,41 @@
|
||||
{
|
||||
"name": "@janhq/inference-triton-trt-llm-extension",
|
||||
"version": "1.0.0",
|
||||
"description": "Inference Engine for NVIDIA Triton with TensorRT-LLM Extension integration on Jan extension framework",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/module.js",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"build": "tsc -b . && webpack --config webpack.config.js",
|
||||
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install"
|
||||
},
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
"./main": "./dist/module.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"webpack": "^5.88.2",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core",
|
||||
"fetch-retry": "^5.0.6",
|
||||
"path-browserify": "^1.0.1",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ulid": "^2.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
],
|
||||
"bundleDependencies": [
|
||||
"fetch-retry"
|
||||
]
|
||||
}
|
||||
7
extensions/inference-triton-trtllm-extension/src/@types/global.d.ts
vendored
Normal file
7
extensions/inference-triton-trtllm-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
import { Model } from "@janhq/core";
|
||||
|
||||
declare const MODULE: string;
|
||||
|
||||
declare interface EngineSettings {
|
||||
base_url?: string;
|
||||
}
|
||||
@ -0,0 +1,63 @@
|
||||
import { Observable } from "rxjs";
|
||||
import { EngineSettings } from "../@types/global";
|
||||
import { Model } from "@janhq/core";
|
||||
|
||||
/**
|
||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||
* @param engine - The engine settings to use for the inference.
|
||||
* @param model - The model to use for the inference.
|
||||
* @returns An Observable that emits the generated response as a string.
|
||||
*/
|
||||
export function requestInference(
|
||||
recentMessages: any[],
|
||||
engine: EngineSettings,
|
||||
model: Model,
|
||||
controller?: AbortController
|
||||
): Observable<string> {
|
||||
return new Observable((subscriber) => {
|
||||
const text_input = recentMessages.map((message) => message.text).join("\n");
|
||||
const requestBody = JSON.stringify({
|
||||
text_input: text_input,
|
||||
max_tokens: 4096,
|
||||
temperature: 0,
|
||||
bad_words: "",
|
||||
stop_words: "[DONE]",
|
||||
stream: true
|
||||
});
|
||||
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
body: requestBody,
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
subscriber.next(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
subscriber.complete();
|
||||
})
|
||||
.catch((err) => subscriber.error(err));
|
||||
});
|
||||
}
|
||||
235
extensions/inference-triton-trtllm-extension/src/index.ts
Normal file
235
extensions/inference-triton-trtllm-extension/src/index.ts
Normal file
@ -0,0 +1,235 @@
|
||||
/**
|
||||
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
* @version 1.0.0
|
||||
* @module inference-nvidia-triton-trt-llm-extension/src/index
|
||||
*/
|
||||
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
EventName,
|
||||
MessageRequest,
|
||||
MessageStatus,
|
||||
ModelSettingParams,
|
||||
ExtensionType,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
events,
|
||||
fs,
|
||||
Model,
|
||||
} from "@janhq/core";
|
||||
import { InferenceExtension } from "@janhq/core";
|
||||
import { requestInference } from "./helpers/sse";
|
||||
import { ulid } from "ulid";
|
||||
import { join } from "path";
|
||||
import { EngineSettings } from "./@types/global";
|
||||
|
||||
/**
|
||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension {
|
||||
private static readonly _homeDir = 'engines'
|
||||
private static readonly _engineMetadataFileName = 'triton_trtllm.json'
|
||||
|
||||
static _currentModel: Model;
|
||||
|
||||
static _engineSettings: EngineSettings = {
|
||||
"base_url": "",
|
||||
};
|
||||
|
||||
controller = new AbortController();
|
||||
isCancelled = false;
|
||||
|
||||
/**
|
||||
* Returns the type of the extension.
|
||||
* @returns {ExtensionType} The type of the extension.
|
||||
*/
|
||||
// TODO: To fix
|
||||
type(): ExtensionType {
|
||||
return undefined;
|
||||
}
|
||||
/**
|
||||
* Subscribes to events emitted by the @janhq/core package.
|
||||
*/
|
||||
onLoad(): void {
|
||||
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir)
|
||||
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
||||
|
||||
// Events subscription
|
||||
events.on(EventName.OnMessageSent, (data) =>
|
||||
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this)
|
||||
);
|
||||
|
||||
events.on(EventName.OnModelInit, (model: Model) => {
|
||||
JanInferenceTritonTrtLLMExtension.handleModelInit(model);
|
||||
});
|
||||
|
||||
events.on(EventName.OnModelStop, (model: Model) => {
|
||||
JanInferenceTritonTrtLLMExtension.handleModelStop(model);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the model inference.
|
||||
*/
|
||||
onUnload(): void {}
|
||||
|
||||
/**
|
||||
* Initializes the model with the specified file name.
|
||||
* @param {string} modelId - The ID of the model to initialize.
|
||||
* @returns {Promise<void>} A promise that resolves when the model is initialized.
|
||||
*/
|
||||
async initModel(
|
||||
modelId: string,
|
||||
settings?: ModelSettingParams
|
||||
): Promise<void> {
|
||||
return
|
||||
}
|
||||
|
||||
static async writeDefaultEngineSettings() {
|
||||
try {
|
||||
const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName)
|
||||
if (await fs.exists(engine_json)) {
|
||||
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
|
||||
}
|
||||
else {
|
||||
await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2))
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err)
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
* @returns {Promise<void>} A promise that resolves when the model is stopped.
|
||||
*/
|
||||
async stopModel(): Promise<void> {}
|
||||
|
||||
/**
|
||||
* Stops streaming inference.
|
||||
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
|
||||
*/
|
||||
async stopInference(): Promise<void> {
|
||||
this.isCancelled = true;
|
||||
this.controller?.abort();
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes a single response inference request.
|
||||
* @param {MessageRequest} data - The data for the inference request.
|
||||
* @returns {Promise<any>} A promise that resolves with the inference response.
|
||||
*/
|
||||
async inference(data: MessageRequest): Promise<ThreadMessage> {
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
thread_id: data.threadId,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
status: MessageStatus.Ready,
|
||||
id: "",
|
||||
role: ChatCompletionRole.Assistant,
|
||||
object: "thread.message",
|
||||
content: [],
|
||||
};
|
||||
|
||||
return new Promise(async (resolve, reject) => {
|
||||
requestInference(data.messages ?? [],
|
||||
JanInferenceTritonTrtLLMExtension._engineSettings,
|
||||
JanInferenceTritonTrtLLMExtension._currentModel)
|
||||
.subscribe({
|
||||
next: (_content) => {},
|
||||
complete: async () => {
|
||||
resolve(message);
|
||||
},
|
||||
error: async (err) => {
|
||||
reject(err);
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private static async handleModelInit(model: Model) {
|
||||
if (model.engine !== 'triton_trtllm') { return }
|
||||
else {
|
||||
JanInferenceTritonTrtLLMExtension._currentModel = model
|
||||
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
||||
// Todo: Check model list with API key
|
||||
events.emit(EventName.OnModelReady, model)
|
||||
// events.emit(EventName.OnModelFail, model)
|
||||
}
|
||||
}
|
||||
|
||||
private static async handleModelStop(model: Model) {
|
||||
if (model.engine !== 'triton_trtllm') { return }
|
||||
events.emit(EventName.OnModelStopped, model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a new message request by making an inference request and emitting events.
|
||||
* Function registered in event manager, should be static to avoid binding issues.
|
||||
* Pass instance as a reference.
|
||||
* @param {MessageRequest} data - The data for the new message request.
|
||||
*/
|
||||
private static async handleMessageRequest(
|
||||
data: MessageRequest,
|
||||
instance: JanInferenceTritonTrtLLMExtension
|
||||
) {
|
||||
if (data.model.engine !== 'triton_trtllm') { return }
|
||||
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
id: ulid(),
|
||||
thread_id: data.threadId,
|
||||
assistant_id: data.assistantId,
|
||||
role: ChatCompletionRole.Assistant,
|
||||
content: [],
|
||||
status: MessageStatus.Pending,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: "thread.message",
|
||||
};
|
||||
events.emit(EventName.OnMessageResponse, message);
|
||||
|
||||
instance.isCancelled = false;
|
||||
instance.controller = new AbortController();
|
||||
|
||||
requestInference(
|
||||
data?.messages ?? [],
|
||||
this._engineSettings,
|
||||
JanInferenceTritonTrtLLMExtension._currentModel,
|
||||
instance.controller
|
||||
).subscribe({
|
||||
next: (content) => {
|
||||
const messageContent: ThreadContent = {
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: content.trim(),
|
||||
annotations: [],
|
||||
},
|
||||
};
|
||||
message.content = [messageContent];
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
complete: async () => {
|
||||
message.status = MessageStatus.Ready;
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
error: async (err) => {
|
||||
const messageContent: ThreadContent = {
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: "Error occurred: " + err.message,
|
||||
annotations: [],
|
||||
},
|
||||
};
|
||||
message.content = [messageContent];
|
||||
message.status = MessageStatus.Ready;
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
15
extensions/inference-triton-trtllm-extension/tsconfig.json
Normal file
15
extensions/inference-triton-trtllm-extension/tsconfig.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2016",
|
||||
"module": "ES6",
|
||||
"moduleResolution": "node",
|
||||
|
||||
"outDir": "./dist",
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"strict": false,
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"]
|
||||
}
|
||||
@ -0,0 +1,38 @@
|
||||
const path = require("path");
|
||||
const webpack = require("webpack");
|
||||
const packageJson = require("./package.json");
|
||||
|
||||
module.exports = {
|
||||
experiments: { outputModule: true },
|
||||
entry: "./src/index.ts", // Adjust the entry point to match your project's main file
|
||||
mode: "production",
|
||||
module: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.tsx?$/,
|
||||
use: "ts-loader",
|
||||
exclude: /node_modules/,
|
||||
},
|
||||
],
|
||||
},
|
||||
plugins: [
|
||||
new webpack.DefinePlugin({
|
||||
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||
}),
|
||||
],
|
||||
output: {
|
||||
filename: "index.js", // Adjust the output file name as needed
|
||||
path: path.resolve(__dirname, "dist"),
|
||||
library: { type: "module" }, // Specify ESM output format
|
||||
},
|
||||
resolve: {
|
||||
extensions: [".ts", ".js"],
|
||||
fallback: {
|
||||
path: require.resolve("path-browserify"),
|
||||
},
|
||||
},
|
||||
optimization: {
|
||||
minimize: false,
|
||||
},
|
||||
// Add loaders and other configuration as needed for your project
|
||||
};
|
||||
Loading…
x
Reference in New Issue
Block a user