fix: retrieval stuck at generating response (#1988)

This commit is contained in:
Louis 2024-02-11 08:27:26 +07:00 committed by GitHub
parent d371120595
commit 0db1763c2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 82 additions and 51 deletions

View File

@ -25,25 +25,11 @@ import { migrateExtensions } from './utils/migration'
import { cleanUpAndQuit } from './utils/clean' import { cleanUpAndQuit } from './utils/clean'
import { setupExtensions } from './utils/extension' import { setupExtensions } from './utils/extension'
import { setupCore } from './utils/setup' import { setupCore } from './utils/setup'
import { setupReactDevTool } from './utils/dev'
app app
.whenReady() .whenReady()
.then(async () => { .then(setupReactDevTool)
if (!app.isPackaged) {
// Which means you're running from source code
const { default: installExtension, REACT_DEVELOPER_TOOLS } = await import(
'electron-devtools-installer'
) // Don't use import on top level, since the installer package is dev-only
try {
const name = installExtension(REACT_DEVELOPER_TOOLS)
console.log(`Added Extension: ${name}`)
} catch (err) {
console.log('An error occurred while installing devtools:')
console.error(err)
// Only log the error and don't throw it because it's not critical
}
}
})
.then(setupCore) .then(setupCore)
.then(createUserSpace) .then(createUserSpace)
.then(migrateExtensions) .then(migrateExtensions)

18
electron/utils/dev.ts Normal file
View File

@ -0,0 +1,18 @@
import { app } from 'electron'
export const setupReactDevTool = async () => {
if (!app.isPackaged) {
// Which means you're running from source code
const { default: installExtension, REACT_DEVELOPER_TOOLS } = await import(
'electron-devtools-installer'
) // Don't use import on top level, since the installer package is dev-only
try {
const name = installExtension(REACT_DEVELOPER_TOOLS)
console.log(`Added Extension: ${name}`)
} catch (err) {
console.log('An error occurred while installing devtools:')
console.error(err)
// Only log the error and don't throw it because it's not critical
}
}
}

View File

@ -1,6 +1,6 @@
{ {
"name": "@janhq/assistant-extension", "name": "@janhq/assistant-extension",
"version": "1.0.0", "version": "1.0.1",
"description": "This extension enables assistants, including Jan, a default assistant that can call all downloaded models", "description": "This extension enables assistants, including Jan, a default assistant that can call all downloaded models",
"main": "dist/index.js", "main": "dist/index.js",
"node": "dist/node/index.js", "node": "dist/node/index.js",

View File

@ -14,6 +14,7 @@ import {
export default class JanAssistantExtension extends AssistantExtension { export default class JanAssistantExtension extends AssistantExtension {
private static readonly _homeDir = "file://assistants"; private static readonly _homeDir = "file://assistants";
private static readonly _threadDir = "file://threads";
controller = new AbortController(); controller = new AbortController();
isCancelled = false; isCancelled = false;
@ -64,6 +65,8 @@ export default class JanAssistantExtension extends AssistantExtension {
if ( if (
data.model?.engine !== InferenceEngine.tool_retrieval_enabled || data.model?.engine !== InferenceEngine.tool_retrieval_enabled ||
!data.messages || !data.messages ||
// TODO: Since the engine is defined, its unsafe to assume that assistant tools are defined
// That could lead to an issue where thread stuck at generating response
!data.thread?.assistants[0]?.tools !data.thread?.assistants[0]?.tools
) { ) {
return; return;
@ -71,11 +74,12 @@ export default class JanAssistantExtension extends AssistantExtension {
const latestMessage = data.messages[data.messages.length - 1]; const latestMessage = data.messages[data.messages.length - 1];
// Ingest the document if needed // 1. Ingest the document if needed
if ( if (
latestMessage && latestMessage &&
latestMessage.content && latestMessage.content &&
typeof latestMessage.content !== "string" typeof latestMessage.content !== "string" &&
latestMessage.content.length > 1
) { ) {
const docFile = latestMessage.content[1]?.doc_url?.url; const docFile = latestMessage.content[1]?.doc_url?.url;
if (docFile) { if (docFile) {
@ -86,9 +90,29 @@ export default class JanAssistantExtension extends AssistantExtension {
data.model?.proxyEngine data.model?.proxyEngine
); );
} }
} else if (
// Check whether we need to ingest document or not
// Otherwise wrong context will be sent
!(await fs.existsSync(
await joinPath([
JanAssistantExtension._threadDir,
data.threadId,
"memory",
])
))
) {
// No document ingested, reroute the result to inference engine
const output = {
...data,
model: {
...data.model,
engine: data.model.proxyEngine,
},
};
events.emit(MessageEvent.OnMessageSent, output);
return;
} }
// 2. Load agent on thread changed
// Load agent on thread changed
if (instance.retrievalThreadId !== data.threadId) { if (instance.retrievalThreadId !== data.threadId) {
await executeOnMain(NODE, "toolRetrievalLoadThreadMemory", data.threadId); await executeOnMain(NODE, "toolRetrievalLoadThreadMemory", data.threadId);
@ -103,22 +127,22 @@ export default class JanAssistantExtension extends AssistantExtension {
); );
} }
// 3. Using the retrieval template with the result and query
if (latestMessage.content) { if (latestMessage.content) {
const prompt = const prompt =
typeof latestMessage.content === "string" typeof latestMessage.content === "string"
? latestMessage.content ? latestMessage.content
: latestMessage.content[0].text; : latestMessage.content[0].text;
// Retrieve the result // Retrieve the result
console.debug("toolRetrievalQuery", latestMessage.content);
const retrievalResult = await executeOnMain( const retrievalResult = await executeOnMain(
NODE, NODE,
"toolRetrievalQueryResult", "toolRetrievalQueryResult",
prompt prompt
); );
console.debug("toolRetrievalQueryResult", retrievalResult);
// Update the message content // Update message content
// Using the retrieval template with the result and query if (data.thread?.assistants[0]?.tools && retrievalResult)
if (data.thread?.assistants[0].tools)
data.messages[data.messages.length - 1].content = data.messages[data.messages.length - 1].content =
data.thread.assistants[0].tools[0].settings?.retrieval_template data.thread.assistants[0].tools[0].settings?.retrieval_template
?.replace("{CONTEXT}", retrievalResult) ?.replace("{CONTEXT}", retrievalResult)
@ -140,7 +164,7 @@ export default class JanAssistantExtension extends AssistantExtension {
return message; return message;
}); });
// Reroute the result to inference engine // 4. Reroute the result to inference engine
const output = { const output = {
...data, ...data,
model: { model: {

View File

@ -1,39 +1,39 @@
import { getJanDataFolderPath, normalizeFilePath } from "@janhq/core/node"; import { getJanDataFolderPath, normalizeFilePath } from "@janhq/core/node";
import { Retrieval } from "./tools/retrieval"; import { retrieval } from "./tools/retrieval";
import path from "path"; import path from "path";
const retrieval = new Retrieval(); export function toolRetrievalUpdateTextSplitter(
export async function toolRetrievalUpdateTextSplitter(
chunkSize: number, chunkSize: number,
chunkOverlap: number, chunkOverlap: number
) { ) {
retrieval.updateTextSplitter(chunkSize, chunkOverlap); retrieval.updateTextSplitter(chunkSize, chunkOverlap);
return Promise.resolve();
} }
export async function toolRetrievalIngestNewDocument( export async function toolRetrievalIngestNewDocument(
file: string, file: string,
engine: string, engine: string
) { ) {
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file)); const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file));
const threadPath = path.dirname(filePath.replace("files", "")); const threadPath = path.dirname(filePath.replace("files", ""));
retrieval.updateEmbeddingEngine(engine); retrieval.updateEmbeddingEngine(engine);
await retrieval.ingestAgentKnowledge(filePath, `${threadPath}/memory`); return retrieval
return Promise.resolve(); .ingestAgentKnowledge(filePath, `${threadPath}/memory`)
.catch((err) => {
console.error(err);
});
} }
export async function toolRetrievalLoadThreadMemory(threadId: string) { export async function toolRetrievalLoadThreadMemory(threadId: string) {
try { return retrieval
await retrieval.loadRetrievalAgent( .loadRetrievalAgent(
path.join(getJanDataFolderPath(), "threads", threadId, "memory"), path.join(getJanDataFolderPath(), "threads", threadId, "memory")
); )
return Promise.resolve(); .catch((err) => {
} catch (err) { console.error(err);
console.debug(err); });
}
} }
export async function toolRetrievalQueryResult(query: string) { export async function toolRetrievalQueryResult(query: string) {
const res = await retrieval.generateResult(query); return retrieval.generateResult(query).catch((err) => {
return Promise.resolve(res); console.error(err);
});
} }

View File

@ -35,6 +35,7 @@ export class Retrieval {
if (engine === "nitro") { if (engine === "nitro") {
this.embeddingModel = new OpenAIEmbeddings( this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: "nitro-embedding" }, { openAIApiKey: "nitro-embedding" },
// TODO: Raw settings
{ basePath: "http://127.0.0.1:3928/v1" }, { basePath: "http://127.0.0.1:3928/v1" },
); );
} else { } else {
@ -75,3 +76,5 @@ export class Retrieval {
return Promise.resolve(serializedDoc); return Promise.resolve(serializedDoc);
}; };
} }
export const retrieval = new Retrieval();