feat: Groq Inference Extension (#2263)

* feat: Groq Inference Extension

* Add Groq supported models

* Fix folder typo

* Add Groq options to interface and new API Key saving, tested working

* Fix linting
This commit is contained in:
Meta Spartan 2024-03-17 17:40:20 -06:00 committed by GitHub
parent 6c81d83d30
commit 0348aa3321
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 631 additions and 9 deletions

View File

@ -118,7 +118,7 @@ const exec = async (command: string): Promise<string> => {
}
export const getEngineConfiguration = async (engineId: string) => {
if (engineId !== 'openai') {
if (engineId !== 'openai' && engineId !== 'groq') {
return undefined
}
const directoryPath = join(getJanDataFolderPath(), 'engines')

View File

@ -18,6 +18,7 @@ export type ModelInfo = {
export enum InferenceEngine {
nitro = 'nitro',
openai = 'openai',
groq = 'groq',
triton_trtllm = 'triton_trtllm',
nitro_tensorrt_llm = 'nitro-tensorrt-llm',

View File

@ -0,0 +1,78 @@
# Jan inference plugin
Created using Jan app example
# Create a Jan Plugin using Typescript
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
## Create Your Own Plugin
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository
2. Select Create a new repository
3. Select an owner and name for your new repository
4. Click Create repository
5. Clone your new repository
## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
> [!NOTE]
>
> You'll need to have a reasonably modern version of
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
> [`nodenv`](https://github.com/nodenv/nodenv) or
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
> root of your repository to install the version specified in
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
1. :hammer_and_wrench: Install the dependencies
```bash
npm install
```
1. :building_construction: Package the TypeScript for distribution
```bash
npm run bundle
```
1. :white_check_mark: Check your artifact
There will be a tgz file in your plugin directory now
## Update the Plugin Metadata
The [`package.json`](package.json) file defines metadata about your plugin, such as
plugin name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your plugin.
## Update the Plugin Code
The [`src/`](./src/) directory is the heart of your plugin! This contains the
source code that will be run when your plugin extension functions are invoked. You can replace the
contents of this directory with your own code.
There are a few things to keep in mind when writing your plugin code:
- Most Jan Plugin Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript
import { core } from "@janhq/core";
function onStart(): Promise<any> {
return core.invokePluginFunc(MODULE_PATH, "run", 0);
}
```
For more information about the Jan Plugin Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your plugin!

View File

@ -0,0 +1,41 @@
{
"name": "@janhq/inference-groq-extension",
"version": "1.0.0",
"description": "This extension enables fast Groq chat completion API calls",
"main": "dist/index.js",
"module": "dist/module.js",
"author": "Carsen Klock & Jan",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../pre-install"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4",
"ts-loader": "^9.5.0"
},
"dependencies": {
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1",
"ulid": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"fetch-retry"
]
}

View File

@ -0,0 +1,16 @@
declare const MODULE: string
declare const GROQ_DOMAIN: string
declare interface EngineSettings {
full_url?: string
api_key?: string
}
enum GroqChatCompletionModelName {
'mixtral-8x7b-32768' = 'mixtral-8x7b-32768',
'llama2-70b-4096' = 'llama2-70b-4096',
}
declare type GroqModel = Omit<Model, 'id'> & {
id: GroqChatCompletionModelName
}

View File

@ -0,0 +1,83 @@
import { ErrorCode } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: GroqModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
// let model_id: string = model.id
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model.id,
...model.parameters,
})
fetch(`${engine.full_url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
'Access-Control-Allow-Origin': '*',
'Authorization': `Bearer ${engine.api_key}`,
// 'api-key': `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'An error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -0,0 +1,224 @@
/**
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0
* @module inference-groq-extension/src/index
*/
import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events,
fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName,
joinPath,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulid'
import { join } from 'path'
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceGroqExtension extends BaseExtension {
private static readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'groq.json'
private static _currentModel: GroqModel
private static _engineSettings: EngineSettings = {
full_url: 'https://api.groq.com/openai/v1/chat/completions',
api_key: 'gsk-<your key here>',
}
controller = new AbortController()
isCancelled = false
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) {
await fs
.mkdirSync(JanInferenceGroqExtension._engineDir)
.catch((err) => console.debug(err))
}
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceGroqExtension.handleMessageRequest(data, this)
)
events.on(ModelEvent.OnModelInit, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceGroqExtension.handleInferenceStopped(this)
})
const settingsFilePath = await joinPath([
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName,
])
events.on(
AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath)
JanInferenceGroqExtension.writeDefaultEngineSettings()
}
)
}
/**
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName
)
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceGroqExtension._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private static async handleModelInit(model: GroqModel) {
if (model.engine !== InferenceEngine.groq) {
return
} else {
JanInferenceGroqExtension._currentModel = model
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: GroqModel) {
if (model.engine !== 'groq') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
private static async handleInferenceStopped(
instance: JanInferenceGroqExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceGroqExtension
) {
if (data.model.engine !== 'groq') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceGroqExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
}

View File

@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "es2016",
"module": "ES6",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
}

View File

@ -0,0 +1,39 @@
const path = require('path')
const webpack = require('webpack')
const packageJson = require('./package.json')
module.exports = {
experiments: { outputModule: true },
entry: './src/index.ts', // Adjust the entry point to match your project's main file
mode: 'production',
module: {
rules: [
{
test: /\.tsx?$/,
use: 'ts-loader',
exclude: /node_modules/,
},
],
},
plugins: [
new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
GROQ_DOMAIN: JSON.stringify('api.groq.com'),
}),
],
output: {
filename: 'index.js', // Adjust the output file name as needed
path: path.resolve(__dirname, 'dist'),
library: { type: 'module' }, // Specify ESM output format
},
resolve: {
extensions: ['.ts', '.js'],
fallback: {
path: require.resolve('path-browserify'),
},
},
optimization: {
minimize: false,
},
// Add loaders and other configuration as needed for your project
}

View File

@ -0,0 +1,27 @@
{
"sources": [
{
"url": "https://groq.com"
}
],
"id": "llama2-70b-4096",
"object": "model",
"name": "Groq Llama 2 70b",
"version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Meta",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}

View File

@ -0,0 +1,27 @@
{
"sources": [
{
"url": "https://groq.com"
}
],
"id": "mixtral-8x7b-32768",
"object": "model",
"name": "Groq Mixtral 8x7b Instruct",
"version": "1.0",
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Mistral",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}

View File

@ -78,7 +78,9 @@ const DropdownListSidebar = ({
(model) => model.engine !== InferenceEngine.openai
)
const remoteModel = downloadedModels.filter(
(model) => model.engine === InferenceEngine.openai
(model) =>
model.engine === InferenceEngine.openai ||
model.engine === InferenceEngine.groq
)
const modelOptions = isTabActive === 0 ? localModel : remoteModel

View File

@ -19,16 +19,49 @@ const OpenAiKeyInput: React.FC = () => {
>(undefined)
const { readOpenAISettings, saveOpenAISettings } = useEngineSettings()
const [groqSettings, setGroqSettings] = useState<
{ api_key: string } | undefined
>(undefined)
const { readGroqSettings, saveGroqSettings } = useEngineSettings()
useEffect(() => {
readOpenAISettings().then((settings) => {
setOpenAISettings(settings)
})
}, [readOpenAISettings])
if (!selectedModel || selectedModel.engine !== InferenceEngine.openai) {
useEffect(() => {
readGroqSettings().then((settings) => {
setGroqSettings(settings)
})
}, [readGroqSettings])
if (
!selectedModel ||
(selectedModel.engine !== InferenceEngine.openai &&
selectedModel.engine !== InferenceEngine.groq)
) {
return null
}
const getCurrentApiKey = () => {
if (selectedModel.engine === InferenceEngine.openai) {
return openAISettings?.api_key
} else if (selectedModel.engine === InferenceEngine.groq) {
return groqSettings?.api_key
}
return '' // Default return value
}
const handleApiKeyChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const newApiKey = e.target.value
if (selectedModel.engine === InferenceEngine.openai) {
saveOpenAISettings({ apiKey: newApiKey })
} else if (selectedModel.engine === InferenceEngine.groq) {
saveGroqSettings({ apiKey: newApiKey })
}
}
return (
<div className="my-4">
<label
@ -40,11 +73,9 @@ const OpenAiKeyInput: React.FC = () => {
<Input
disabled={serverEnabled}
id="assistant-instructions"
placeholder="Enter your API_KEY"
defaultValue={openAISettings?.api_key}
onChange={(e) => {
saveOpenAISettings({ apiKey: e.target.value })
}}
placeholder={getCurrentApiKey()}
defaultValue={getCurrentApiKey()}
onChange={handleApiKeyChange}
/>
</div>
)

View File

@ -36,5 +36,43 @@ export const useEngineSettings = () => {
settingFilePath
)
}
return { readOpenAISettings, saveOpenAISettings }
const readGroqSettings = useCallback(async () => {
if (!(await fs.existsSync(await joinPath(['file://engines', 'groq.json']))))
return {}
const settings = await fs.readFileSync(
await joinPath(['file://engines', 'groq.json']),
'utf-8'
)
if (settings) {
return typeof settings === 'object' ? settings : JSON.parse(settings)
}
return {}
}, [])
const saveGroqSettings = async ({
apiKey,
}: {
apiKey: string | undefined
}) => {
const settings = await readGroqSettings()
const settingFilePath = await joinPath(['file://engines', 'groq.json'])
settings.api_key = apiKey
await fs.writeFileSync(settingFilePath, JSON.stringify(settings))
// Sec: Don't attach the settings data to the event
events.emit(
AppConfigurationEventName.OnConfigurationUpdate,
settingFilePath
)
}
return {
readOpenAISettings,
saveOpenAISettings,
readGroqSettings,
saveGroqSettings,
}
}

View File

@ -49,6 +49,7 @@ export default function RowModel(props: RowModelProps) {
const isRemoteModel =
props.data.engine === InferenceEngine.openai ||
props.data.engine === InferenceEngine.groq ||
props.data.engine === InferenceEngine.triton_trtllm
const onModelActionClick = (modelId: string) => {