add time weighted retrieval (#2908)

* add time weighted retrieval

* add missing configuration for timeWeightedVectorStore

* resolving conflict

* add missing configuration for timeWeightedVectorStore

* resolving conflict

* fix linting issues

* fix build failed due to requirement for useTimeWeightedRetriever in AssistantTool

* update web packages complying the new structure

---------

Co-authored-by: thu <thu@treehouse.finance>
This commit is contained in:
Nathan 2024-06-20 16:34:50 +07:00 committed by Louis
parent 8077eb5cf6
commit 08c60a70c2
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
6 changed files with 135 additions and 17 deletions

View File

@ -6,6 +6,7 @@
export type AssistantTool = {
type: string
enabled: boolean
useTimeWeightedRetriever?: boolean
settings: any
}

View File

@ -126,6 +126,7 @@ export default class JanAssistantExtension extends AssistantExtension {
{
type: 'retrieval',
enabled: false,
useTimeWeightedRetriever: false,
settings: {
top_k: 2,
chunk_size: 1024,

View File

@ -11,13 +11,14 @@ export function toolRetrievalUpdateTextSplitter(
export async function toolRetrievalIngestNewDocument(
file: string,
model: string,
engine: string
engine: string,
useTimeWeighted: boolean
) {
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
const threadPath = path.dirname(filePath.replace('files', ''))
retrieval.updateEmbeddingEngine(model, engine)
return retrieval
.ingestAgentKnowledge(filePath, `${threadPath}/memory`)
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
.catch((err) => {
console.error(err)
})
@ -33,8 +34,11 @@ export async function toolRetrievalLoadThreadMemory(threadId: string) {
})
}
export async function toolRetrievalQueryResult(query: string) {
return retrieval.generateResult(query).catch((err) => {
export async function toolRetrievalQueryResult(
query: string,
useTimeWeighted: boolean = false
) {
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
console.error(err)
})
}

View File

@ -2,11 +2,16 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
import { formatDocumentsAsString } from 'langchain/util/document'
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
import { readEmbeddingEngine } from './engine'
import path from 'path'
export class Retrieval {
public chunkSize: number = 100
public chunkOverlap?: number = 0
@ -15,8 +20,25 @@ export class Retrieval {
private embeddingModel?: OpenAIEmbeddings = undefined
private textSplitter?: RecursiveCharacterTextSplitter
// to support time-weighted retrieval
private timeWeightedVectorStore: MemoryVectorStore
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
this.updateTextSplitter(chunkSize, chunkOverlap)
// declare time-weighted retriever and storage
this.timeWeightedVectorStore = new MemoryVectorStore(
new OpenAIEmbeddings(
{ openAIApiKey: 'nitro-embedding' },
{ basePath: 'http://127.0.0.1:3928/v1' }
)
)
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
vectorStore: this.timeWeightedVectorStore,
memoryStream: [],
searchKwargs: 2,
})
}
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
@ -44,11 +66,15 @@ export class Retrieval {
openAIApiKey: settings.api_key,
})
}
// update time-weighted embedding model
this.timeWeightedVectorStore.embeddings = this.embeddingModel
}
public ingestAgentKnowledge = async (
filePath: string,
memoryPath: string
memoryPath: string,
useTimeWeighted: boolean
): Promise<any> => {
const loader = new PDFLoader(filePath, {
splitPages: true,
@ -57,6 +83,13 @@ export class Retrieval {
const doc = await loader.load()
const docs = await this.textSplitter!.splitDocuments(doc)
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
if (useTimeWeighted && this.timeWeightedretriever) {
await (
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
).addDocuments(docs)
}
return vectorStore.save(memoryPath)
}
@ -67,10 +100,25 @@ export class Retrieval {
return Promise.resolve()
}
public generateResult = async (query: string): Promise<string> => {
public generateResult = async (
query: string,
useTimeWeighted: boolean
): Promise<string> => {
if (useTimeWeighted) {
if (!this.timeWeightedretriever) {
return Promise.resolve(' ')
}
// use invoke because getRelevantDocuments is deprecated
const relevantDocs = await this.timeWeightedretriever.invoke(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)
}
if (!this.retriever) {
return Promise.resolve(' ')
}
// should use invoke(query) because getRelevantDocuments is deprecated
const relevantDocs = await this.retriever.getRelevantDocuments(query)
const serializedDoc = formatDocumentsAsString(relevantDocs)
return Promise.resolve(serializedDoc)

View File

@ -37,7 +37,8 @@ export class RetrievalTool extends InferenceTool {
'toolRetrievalIngestNewDocument',
docFile,
data.model?.id,
data.model?.engine
data.model?.engine,
tool?.useTimeWeightedRetriever ?? false
)
} else {
return Promise.resolve(data)
@ -78,7 +79,8 @@ export class RetrievalTool extends InferenceTool {
const retrievalResult = await executeOnMain(
NODE,
'toolRetrievalQueryResult',
prompt
prompt,
tool?.useTimeWeightedRetriever ?? false
)
console.debug('toolRetrievalQueryResult', retrievalResult)

View File

@ -66,6 +66,32 @@ const Tools = () => {
[activeThread, updateThreadMetadata]
)
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
(enabled: boolean) => {
if (!activeThread) return
updateThreadMetadata({
...activeThread,
assistants: [
{
...activeThread.assistants[0],
tools: [
{
type: 'retrieval',
enabled: true,
useTimeWeightedRetriever: enabled,
settings:
(activeThread.assistants[0].tools &&
activeThread.assistants[0].tools[0]?.settings) ??
{},
},
],
},
],
})
},
[activeThread, updateThreadMetadata]
)
if (!experimentalFeature) return null
return (
@ -143,6 +169,46 @@ const Tools = () => {
className="inline-block font-medium"
>
Vector Database
<Tooltip
trigger={
<InfoIcon
size={16}
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
/>
}
content="Vector Database is crucial for efficient storage
and retrieval of embeddings. Consider your
specific task, available resources, and language
requirements. Experiment to find the best fit for
your specific use case."
/>
</label>
<div className="ml-auto flex items-center justify-between">
<Switch
name="use-time-weighted-retriever"
className="mr-2"
checked={
activeThread?.assistants[0].tools[0]
.useTimeWeightedRetriever || false
}
onChange={(e) =>
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
}
/>
</div>
</div>
<div className="w-full">
<Input value="HNSWLib" disabled readOnly />
</div>
</div>
<div className="mb-4">
<div className="mb-2 flex items-center">
<label
id="use-time-weighted-retriever"
className="inline-block font-medium"
>
Time-Weighted Retrieval?
</label>
<Tooltip
trigger={
@ -151,17 +217,13 @@ const Tools = () => {
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
/>
}
content="Vector Database is crucial for efficient storage
and retrieval of embeddings. Consider your
specific task, available resources, and language
requirements. Experiment to find the best fit for
your specific use case."
content="Time-Weighted Retriever looks at how similar
they are and how new they are. It compares
documents based on their meaning like usual, but
also considers when they were added to give
newer ones more importance."
/>
</div>
<div className="w-full">
<Input value="HNSWLib" disabled readOnly />
</div>
</div>
<AssistantSetting
componentData={componentDataAssistantSetting}