add time weighted retrieval (#2908)
* add time weighted retrieval * add missing configuration for timeWeightedVectorStore * resolving conflict * add missing configuration for timeWeightedVectorStore * resolving conflict * fix linting issues * fix build failed due to requirement for useTimeWeightedRetriever in AssistantTool * update web packages complying the new structure --------- Co-authored-by: thu <thu@treehouse.finance>
This commit is contained in:
parent
8077eb5cf6
commit
08c60a70c2
@ -6,6 +6,7 @@
|
||||
export type AssistantTool = {
|
||||
type: string
|
||||
enabled: boolean
|
||||
useTimeWeightedRetriever?: boolean
|
||||
settings: any
|
||||
}
|
||||
|
||||
|
||||
@ -126,6 +126,7 @@ export default class JanAssistantExtension extends AssistantExtension {
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: false,
|
||||
useTimeWeightedRetriever: false,
|
||||
settings: {
|
||||
top_k: 2,
|
||||
chunk_size: 1024,
|
||||
|
||||
@ -11,13 +11,14 @@ export function toolRetrievalUpdateTextSplitter(
|
||||
export async function toolRetrievalIngestNewDocument(
|
||||
file: string,
|
||||
model: string,
|
||||
engine: string
|
||||
engine: string,
|
||||
useTimeWeighted: boolean
|
||||
) {
|
||||
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
|
||||
const threadPath = path.dirname(filePath.replace('files', ''))
|
||||
retrieval.updateEmbeddingEngine(model, engine)
|
||||
return retrieval
|
||||
.ingestAgentKnowledge(filePath, `${threadPath}/memory`)
|
||||
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
|
||||
.catch((err) => {
|
||||
console.error(err)
|
||||
})
|
||||
@ -33,8 +34,11 @@ export async function toolRetrievalLoadThreadMemory(threadId: string) {
|
||||
})
|
||||
}
|
||||
|
||||
export async function toolRetrievalQueryResult(query: string) {
|
||||
return retrieval.generateResult(query).catch((err) => {
|
||||
export async function toolRetrievalQueryResult(
|
||||
query: string,
|
||||
useTimeWeighted: boolean = false
|
||||
) {
|
||||
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
|
||||
console.error(err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -2,11 +2,16 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
|
||||
import { formatDocumentsAsString } from 'langchain/util/document'
|
||||
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
|
||||
|
||||
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
|
||||
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
|
||||
|
||||
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
|
||||
|
||||
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
|
||||
import { readEmbeddingEngine } from './engine'
|
||||
|
||||
import path from 'path'
|
||||
|
||||
export class Retrieval {
|
||||
public chunkSize: number = 100
|
||||
public chunkOverlap?: number = 0
|
||||
@ -15,8 +20,25 @@ export class Retrieval {
|
||||
private embeddingModel?: OpenAIEmbeddings = undefined
|
||||
private textSplitter?: RecursiveCharacterTextSplitter
|
||||
|
||||
// to support time-weighted retrieval
|
||||
private timeWeightedVectorStore: MemoryVectorStore
|
||||
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
|
||||
|
||||
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
|
||||
this.updateTextSplitter(chunkSize, chunkOverlap)
|
||||
|
||||
// declare time-weighted retriever and storage
|
||||
this.timeWeightedVectorStore = new MemoryVectorStore(
|
||||
new OpenAIEmbeddings(
|
||||
{ openAIApiKey: 'nitro-embedding' },
|
||||
{ basePath: 'http://127.0.0.1:3928/v1' }
|
||||
)
|
||||
)
|
||||
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
|
||||
vectorStore: this.timeWeightedVectorStore,
|
||||
memoryStream: [],
|
||||
searchKwargs: 2,
|
||||
})
|
||||
}
|
||||
|
||||
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
|
||||
@ -44,11 +66,15 @@ export class Retrieval {
|
||||
openAIApiKey: settings.api_key,
|
||||
})
|
||||
}
|
||||
|
||||
// update time-weighted embedding model
|
||||
this.timeWeightedVectorStore.embeddings = this.embeddingModel
|
||||
}
|
||||
|
||||
public ingestAgentKnowledge = async (
|
||||
filePath: string,
|
||||
memoryPath: string
|
||||
memoryPath: string,
|
||||
useTimeWeighted: boolean
|
||||
): Promise<any> => {
|
||||
const loader = new PDFLoader(filePath, {
|
||||
splitPages: true,
|
||||
@ -57,6 +83,13 @@ export class Retrieval {
|
||||
const doc = await loader.load()
|
||||
const docs = await this.textSplitter!.splitDocuments(doc)
|
||||
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
|
||||
|
||||
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
|
||||
if (useTimeWeighted && this.timeWeightedretriever) {
|
||||
await (
|
||||
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
|
||||
).addDocuments(docs)
|
||||
}
|
||||
return vectorStore.save(memoryPath)
|
||||
}
|
||||
|
||||
@ -67,10 +100,25 @@ export class Retrieval {
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
public generateResult = async (query: string): Promise<string> => {
|
||||
public generateResult = async (
|
||||
query: string,
|
||||
useTimeWeighted: boolean
|
||||
): Promise<string> => {
|
||||
if (useTimeWeighted) {
|
||||
if (!this.timeWeightedretriever) {
|
||||
return Promise.resolve(' ')
|
||||
}
|
||||
// use invoke because getRelevantDocuments is deprecated
|
||||
const relevantDocs = await this.timeWeightedretriever.invoke(query)
|
||||
const serializedDoc = formatDocumentsAsString(relevantDocs)
|
||||
return Promise.resolve(serializedDoc)
|
||||
}
|
||||
|
||||
if (!this.retriever) {
|
||||
return Promise.resolve(' ')
|
||||
}
|
||||
|
||||
// should use invoke(query) because getRelevantDocuments is deprecated
|
||||
const relevantDocs = await this.retriever.getRelevantDocuments(query)
|
||||
const serializedDoc = formatDocumentsAsString(relevantDocs)
|
||||
return Promise.resolve(serializedDoc)
|
||||
|
||||
@ -37,7 +37,8 @@ export class RetrievalTool extends InferenceTool {
|
||||
'toolRetrievalIngestNewDocument',
|
||||
docFile,
|
||||
data.model?.id,
|
||||
data.model?.engine
|
||||
data.model?.engine,
|
||||
tool?.useTimeWeightedRetriever ?? false
|
||||
)
|
||||
} else {
|
||||
return Promise.resolve(data)
|
||||
@ -78,7 +79,8 @@ export class RetrievalTool extends InferenceTool {
|
||||
const retrievalResult = await executeOnMain(
|
||||
NODE,
|
||||
'toolRetrievalQueryResult',
|
||||
prompt
|
||||
prompt,
|
||||
tool?.useTimeWeightedRetriever ?? false
|
||||
)
|
||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||
|
||||
|
||||
@ -66,6 +66,32 @@ const Tools = () => {
|
||||
[activeThread, updateThreadMetadata]
|
||||
)
|
||||
|
||||
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
|
||||
(enabled: boolean) => {
|
||||
if (!activeThread) return
|
||||
updateThreadMetadata({
|
||||
...activeThread,
|
||||
assistants: [
|
||||
{
|
||||
...activeThread.assistants[0],
|
||||
tools: [
|
||||
{
|
||||
type: 'retrieval',
|
||||
enabled: true,
|
||||
useTimeWeightedRetriever: enabled,
|
||||
settings:
|
||||
(activeThread.assistants[0].tools &&
|
||||
activeThread.assistants[0].tools[0]?.settings) ??
|
||||
{},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
},
|
||||
[activeThread, updateThreadMetadata]
|
||||
)
|
||||
|
||||
if (!experimentalFeature) return null
|
||||
|
||||
return (
|
||||
@ -143,6 +169,46 @@ const Tools = () => {
|
||||
className="inline-block font-medium"
|
||||
>
|
||||
Vector Database
|
||||
<Tooltip
|
||||
trigger={
|
||||
<InfoIcon
|
||||
size={16}
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Vector Database is crucial for efficient storage
|
||||
and retrieval of embeddings. Consider your
|
||||
specific task, available resources, and language
|
||||
requirements. Experiment to find the best fit for
|
||||
your specific use case."
|
||||
/>
|
||||
</label>
|
||||
<div className="ml-auto flex items-center justify-between">
|
||||
<Switch
|
||||
name="use-time-weighted-retriever"
|
||||
className="mr-2"
|
||||
checked={
|
||||
activeThread?.assistants[0].tools[0]
|
||||
.useTimeWeightedRetriever || false
|
||||
}
|
||||
onChange={(e) =>
|
||||
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
<Input value="HNSWLib" disabled readOnly />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center">
|
||||
<label
|
||||
id="use-time-weighted-retriever"
|
||||
className="inline-block font-medium"
|
||||
>
|
||||
Time-Weighted Retrieval?
|
||||
</label>
|
||||
<Tooltip
|
||||
trigger={
|
||||
@ -151,17 +217,13 @@ const Tools = () => {
|
||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||
/>
|
||||
}
|
||||
content="Vector Database is crucial for efficient storage
|
||||
and retrieval of embeddings. Consider your
|
||||
specific task, available resources, and language
|
||||
requirements. Experiment to find the best fit for
|
||||
your specific use case."
|
||||
content="Time-Weighted Retriever looks at how similar
|
||||
they are and how new they are. It compares
|
||||
documents based on their meaning like usual, but
|
||||
also considers when they were added to give
|
||||
newer ones more importance."
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="w-full">
|
||||
<Input value="HNSWLib" disabled readOnly />
|
||||
</div>
|
||||
</div>
|
||||
<AssistantSetting
|
||||
componentData={componentDataAssistantSetting}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user