feat: update inference service to use nitro

This commit is contained in:
Louis 2023-09-29 14:09:07 +07:00 committed by Louis
parent 9b14b743d1
commit 87678e799f
19 changed files with 2565 additions and 237 deletions

3
.gitignore vendored
View File

@ -17,4 +17,5 @@ electron/renderer
# Nitro
electron/core/plugins/nitro-plugin/nitro
electron/core/plugins/nitro-plugin/uploads
electron/core/plugins/nitro-plugin/*.log
electron/core/plugins/nitro-plugin/*.log
*.log

View File

@ -92,7 +92,7 @@ const createConversation = (conversation: any) =>
resolve(res);
});
} else {
resolve("-");
resolve(undefined);
}
});
const createMessage = (message: any) =>
@ -100,9 +100,24 @@ const createMessage = (message: any) =>
if (window && window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "storeMessage", message)
.then((res: any) => resolve(res));
.then((res: any) => {
resolve(res);
});
} else {
resolve("-");
resolve(undefined);
}
});
const updateMessage = (message: any) =>
new Promise((resolve) => {
if (window && window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "updateMessage", message)
.then((res: any) => {
resolve(res);
});
} else {
resolve(undefined);
}
});
@ -128,6 +143,7 @@ export function init({ register }: { register: any }) {
setupDb();
register("getConversations", "getConv", getConversations, 1);
register("createConversation", "insertConv", createConversation);
register("updateMessage", "updateMessage", updateMessage);
register("deleteConversation", "deleteConv", deleteConversation);
register("createMessage", "insertMessage", createMessage);
register("getConversationMessages", "getMessages", getConversationMessages);

View File

@ -241,7 +241,7 @@ function getConversations() {
);
db.all(
"SELECT * FROM conversations ORDER BY created_at DESC",
"SELECT * FROM conversations ORDER BY updated_at DESC",
(err: any, row: any) => {
res(row);
}
@ -249,7 +249,7 @@ function getConversations() {
db.close();
});
}
function storeConversation(conversation: any) {
function storeConversation(conversation: any): Promise<number | undefined> {
return new Promise((res) => {
const db = new sqlite3.Database(
path.join(app.getPath("userData"), "jan.db")
@ -284,7 +284,7 @@ function storeConversation(conversation: any) {
});
}
function storeMessage(message: any) {
function storeMessage(message: any): Promise<number | undefined> {
return new Promise((res) => {
const db = new sqlite3.Database(
path.join(app.getPath("userData"), "jan.db")
@ -299,7 +299,7 @@ function storeMessage(message: any) {
message.conversation_id,
message.user,
message.message,
(err: any) => {
function (err: any) {
if (err) {
// Handle the insertion error here
console.error(err.message);
@ -318,6 +318,24 @@ function storeMessage(message: any) {
db.close();
});
}
function updateMessage(message: any): Promise<number | undefined> {
return new Promise((res) => {
const db = new sqlite3.Database(
path.join(app.getPath("userData"), "jan.db")
);
db.serialize(() => {
const stmt = db.prepare(
"UPDATE messages SET message = ?, updated_at = ? WHERE id = ?"
);
stmt.run(message.message, message.updated_at, message.id);
stmt.finalize();
res(message.id);
});
db.close();
});
}
function deleteConversation(id: any) {
return new Promise((res) => {
@ -347,7 +365,7 @@ function getConversationMessages(conversation_id: any) {
path.join(app.getPath("userData"), "jan.db")
);
const query = `SELECT * FROM messages WHERE conversation_id = ${conversation_id} ORDER BY created_at DESC`;
const query = `SELECT * FROM messages WHERE conversation_id = ${conversation_id} ORDER BY id DESC`;
db.all(query, (err: Error, row: any) => {
res(row);
});
@ -361,6 +379,7 @@ module.exports = {
deleteConversation,
storeConversation,
storeMessage,
updateMessage,
getConversationMessages,
storeModel,
updateFinishedDownloadAt,

View File

@ -1,14 +1,5 @@
const MODULE_PATH = "inference-plugin/dist/module.js";
const prompt = async (prompt) =>
new Promise(async (resolve) => {
if (window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "prompt", prompt)
.then((res) => resolve(res));
}
});
const initModel = async (product) =>
new Promise(async (resolve) => {
if (window.electronAPI) {
@ -18,8 +9,17 @@ const initModel = async (product) =>
}
});
const dispose = async () =>
new Promise(async (resolve) => {
if (window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "killSubprocess")
.then((res) => resolve(res));
}
});
// Register all the above functions and objects with the relevant extension points
export function init({ register }) {
register("initModel", "initModel", initModel);
register("prompt", "prompt", prompt);
register("dispose", "dispose", dispose);
}

View File

@ -1,8 +1,9 @@
const path = require("path");
const { app, dialog } = require("electron");
const _importDynamic = new Function("modulePath", "return import(modulePath)");
const { spawn } = require("child_process");
const fs = require("fs");
let llamaSession = null;
let subprocess = null;
async function initModel(product) {
// fileName fallback
@ -18,38 +19,64 @@ async function initModel(product) {
return;
}
console.info(`Initializing model: ${product.name}..`);
_importDynamic("../node_modules/node-llama-cpp/dist/index.js")
.then(({ LlamaContext, LlamaChatSession, LlamaModel }) => {
const modelPath = path.join(app.getPath("userData"), product.fileName);
const model = new LlamaModel({ modelPath });
const context = new LlamaContext({ model });
llamaSession = new LlamaChatSession({ context });
console.info(`Init model ${product.name} successfully!`);
})
.catch(async (e) => {
console.error(e);
await dialog.showMessageBox({
message: "Failed to import LLM module",
});
});
if (subprocess) {
console.error(
"A subprocess is already running. Attempt to kill then reinit."
);
killSubprocess();
}
let binaryFolder = `${__dirname}/nitro`; // Current directory by default
// Read the existing config
const configFilePath = `${binaryFolder}/config/config.json`;
let config = {};
if (fs.existsSync(configFilePath)) {
const rawData = fs.readFileSync(configFilePath, "utf-8");
config = JSON.parse(rawData);
}
// Update the llama_model_path
if (!config.custom_config) {
config.custom_config = {};
}
const modelPath = path.join(app.getPath("userData"), product.fileName);
config.custom_config.llama_model_path = modelPath;
// Write the updated config back to the file
fs.writeFileSync(configFilePath, JSON.stringify(config, null, 4));
// Execute the binary
subprocess = spawn(`${binaryFolder}/nitro`, [configFilePath]);
// Handle subprocess output
subprocess.stdout.on("data", (data) => {
console.log(`stdout: ${data}`);
});
subprocess.stderr.on("data", (data) => {
console.error(`stderr: ${data}`);
});
subprocess.on("close", (code) => {
console.log(`child process exited with code ${code}`);
subprocess = null;
});
}
async function prompt(prompt) {
if (!llamaSession) {
await dialog.showMessageBox({
message: "Model not initialized",
});
return;
function killSubprocess() {
if (subprocess) {
subprocess.kill();
subprocess = null;
console.log("Subprocess terminated.");
} else {
console.error("No subprocess is currently running.");
}
console.log("prompt: ", prompt);
const response = await llamaSession.prompt(prompt);
console.log("response: ", response);
return response;
}
module.exports = {
initModel,
prompt,
killSubprocess,
};

View File

@ -0,0 +1 @@
{"custom_config": {"llama_model_path":"","ctx_len":2048,"ngl":100}}

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -10,7 +10,7 @@
],
"scripts": {
"build": "webpack --config webpack.config.js",
"build:package": "rimraf ./*.tgz && npm run build && cpx \"module.js\" \"dist\" && npm pack",
"build:package": "rimraf ./*.tgz && npm run build && cpx \"module.js\" \"dist\" && rm -rf dist/nitro && cp -r nitro dist/nitro && npm pack",
"build:publish": "yarn build:package && cpx *.tgz ../../pre-install"
},
"devDependencies": {

View File

@ -1,25 +0,0 @@
const MODULE_PATH = "nitro-plugin/dist/module.js";
const installModel = async (product) =>
new Promise(async (resolve) => {
if (window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "installModel", product)
.then((res) => resolve(res));
}
});
const uninstallModel = async (product) =>
new Promise(async (resolve) => {
if (window.electronAPI) {
window.electronAPI
.invokePluginFunc(MODULE_PATH, "uninstallModel", product)
.then((res) => resolve(res));
}
});
// Register all the above functions and objects with the relevant extension points
export function init({ register }) {
register("installModel", "installModel", installModel);
register("uninstallModel", "uninstallModel", uninstallModel);
}

View File

@ -1,87 +0,0 @@
const { spawn } = require('child_process');
const fs = require('fs');
class NitroPlugin {
constructor() {
this.subprocess = null;
this.binaryFolder = `${__dirname}/nitro`; // Current directory by default
this.config = {};
}
/**
* Install a model by writing a JSON file and executing a binary.
* @param {string} modelPath - Path to the model.
*/
installModel(modelPath) {
// Check if there's an existing subprocess
if (this.subprocess) {
console.error('A subprocess is already running. Please uninstall the current model first.');
return;
}
// Read the existing config
const configFilePath = `${this.binaryFolder}/config/config.json`;
let config = {};
if (fs.existsSync(configFilePath)) {
const rawData = fs.readFileSync(configFilePath, 'utf-8');
config = JSON.parse(rawData);
}
// Update the llama_model_path
if (!config.custom_config) {
config.custom_config = {};
}
config.custom_config.llama_model_path = modelPath;
// Write the updated config back to the file
fs.writeFileSync(configFilePath, JSON.stringify(config, null, 4));
// Execute the binary
this.subprocess = spawn(`${this.binaryFolder}/nitro`, [configFilePath]);
// Handle subprocess output
this.subprocess.stdout.on('data', (data) => {
console.log(`stdout: ${data}`);
});
this.subprocess.stderr.on('data', (data) => {
console.error(`stderr: ${data}`);
});
this.subprocess.on('close', (code) => {
console.log(`child process exited with code ${code}`);
this.subprocess = null;
});
}
/**
* Uninstall the model by killing the subprocess.
*/
uninstallModel() {
if (this.subprocess) {
this.subprocess.kill();
this.subprocess = null;
console.log('Subprocess terminated.');
} else {
console.error('No subprocess is currently running.');
}
}
}
const test = async () => {
const nitro = new NitroPlugin();
nitro.installModel('/Users/nam/Documents/janai/code/jan/models/llama-2-7b.Q4_K_S.gguf');
// nitro.uninstallModel();
}
test()
// Export the functions
// module.exports = {
// NitroPlugin,
// installModel: (modelPath) => {
// nitro.installModel(modelPath);
// },
// uninstallModel: () => {
// nitro.uninstallModel();
// }
// };

View File

@ -1,36 +0,0 @@
{
"name": "nitro-plugin",
"version": "0.0.1",
"description": "",
"main": "dist/index.js",
"author": "Hiro",
"license": "MIT",
"activationPoints": [
"init"
],
"scripts": {
"build": "webpack --config webpack.config.js",
"build:package": "rimraf ./*.tgz && npm run build && cpx \"module.js\" \"dist\" && npm pack",
"build:publish": "yarn build:package && cpx *.tgz ../../pre-install"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4"
},
"bundledDependencies": [
"electron-is-dev"
],
"dependencies": {
"electron-is-dev": "^2.0.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
]
}

View File

@ -1,25 +0,0 @@
const path = require("path");
module.exports = {
experiments: { outputModule: true },
entry: "./index.js", // Adjust the entry point to match your project's main file
mode: "production",
module: {
rules: [
{
test: /\.tsx?$/,
use: "ts-loader",
exclude: /node_modules/,
},
],
},
output: {
filename: "index.js", // Adjust the output file name as needed
path: path.resolve(__dirname, "dist"),
library: { type: "module" }, // Specify ESM output format
},
resolve: {
extensions: [".js"],
},
// Add loaders and other configuration as needed for your project
};

View File

@ -24,13 +24,14 @@ autoUpdater.autoInstallOnAppQuit = true;
const createMainWindow = () => {
mainWindow = new BrowserWindow({
width: electronScreen.getPrimaryDisplay().workArea.width,
height: electronScreen.getPrimaryDisplay().workArea.height,
width: 1200,
height: 800,
show: false,
backgroundColor: "white",
webPreferences: {
nodeIntegration: true,
preload: join(__dirname, "preload.js"),
webSecurity: false,
},
});

View File

@ -3,7 +3,7 @@ import SimpleImageMessage from "../SimpleImageMessage";
import SimpleTextMessage from "../SimpleTextMessage";
import { ChatMessage, MessageType } from "@/_models/ChatMessage";
import StreamTextMessage from "../StreamTextMessage";
import { useAtom } from "jotai";
import { useAtomValue } from "jotai";
import { currentStreamingMessageAtom } from "@/_helpers/JotaiWrapper";
export default function renderChatMessage({
@ -17,7 +17,7 @@ export default function renderChatMessage({
status,
}: ChatMessage): React.ReactNode {
// eslint-disable-next-line react-hooks/rules-of-hooks
const [message, _] = useAtom(currentStreamingMessageAtom);
const message = useAtomValue(currentStreamingMessageAtom);
switch (messageType) {
case MessageType.ImageWithText:
return (

View File

@ -3,7 +3,7 @@ import { displayDate } from "@/_utils/datetime";
import { TextCode } from "../TextCode";
import { getMessageCode } from "@/_utils/message";
import Image from "next/image";
import { useAtom } from "jotai";
import { useAtomValue } from "jotai";
import { currentStreamingMessageAtom } from "@/_helpers/JotaiWrapper";
type Props = {
@ -21,7 +21,7 @@ const StreamTextMessage: React.FC<Props> = ({
avatarUrl = "",
text = "",
}) => {
const [message, _] = useAtom(currentStreamingMessageAtom);
const message = useAtomValue(currentStreamingMessageAtom);
return message?.text && message?.text?.length > 0 ? (
<div className="flex items-start gap-2 ml-3">

View File

@ -1,19 +1,28 @@
import {
addNewMessageAtom,
currentChatMessagesAtom,
currentConversationAtom,
currentPromptAtom,
currentStreamingMessageAtom,
showingTyping,
updateMessageAtom,
} from "@/_helpers/JotaiWrapper";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { DataService, InfereceService } from "../../shared/coreService";
import { RawMessage, toChatMessage } from "@/_models/ChatMessage";
import { DataService } from "../../shared/coreService";
import {
MessageSenderType,
RawMessage,
toChatMessage,
} from "@/_models/ChatMessage";
import { executeSerial } from "@/_services/pluginService";
export default function useSendChatMessage() {
const currentConvo = useAtomValue(currentConversationAtom);
const updateStreamMessage = useSetAtom(currentStreamingMessageAtom);
const addNewMessage = useSetAtom(addNewMessageAtom);
const updateMessage = useSetAtom(updateMessageAtom);
const chatMessagesHistory = useAtomValue(currentChatMessagesAtom);
const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom);
const [, setIsTyping] = useAtom(showingTyping);
const sendChatMessage = async () => {
@ -29,18 +38,91 @@ export default function useSendChatMessage() {
const id = await executeSerial(DataService.CREATE_MESSAGE, newMessage);
newMessage.id = id;
addNewMessage(await toChatMessage(newMessage));
const resp = await executeSerial(InfereceService.PROMPT, prompt);
const newChatMessage = await toChatMessage(newMessage);
addNewMessage(newChatMessage);
const recentMessages = [...chatMessagesHistory, newChatMessage]
.slice(-10)
.map((message) => {
return {
content: message.text,
role:
message.messageSenderType === MessageSenderType.User
? "user"
: "assistant",
};
});
const response = await fetch(
"http://localhost:8080/llama/chat_completion",
{
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origi": "*",
},
body: JSON.stringify({
messages: recentMessages,
stream: true,
model: "gpt-3.5-turbo",
max_tokens: 500,
}),
}
);
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let answer = "";
// Cache received response
const newResponse: RawMessage = {
conversation_id: parseInt(currentConvo?.id ?? "0") ?? 0,
message: resp,
message: answer,
user: "assistant",
created_at: new Date().toISOString(),
};
const respId = await executeSerial(DataService.CREATE_MESSAGE, newResponse);
newResponse.id = respId;
addNewMessage(await toChatMessage(newResponse));
const responseChatMessage = await toChatMessage(newResponse);
addNewMessage(responseChatMessage);
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
console.log("SSE stream closed");
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
setIsTyping(false);
const data = JSON.parse(line.replace("data: ", ""));
answer += data.choices[0]?.delta?.content ?? "";
if (answer.startsWith("assistant: ")) {
answer = answer.replace("assistant: ", "").trim();
}
updateStreamMessage({
...responseChatMessage,
text: answer,
});
updateMessage(
responseChatMessage.id,
responseChatMessage.conversationId,
answer
);
}
}
}
await executeSerial(DataService.UPDATE_MESSAGE, {
...newResponse,
message: answer,
updated_at: new Date()
.toISOString()
.replace("T", " ")
.replace(/\.\d+Z$/, ""),
});
setIsTyping(false);
};
return {

View File

@ -14,7 +14,7 @@ export const isCorePluginInstalled = () => {
if (!extensionPoints.get(DataService.GET_CONVERSATIONS)) {
return false;
}
if (!extensionPoints.get(InfereceService.PROMPT)) {
if (!extensionPoints.get(InfereceService.INIT_MODEL)) {
return false;
}
if (!extensionPoints.get(ModelManagementService.GET_DOWNLOADED_MODELS)) {
@ -33,7 +33,7 @@ export const setupBasePlugins = async () => {
if (
!extensionPoints.get(DataService.GET_CONVERSATIONS) ||
!extensionPoints.get(InfereceService.PROMPT) ||
!extensionPoints.get(InfereceService.INIT_MODEL) ||
!extensionPoints.get(ModelManagementService.GET_DOWNLOADED_MODELS)
) {
const installed = await plugins.install(basePlugins);

View File

@ -11,6 +11,7 @@ export enum DataService {
CREATE_CONVERSATION = "createConversation",
DELETE_CONVERSATION = "deleteConversation",
CREATE_MESSAGE = "createMessage",
UPDATE_MESSAGE = "updateMessage",
GET_CONVERSATION_MESSAGES = "getConversationMessages",
STORE_MODEL = "storeModel",