chore: update convo summary (#378)

* chore: update convo summary

* chore: conversation summary debouncing & retry attempt

* chore: convention
This commit is contained in:
Louis 2023-10-18 13:38:56 +07:00 committed by GitHub
parent ed4ea2e8a6
commit c623a9bb33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 135 additions and 48 deletions

View File

@ -1,4 +1,5 @@
import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core"; import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core";
import { Observable } from "rxjs";
const inferenceUrl = "http://localhost:3928/llama/chat_completion"; const inferenceUrl = "http://localhost:3928/llama/chat_completion";
@ -8,11 +9,57 @@ const stopModel = () => {
core.invokePluginFunc(MODULE_PATH, "killSubprocess"); core.invokePluginFunc(MODULE_PATH, "killSubprocess");
}; };
async function handleMessageRequest(data: NewMessageRequest) { function requestInference(recentMessages: any[]): Observable<string> {
return new Observable((subscriber) => {
fetch(inferenceUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origin": "*",
},
body: JSON.stringify({
messages: recentMessages,
stream: true,
model: "gpt-3.5-turbo",
max_tokens: 500,
}),
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
console.log("SSE stream closed");
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
}
}
subscriber.complete();
})
.catch(subscriber.error);
});
}
async function retrieveLastTenMessages(conversationId: string) {
// TODO: Common collections should be able to access via core functions instead of store // TODO: Common collections should be able to access via core functions instead of store
const messageHistory = const messageHistory = (await store.findMany("messages", { conversationId }, [{ createdAt: "asc" }])) ?? [];
(await store.findMany("messages", { conversationId: data.conversationId }, [{ createdAt: "asc" }])) ?? []; return messageHistory
const recentMessages = messageHistory
.filter((e) => e.message !== "" && (e.user === "user" || e.user === "assistant")) .filter((e) => e.message !== "" && (e.user === "user" || e.user === "assistant"))
.slice(-10) .slice(-10)
.map((message) => { .map((message) => {
@ -21,7 +68,10 @@ async function handleMessageRequest(data: NewMessageRequest) {
role: message.user === "user" ? "user" : "assistant", role: message.user === "user" ? "user" : "assistant",
}; };
}); });
}
async function handleMessageRequest(data: NewMessageRequest) {
const recentMessages = await retrieveLastTenMessages(data.conversationId);
const message = { const message = {
...data, ...data,
message: "", message: "",
@ -35,49 +85,45 @@ async function handleMessageRequest(data: NewMessageRequest) {
message._id = id; message._id = id;
events.emit(EventName.OnNewMessageResponse, message); events.emit(EventName.OnNewMessageResponse, message);
const response = await fetch(inferenceUrl, { requestInference(recentMessages).subscribe({
method: "POST", next: (content) => {
headers: { message.message = content;
"Content-Type": "application/json", events.emit(EventName.OnMessageResponseUpdate, message);
Accept: "text/event-stream", },
"Access-Control-Allow-Origi": "*", complete: async () => {
message.message = message.message.trim();
// TODO: Common collections should be able to access via core functions instead of store
await store.updateOne("messages", message._id, message);
},
error: async (err) => {
message.message = message.message.trim() + "\n" + "Error occurred: " + err;
// TODO: Common collections should be able to access via core functions instead of store
await store.updateOne("messages", message._id, message);
}, },
body: JSON.stringify({
messages: recentMessages,
stream: true,
model: "gpt-3.5-turbo",
max_tokens: 500,
}),
}); });
const stream = response.body; }
const decoder = new TextDecoder("utf-8"); async function inferenceRequest(data: NewMessageRequest): Promise<any> {
const reader = stream?.getReader(); const message = {
let answer = ""; ...data,
message: "",
while (true && reader) { user: "assistant",
const { done, value } = await reader.read(); createdAt: new Date().toISOString(),
if (done) { };
console.log("SSE stream closed"); return new Promise(async (resolve, reject) => {
break; const recentMessages = await retrieveLastTenMessages(data.conversationId);
} requestInference([...recentMessages, { role: "user", content: data.message }]).subscribe({
const text = decoder.decode(value); next: (content) => {
const lines = text.trim().split("\n"); message.message = content;
for (const line of lines) { },
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { complete: async () => {
const data = JSON.parse(line.replace("data: ", "")); resolve(message);
answer += data.choices[0]?.delta?.content ?? ""; },
if (answer.startsWith("assistant: ")) { error: async (err) => {
answer = answer.replace("assistant: ", ""); reject(err);
} },
message.message = answer; });
events.emit(EventName.OnMessageResponseUpdate, message); });
}
}
}
message.message = answer.trim();
// TODO: Common collections should be able to access via core functions instead of store
await store.updateOne("messages", message._id, message);
} }
const registerListener = () => { const registerListener = () => {
@ -92,4 +138,5 @@ export function init({ register }) {
register(PluginService.OnStart, PLUGIN_NAME, onStart); register(PluginService.OnStart, PLUGIN_NAME, onStart);
register(InferenceService.InitModel, initModel.name, initModel); register(InferenceService.InitModel, initModel.name, initModel);
register(InferenceService.StopModel, stopModel.name, stopModel); register(InferenceService.StopModel, stopModel.name, stopModel);
register(InferenceService.InferenceRequest, inferenceRequest.name, inferenceRequest);
} }

View File

@ -26,8 +26,9 @@
"webpack-cli": "^5.1.4" "webpack-cli": "^5.1.4"
}, },
"dependencies": { "dependencies": {
"@janhq/core": "^0.1.1", "@janhq/core": "^0.1.2",
"kill-port-process": "^3.2.0", "kill-port-process": "^3.2.0",
"rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2", "tcp-port-used": "^1.0.2",
"ts-loader": "^9.5.0" "ts-loader": "^9.5.0"
}, },

View File

@ -1,17 +1,44 @@
import { currentPromptAtom } from "@/_helpers/JotaiWrapper"; import { currentPromptAtom } from "@/_helpers/JotaiWrapper";
import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { DataService, EventName, events } from "@janhq/core"; import { DataService, EventName, InferenceService, events, store } from "@janhq/core";
import { RawMessage, toChatMessage } from "@/_models/ChatMessage"; import { RawMessage, toChatMessage } from "@/_models/ChatMessage";
import { executeSerial } from "@/_services/pluginService"; import { executeSerial } from "@/_services/pluginService";
import { addNewMessageAtom } from "@/_helpers/atoms/ChatMessage.atom"; import { addNewMessageAtom } from "@/_helpers/atoms/ChatMessage.atom";
import { currentConversationAtom } from "@/_helpers/atoms/Conversation.atom"; import { currentConversationAtom, updateConversationAtom } from "@/_helpers/atoms/Conversation.atom";
export default function useSendChatMessage() { export default function useSendChatMessage() {
const currentConvo = useAtomValue(currentConversationAtom); const currentConvo = useAtomValue(currentConversationAtom);
const addNewMessage = useSetAtom(addNewMessageAtom); const addNewMessage = useSetAtom(addNewMessageAtom);
const updateConversation = useSetAtom(updateConversationAtom);
const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom); const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom);
let timeout: any | undefined = undefined;
function updateConvSummary(newMessage: any) {
if (timeout) {
clearTimeout(timeout);
}
timeout = setTimeout(() => {
const conv = currentConvo;
if (!currentConvo?.summary || currentConvo.summary === "" || currentConvo.summary.startsWith("User request:")) {
// Request convo summary
setTimeout(async () => {
newMessage.message = "summary this conversation in 5 words";
const result = await executeSerial(InferenceService.InferenceRequest, newMessage);
if (result?.message && result.message.split(" ").length <= 7 && conv?._id) {
const updatedConv = {
...conv,
summary: result.message,
};
updateConversation(updatedConv);
await executeSerial(DataService.UpdateConversation, updatedConv);
}
}, 1000);
}
}, 100);
}
const sendChatMessage = async () => { const sendChatMessage = async () => {
setCurrentPrompt(""); setCurrentPrompt("");
const prompt = currentPrompt.trim(); const prompt = currentPrompt.trim();
@ -26,7 +53,19 @@ export default function useSendChatMessage() {
const newChatMessage = toChatMessage(newMessage); const newChatMessage = toChatMessage(newMessage);
addNewMessage(newChatMessage); addNewMessage(newChatMessage);
events.emit(EventName.OnNewMessageRequest, newMessage); events.emit(EventName.OnNewMessageRequest, newMessage);
if (!currentConvo?.summary && currentConvo) {
const updatedConv = {
...currentConvo,
summary: `Prompt: ${prompt}`,
};
updateConversation(updatedConv);
await executeSerial(DataService.UpdateConversation, updatedConv);
}
updateConvSummary(newMessage);
}; };
return { return {

View File

@ -13,7 +13,7 @@
"dependencies": { "dependencies": {
"@headlessui/react": "^1.7.15", "@headlessui/react": "^1.7.15",
"@heroicons/react": "^2.0.18", "@heroicons/react": "^2.0.18",
"@janhq/core": "^0.1.1", "@janhq/core": "^0.1.2",
"@tailwindcss/typography": "^0.5.9", "@tailwindcss/typography": "^0.5.9",
"@types/react": "18.2.15", "@types/react": "18.2.15",
"@types/react-dom": "18.2.7", "@types/react-dom": "18.2.7",