feat: cohere remote API extension (#2785)

* fix core

* add cohere extension

* add cohere response customizable

* nitpicking

* use transformResponse

* Update extensions/inference-cohere-extension/src/index.ts

Co-authored-by: Louis <louis@jan.ai>

* use prettier

* Update extensions/inference-cohere-extension/src/index.ts

Co-authored-by: Louis <louis@jan.ai>

* pass requestBody as object

* transformPayload as a property

* This is not correct. CHATBOT is an equivalent role to assistant.
system message should be used with the preamble parameter and should not be included in the chat_history

---------

Co-authored-by: Jack Tri Le <Jack>
Co-authored-by: Louis <louis@jan.ai>
This commit is contained in:
Inchoker 2024-04-24 18:16:57 +07:00 committed by GitHub
parent 3c294d6a48
commit 96abd533c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 368 additions and 14 deletions

View File

@ -31,6 +31,12 @@ export abstract class OAIEngine extends AIEngine {
// The loaded model instance
loadedModel: Model | undefined
// Transform the payload
transformPayload?: Function
// Transform the response
transformResponse?: Function
/**
* On extension load, subscribe to events.
*/
@ -78,13 +84,23 @@ export abstract class OAIEngine extends AIEngine {
}
const header = await this.headers()
let requestBody = {
messages: data.messages ?? [],
model: model.id,
stream: true,
...model.parameters,
}
if (this.transformPayload) {
requestBody = this.transformPayload(requestBody)
}
requestInference(
this.inferenceUrl,
data.messages ?? [],
requestBody,
model,
this.controller,
header
header,
this.transformResponse
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {

View File

@ -7,21 +7,16 @@ import { ErrorCode, ModelRuntimeParams } from '../../../../types'
*/
export function requestInference(
inferenceUrl: string,
recentMessages: any[],
requestBody: any,
model: {
id: string
parameters: ModelRuntimeParams
},
controller?: AbortController,
headers?: HeadersInit
headers?: HeadersInit,
transformResponse?: Function
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
messages: recentMessages,
model: model.id,
stream: true,
...model.parameters,
})
fetch(inferenceUrl, {
method: 'POST',
headers: {
@ -30,17 +25,17 @@ export function requestInference(
'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json',
...headers,
},
body: requestBody,
body: JSON.stringify(requestBody),
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
let errorCode = ErrorCode.Unknown;
let errorCode = ErrorCode.Unknown
if (data.error) {
errorCode = data.error.code ?? data.error.type ?? ErrorCode.Unknown
} else if (response.status === 401) {
errorCode = ErrorCode.InvalidApiKey;
errorCode = ErrorCode.InvalidApiKey
}
const error = {
message: data.error?.message ?? 'Error occurred.',
@ -52,7 +47,11 @@ export function requestInference(
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
if (transformResponse) {
subscriber.next(transformResponse(data))
} else {
subscriber.next(data.choices[0]?.message?.content ?? '')
}
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')

View File

@ -1,3 +1,5 @@
import { ChatCompletionMessage } from '../inference'
/**
* Native Route APIs
* @description Enum of all the routes exposed by the app
@ -154,3 +156,8 @@ export const APIEvents = [
...Object.values(DownloadEvent),
...Object.values(LocalImportModelEvent),
]
export type PayloadType = {
messages: ChatCompletionMessage[]
model: string
stream: Boolean
}

View File

@ -0,0 +1,79 @@
# Cohere Engine Extension
Created using Jan extension example
# Create a Jan Extension using Typescript
Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀
## Create Your Own Extension
To create your own extension, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository
2. Select Create a new repository
3. Select an owner and name for your new repository
4. Click Create repository
5. Clone your new repository
## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension.
> [!NOTE]
>
> You'll need to have a reasonably modern version of
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
> [`nodenv`](https://github.com/nodenv/nodenv) or
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
> root of your repository to install the version specified in
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
1. :hammer_and_wrench: Install the dependencies
```bash
npm install
```
1. :building_construction: Package the TypeScript for distribution
```bash
npm run bundle
```
1. :white_check_mark: Check your artifact
There will be a tgz file in your extension directory now
## Update the Extension Metadata
The [`package.json`](package.json) file defines metadata about your extension, such as
extension name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your extension.
## Update the Extension Code
The [`src/`](./src/) directory is the heart of your extension! This contains the
source code that will be run when your extension functions are invoked. You can replace the
contents of this directory with your own code.
There are a few things to keep in mind when writing your extension code:
- Most Jan Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript
import { events, MessageEvent, MessageRequest } from '@janhq/core'
function onStart(): Promise<any> {
return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.inference(data)
)
}
```
For more information about the Jan Extension Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your extension!

View File

@ -0,0 +1,43 @@
{
"name": "@janhq/inference-cohere-extension",
"productName": "Cohere Inference Engine",
"version": "1.0.0",
"description": "This extension enables Cohere chat completion API calls",
"main": "dist/index.js",
"module": "dist/module.js",
"engine": "cohere",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && yarn build && npm pack && cpx *.tgz ../../pre-install",
"sync:core": "cd ../.. && yarn build:core && cd extensions && rm yarn.lock && cd inference-cohere-extension && yarn && yarn build:publish"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4",
"ts-loader": "^9.5.0"
},
"dependencies": {
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"ulidx": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"fetch-retry"
]
}

View File

@ -0,0 +1,26 @@
[
{
"sources": [
{
"url": "https://cohere.com"
}
],
"id": "command-r-plus",
"object": "model",
"name": "Command R+",
"version": "1.0",
"description": "Command R+ is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It is best suited for complex RAG workflows and multi-step tool use.",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 128000,
"temperature": 0.7,
"stream": false
},
"metadata": {
"author": "Cohere",
"tags": ["General", "Big Context Length"]
},
"engine": "cohere"
}
]

View File

@ -0,0 +1,23 @@
[
{
"key": "chat-completions-endpoint",
"title": "Chat Completions Endpoint",
"description": "The endpoint to use for chat completions. See the [Cohere API documentation](https://docs.cohere.com/reference/chat) for more information.",
"controllerType": "input",
"controllerProps": {
"placeholder": "https://api.cohere.ai/v1/chat",
"value": "https://api.cohere.ai/v1/chat"
}
},
{
"key": "cohere-api-key",
"title": "API Key",
"description": "The Cohere API uses API keys for authentication. Visit your [API Keys](https://platform.openai.com/account/api-keys) page to retrieve the API key you'll use in your requests.",
"controllerType": "input",
"controllerProps": {
"placeholder": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"value": "",
"type": "password"
}
}
]

View File

@ -0,0 +1,110 @@
/**
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0
* @module inference-openai-extension/src/index
*/
import { RemoteOAIEngine } from '@janhq/core'
import { PayloadType } from '@janhq/core'
import { ChatCompletionRole } from '@janhq/core'
declare const SETTINGS: Array<any>
declare const MODELS: Array<any>
enum Settings {
apiKey = 'cohere-api-key',
chatCompletionsEndPoint = 'chat-completions-endpoint',
}
enum RoleType {
user = 'USER',
chatbot = 'CHATBOT',
system = 'SYSTEM',
}
type CoherePayloadType = {
chat_history?: Array<{ role: RoleType; message: string }>
message?: string,
preamble?: string,
}
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceCohereExtension extends RemoteOAIEngine {
inferenceUrl: string = ''
provider: string = 'cohere'
override async onLoad(): Promise<void> {
super.onLoad()
// Register Settings
this.registerSettings(SETTINGS)
this.registerModels(MODELS)
this.apiKey = await this.getSetting<string>(Settings.apiKey, '')
this.inferenceUrl = await this.getSetting<string>(
Settings.chatCompletionsEndPoint,
''
)
if (this.inferenceUrl.length === 0) {
SETTINGS.forEach((setting) => {
if (setting.key === Settings.chatCompletionsEndPoint) {
this.inferenceUrl = setting.controllerProps.value as string
}
})
}
}
onSettingUpdate<T>(key: string, value: T): void {
if (key === Settings.apiKey) {
this.apiKey = value as string
} else if (key === Settings.chatCompletionsEndPoint) {
if (typeof value !== 'string') return
if (value.trim().length === 0) {
SETTINGS.forEach((setting) => {
if (setting.key === Settings.chatCompletionsEndPoint) {
this.inferenceUrl = setting.controllerProps.value as string
}
})
} else {
this.inferenceUrl = value
}
}
}
transformPayload = (payload: PayloadType): CoherePayloadType => {
if (payload.messages.length === 0) {
return {}
}
const convertedData:CoherePayloadType = {
chat_history: [],
message: '',
}
payload.messages.forEach((item, index) => {
// Assign the message of the last item to the `message` property
if (index === payload.messages.length - 1) {
convertedData.message = item.content as string
return
}
if (item.role === ChatCompletionRole.User) {
convertedData.chat_history.push({ role: RoleType.user, message: item.content as string})
} else if (item.role === ChatCompletionRole.Assistant) {
convertedData.chat_history.push({
role: RoleType.chatbot,
message: item.content as string,
})
} else if (item.role === ChatCompletionRole.System) {
convertedData.preamble = item.content as string
}
})
return convertedData
}
transformResponse = (data: any) => data.text
}

View File

@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "es2016",
"module": "ES6",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
}

View File

@ -0,0 +1,37 @@
const webpack = require('webpack')
const packageJson = require('./package.json')
const settingJson = require('./resources/settings.json')
const modelsJson = require('./resources/models.json')
module.exports = {
experiments: { outputModule: true },
entry: './src/index.ts', // Adjust the entry point to match your project's main file
mode: 'production',
module: {
rules: [
{
test: /\.tsx?$/,
use: 'ts-loader',
exclude: /node_modules/,
},
],
},
plugins: [
new webpack.DefinePlugin({
MODELS: JSON.stringify(modelsJson),
SETTINGS: JSON.stringify(settingJson),
ENGINE: JSON.stringify(packageJson.engine),
}),
],
output: {
filename: 'index.js', // Adjust the output file name as needed
library: { type: 'module' }, // Specify ESM output format
},
resolve: {
extensions: ['.ts', '.js'],
},
optimization: {
minimize: false,
},
// Add loaders and other configuration as needed for your project
}