fix: retrieval always ask for api key

This commit is contained in:
Louis 2024-01-29 22:44:13 +07:00
parent 00a109d46b
commit 12ebf272d6
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
5 changed files with 19 additions and 19 deletions

View File

@ -12,12 +12,11 @@ export class Retrieval {
public chunkOverlap?: number = 0;
private retriever: any;
private embeddingModel: any = undefined;
private embeddingModel?: OpenAIEmbeddings = undefined;
private textSplitter?: RecursiveCharacterTextSplitter;
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap);
this.embeddingModel = new OpenAIEmbeddings({});
}
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
@ -36,7 +35,7 @@ export class Retrieval {
if (engine === "nitro") {
this.embeddingModel = new OpenAIEmbeddings(
{ openAIApiKey: "nitro-embedding" },
{ basePath: "http://127.0.0.1:3928/v1" },
{ basePath: "http://127.0.0.1:3928/v1" }
);
} else {
// Fallback to OpenAI Settings
@ -50,11 +49,12 @@ export class Retrieval {
public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string,
memoryPath: string
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
});
if (!this.embeddingModel) return Promise.reject();
const doc = await loader.load();
const docs = await this.textSplitter!.splitDocuments(doc);
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel);
@ -62,6 +62,7 @@ export class Retrieval {
};
public loadRetrievalAgent = async (memoryPath: string): Promise<void> => {
if (!this.embeddingModel) return Promise.reject();
const vectorStore = await HNSWLib.load(memoryPath, this.embeddingModel);
this.retriever = vectorStore.asRetriever(2);
return Promise.resolve();

View File

@ -119,19 +119,20 @@ export default class JSONConversationalExtension extends ConversationalExtension
if (!(await fs.existsSync(threadDirPath)))
await fs.mkdirSync(threadDirPath)
if (message.content[0].type === 'image') {
if (message.content[0]?.type === 'image') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath)
const imagePath = await joinPath([filesPath, `${message.id}.png`])
const base64 = message.content[0].text.annotations[0]
await this.storeImage(base64, imagePath)
// if (fs.existsSync(imagePath)) {
// message.content[0].text.annotations[0] = imagePath
// }
if ((await fs.existsSync(imagePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png`
}
}
if (message.content[0].type === 'pdf') {
if (message.content[0]?.type === 'pdf') {
const filesPath = await joinPath([threadDirPath, 'files'])
if (!(await fs.existsSync(filesPath))) await fs.mkdirSync(filesPath)
@ -139,7 +140,7 @@ export default class JSONConversationalExtension extends ConversationalExtension
const blob = message.content[0].text.annotations[0]
await this.storeFile(blob, filePath)
if (await fs.existsSync(filePath)) {
if ((await fs.existsSync(filePath)) && message.content?.length) {
// Use file path instead of blob
message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf`
}

View File

@ -100,6 +100,8 @@ export default function EventHandler({ children }: { children: ReactNode }) {
message.status
)
if (message.status === MessageStatus.Pending) {
if (message.content.length)
updateThreadWaiting(message.thread_id, false)
return
}
// Mark the thread as not waiting for response

View File

@ -98,14 +98,7 @@ const ChatBody: React.FC = () => {
</div>
))}
{activeModel &&
(isGeneratingResponse ||
(messages.length &&
messages[messages.length - 1].status ===
MessageStatus.Pending &&
!messages[messages.length - 1].content.length)) && (
<GenerateResponse />
)}
{activeModel && isGeneratingResponse && <GenerateResponse />}
</ScrollToBottom>
)}
</Fragment>

View File

@ -4,6 +4,7 @@ import {
ThreadMessage,
ChatCompletionRole,
ConversationalExtension,
ContentType,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { RefreshCcw, CopyIcon, Trash2Icon, CheckIcon } from 'lucide-react'
@ -53,7 +54,9 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
<div className={twMerge('flex flex-row items-center')}>
<div className="flex overflow-hidden rounded-md border border-border bg-background/20">
{message.id === messages[messages.length - 1]?.id &&
messages[messages.length - 1].status !== MessageStatus.Error && (
messages[messages.length - 1].status !== MessageStatus.Error &&
messages[messages.length - 1].content[0]?.type !==
ContentType.Pdf && (
<div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={onRegenerateClick}