fix: retrieval always ask for api key

This commit is contained in:
Louis 2024-01-29 22:44:13 +07:00
parent 00a109d46b
commit 12ebf272d6
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
5 changed files with 19 additions and 19 deletions

View File

@ -12,12 +12,11 @@ export class Retrieval {
public chunkOverlap?: number = 0; public chunkOverlap?: number = 0;
private retriever: any; private retriever: any;
private embeddingModel: any = undefined; private embeddingModel?: OpenAIEmbeddings = undefined;
private textSplitter?: RecursiveCharacterTextSplitter; private textSplitter?: RecursiveCharacterTextSplitter;
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) { constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap); this.updateTextSplitter(chunkSize, chunkOverlap);
this.embeddingModel = new OpenAIEmbeddings({});
} }
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void { public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
@ -36,7 +35,7 @@ export class Retrieval {
if (engine === "nitro") { if (engine === "nitro") {
this.embeddingModel = new OpenAIEmbeddings( this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: "nitro-embedding" }, { openAIApiKey: "nitro-embedding" },
{ basePath: "http://127.0.0.1:3928/v1" }, { basePath: "http://127.0.0.1:3928/v1" }
); );
} else { } else {
// Fallback to OpenAI Settings // Fallback to OpenAI Settings
@ -50,11 +49,12 @@ export class Retrieval {
public ingestAgentKnowledge = async ( public ingestAgentKnowledge = async (
filePath: string, filePath: string,
memoryPath: string, memoryPath: string
): Promise<any> => { ): Promise<any> => {
const loader = new PDFLoader(filePath, { const loader = new PDFLoader(filePath, {
splitPages: true, splitPages: true,
}); });
if (!this.embeddingModel) return Promise.reject();
const doc = await loader.load(); const doc = await loader.load();
const docs = await this.textSplitter!.splitDocuments(doc); const docs = await this.textSplitter!.splitDocuments(doc);
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel); const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel);
@ -62,6 +62,7 @@ export class Retrieval {
}; };
public loadRetrievalAgent = async (memoryPath: string): Promise<void> => { public loadRetrievalAgent = async (memoryPath: string): Promise<void> => {
if (!this.embeddingModel) return Promise.reject();
const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel); const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel);
this.retriever = vectorStore.asRetriever(2); this.retriever = vectorStore.asRetriever(2);
return Promise.resolve(); return Promise.resolve();

View File

@ -119,19 +119,20 @@ export default class JSONConversationalExtension extends ConversationalExtension
if (!(await fs.existsSync(threadDirPath))) if (!(await fs.existsSync(threadDirPath)))
await fs.mkdirSync(threadDirPath) await fs.mkdirSync(threadDirPath)
if (message.content[0].type === 'image') { if (message.content[0]?.type === 'image') {
const filesPath = await joinPath([threadDirPath, 'files']) const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath) if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath)
const imagePath = await joinPath([filesPath, `${message.id}.png`]) const imagePath = await joinPath([filesPath, `${message.id}.png`])
const base64 = message.content[0].text.annotations[0] const base64 = message.content[0].text.annotations[0]
await this.storeImage(base64, imagePath) await this.storeImage(base64, imagePath)
// if (fs.existsSync(imagePath)) { if ((await fs.existsSync(imagePath)) && message.content?.length) {
// message.content[0].text.annotations[0] = imagePath // Use file path instead of blob
// } message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
}
} }
if (message.content[0].type === 'pdf') { if (message.content[0]?.type === 'pdf') {
const filesPath = await joinPath([threadDirPath, 'files']) const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath) if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath)
@ -139,7 +140,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
const blob = message.content[0].text.annotations[0] const blob = message.content[0].text.annotations[0]
await this.storeFile(blob, filePath) await this.storeFile(blob, filePath)
if (await fs.existsSync(filePath)) { if ((await fs.existsSync(filePath)) && message.content?.length) {
// Use file path instead of blob // Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf` message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
} }

View File

@ -100,6 +100,8 @@ export default function EventHandler({ children }: { children: ReactNode }) {
message.status message.status
) )
if (message.status === MessageStatus.Pending) { if (message.status === MessageStatus.Pending) {
if (message.content.length)
updateThreadWaiting(message.thread_id, false)
return return
} }
// Mark the thread as not waiting for response // Mark the thread as not waiting for response

View File

@ -98,14 +98,7 @@ const ChatBody: React.FC = () => {
</div> </div>
))} ))}
{activeModel && {activeModel && isGeneratingResponse && <GenerateResponse />}
(isGeneratingResponse ||
(messages.length &&
messages[messages.length - 1].status ===
MessageStatus.Pending &&
!messages[messages.length - 1].content.length)) && (
<GenerateResponse />
)}
</ScrollToBottom> </ScrollToBottom>
)} )}
</Fragment> </Fragment>

View File

@ -4,6 +4,7 @@ import {
ThreadMessage, ThreadMessage,
ChatCompletionRole, ChatCompletionRole,
ConversationalExtension, ConversationalExtension,
ContentType,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { RefreshCcw, CopyIcon, Trash2Icon, CheckIcon } from 'lucide-react' import { RefreshCcw, CopyIcon, Trash2Icon, CheckIcon } from 'lucide-react'
@ -53,7 +54,9 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
<div className={twMerge('flex flex-row items-center')}> <div className={twMerge('flex flex-row items-center')}>
<div className="flex overflow-hidden rounded-md border border-border bg-background/20"> <div className="flex overflow-hidden rounded-md border border-border bg-background/20">
{message.id === messages[messages.length - 1]?.id && {message.id === messages[messages.length - 1]?.id &&
messages[messages.length - 1].status !== MessageStatus.Error && ( messages[messages.length - 1].status !== MessageStatus.Error &&
messages[messages.length - 1].content[0]?.type !==
ContentType.Pdf && (
<div <div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80" className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={onRegenerateClick} onClick={onRegenerateClick}