diff --git a/plugins/inference-plugin/index.ts b/plugins/inference-plugin/index.ts index ca2869650..287c52824 100644 --- a/plugins/inference-plugin/index.ts +++ b/plugins/inference-plugin/index.ts @@ -1,4 +1,5 @@ import { EventName, InferenceService, NewMessageRequest, PluginService, core, events, store } from "@janhq/core"; +import { Observable } from "rxjs"; const inferenceUrl = "http://localhost:3928/llama/chat_completion"; @@ -8,11 +9,57 @@ const stopModel = () => { core.invokePluginFunc(MODULE_PATH, "killSubprocess"); }; -async function handleMessageRequest(data: NewMessageRequest) { +function requestInference(recentMessages: any[]): Observable { + return new Observable((subscriber) => { + fetch(inferenceUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + "Access-Control-Allow-Origin": "*", + }, + body: JSON.stringify({ + messages: recentMessages, + stream: true, + model: "gpt-3.5-turbo", + max_tokens: 500, + }), + }) + .then(async (response) => { + const stream = response.body; + const decoder = new TextDecoder("utf-8"); + const reader = stream?.getReader(); + let content = ""; + + while (true && reader) { + const { done, value } = await reader.read(); + if (done) { + console.log("SSE stream closed"); + break; + } + const text = decoder.decode(value); + const lines = text.trim().split("\n"); + for (const line of lines) { + if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { + const data = JSON.parse(line.replace("data: ", "")); + content += data.choices[0]?.delta?.content ?? ""; + if (content.startsWith("assistant: ")) { + content = content.replace("assistant: ", ""); + } + subscriber.next(content); + } + } + } + subscriber.complete(); + }) + .catch(subscriber.error); + }); +} + +async function retrieveLastTenMessages(conversationId: string) { // TODO: Common collections should be able to access via core functions instead of store - const messageHistory = - (await store.findMany("messages", { conversationId: data.conversationId }, [{ createdAt: "asc" }])) ?? []; - const recentMessages = messageHistory + const messageHistory = (await store.findMany("messages", { conversationId }, [{ createdAt: "asc" }])) ?? []; + return messageHistory .filter((e) => e.message !== "" && (e.user === "user" || e.user === "assistant")) .slice(-10) .map((message) => { @@ -21,7 +68,10 @@ async function handleMessageRequest(data: NewMessageRequest) { role: message.user === "user" ? "user" : "assistant", }; }); +} +async function handleMessageRequest(data: NewMessageRequest) { + const recentMessages = await retrieveLastTenMessages(data.conversationId); const message = { ...data, message: "", @@ -35,49 +85,45 @@ async function handleMessageRequest(data: NewMessageRequest) { message._id = id; events.emit(EventName.OnNewMessageResponse, message); - const response = await fetch(inferenceUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "text/event-stream", - "Access-Control-Allow-Origi": "*", + requestInference(recentMessages).subscribe({ + next: (content) => { + message.message = content; + events.emit(EventName.OnMessageResponseUpdate, message); + }, + complete: async () => { + message.message = message.message.trim(); + // TODO: Common collections should be able to access via core functions instead of store + await store.updateOne("messages", message._id, message); + }, + error: async (err) => { + message.message = message.message.trim() + "\n" + "Error occurred: " + err; + // TODO: Common collections should be able to access via core functions instead of store + await store.updateOne("messages", message._id, message); }, - body: JSON.stringify({ - messages: recentMessages, - stream: true, - model: "gpt-3.5-turbo", - max_tokens: 500, - }), }); - const stream = response.body; +} - const decoder = new TextDecoder("utf-8"); - const reader = stream?.getReader(); - let answer = ""; - - while (true && reader) { - const { done, value } = await reader.read(); - if (done) { - console.log("SSE stream closed"); - break; - } - const text = decoder.decode(value); - const lines = text.trim().split("\n"); - for (const line of lines) { - if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { - const data = JSON.parse(line.replace("data: ", "")); - answer += data.choices[0]?.delta?.content ?? ""; - if (answer.startsWith("assistant: ")) { - answer = answer.replace("assistant: ", ""); - } - message.message = answer; - events.emit(EventName.OnMessageResponseUpdate, message); - } - } - } - message.message = answer.trim(); - // TODO: Common collections should be able to access via core functions instead of store - await store.updateOne("messages", message._id, message); +async function inferenceRequest(data: NewMessageRequest): Promise { + const message = { + ...data, + message: "", + user: "assistant", + createdAt: new Date().toISOString(), + }; + return new Promise(async (resolve, reject) => { + const recentMessages = await retrieveLastTenMessages(data.conversationId); + requestInference([...recentMessages, { role: "user", content: data.message }]).subscribe({ + next: (content) => { + message.message = content; + }, + complete: async () => { + resolve(message); + }, + error: async (err) => { + reject(err); + }, + }); + }); } const registerListener = () => { @@ -92,4 +138,5 @@ export function init({ register }) { register(PluginService.OnStart, PLUGIN_NAME, onStart); register(InferenceService.InitModel, initModel.name, initModel); register(InferenceService.StopModel, stopModel.name, stopModel); + register(InferenceService.InferenceRequest, inferenceRequest.name, inferenceRequest); } diff --git a/plugins/inference-plugin/package.json b/plugins/inference-plugin/package.json index bfc554ef7..9f547bf50 100644 --- a/plugins/inference-plugin/package.json +++ b/plugins/inference-plugin/package.json @@ -26,8 +26,9 @@ "webpack-cli": "^5.1.4" }, "dependencies": { - "@janhq/core": "^0.1.1", + "@janhq/core": "^0.1.2", "kill-port-process": "^3.2.0", + "rxjs": "^7.8.1", "tcp-port-used": "^1.0.2", "ts-loader": "^9.5.0" }, diff --git a/web/app/_hooks/useSendChatMessage.ts b/web/app/_hooks/useSendChatMessage.ts index 88a900a0d..769b1158e 100644 --- a/web/app/_hooks/useSendChatMessage.ts +++ b/web/app/_hooks/useSendChatMessage.ts @@ -1,17 +1,44 @@ import { currentPromptAtom } from "@/_helpers/JotaiWrapper"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { DataService, EventName, events } from "@janhq/core"; +import { DataService, EventName, InferenceService, events, store } from "@janhq/core"; import { RawMessage, toChatMessage } from "@/_models/ChatMessage"; import { executeSerial } from "@/_services/pluginService"; import { addNewMessageAtom } from "@/_helpers/atoms/ChatMessage.atom"; -import { currentConversationAtom } from "@/_helpers/atoms/Conversation.atom"; +import { currentConversationAtom, updateConversationAtom } from "@/_helpers/atoms/Conversation.atom"; export default function useSendChatMessage() { const currentConvo = useAtomValue(currentConversationAtom); const addNewMessage = useSetAtom(addNewMessageAtom); + const updateConversation = useSetAtom(updateConversationAtom); const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom); + let timeout: any | undefined = undefined; + + function updateConvSummary(newMessage: any) { + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(() => { + const conv = currentConvo; + if (!currentConvo?.summary || currentConvo.summary === "" || currentConvo.summary.startsWith("User request:")) { + // Request convo summary + setTimeout(async () => { + newMessage.message = "summary this conversation in 5 words"; + const result = await executeSerial(InferenceService.InferenceRequest, newMessage); + if (result?.message && result.message.split(" ").length <= 7 && conv?._id) { + const updatedConv = { + ...conv, + summary: result.message, + }; + updateConversation(updatedConv); + await executeSerial(DataService.UpdateConversation, updatedConv); + } + }, 1000); + } + }, 100); + } + const sendChatMessage = async () => { setCurrentPrompt(""); const prompt = currentPrompt.trim(); @@ -26,7 +53,19 @@ export default function useSendChatMessage() { const newChatMessage = toChatMessage(newMessage); addNewMessage(newChatMessage); + events.emit(EventName.OnNewMessageRequest, newMessage); + + if (!currentConvo?.summary && currentConvo) { + const updatedConv = { + ...currentConvo, + summary: `Prompt: ${prompt}`, + }; + updateConversation(updatedConv); + await executeSerial(DataService.UpdateConversation, updatedConv); + } + + updateConvSummary(newMessage); }; return { diff --git a/web/package.json b/web/package.json index ec05df5d8..42126f34b 100644 --- a/web/package.json +++ b/web/package.json @@ -13,7 +13,7 @@ "dependencies": { "@headlessui/react": "^1.7.15", "@heroicons/react": "^2.0.18", - "@janhq/core": "^0.1.1", + "@janhq/core": "^0.1.2", "@tailwindcss/typography": "^0.5.9", "@types/react": "18.2.15", "@types/react-dom": "18.2.7",