add time weighted retrieval (#2908)
* add time weighted retrieval * add missing configuration for timeWeightedVectorStore * resolving conflict * add missing configuration for timeWeightedVectorStore * resolving conflict * fix linting issues * fix build failed due to requirement for useTimeWeightedRetriever in AssistantTool * update web packages complying the new structure --------- Co-authored-by: thu <thu@treehouse.finance>
This commit is contained in:
parent
8077eb5cf6
commit
08c60a70c2
@ -6,6 +6,7 @@
|
|||||||
export type AssistantTool = {
|
export type AssistantTool = {
|
||||||
type: string
|
type: string
|
||||||
enabled: boolean
|
enabled: boolean
|
||||||
|
useTimeWeightedRetriever?: boolean
|
||||||
settings: any
|
settings: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -126,6 +126,7 @@ export default class JanAssistantExtension extends AssistantExtension {
|
|||||||
{
|
{
|
||||||
type: 'retrieval',
|
type: 'retrieval',
|
||||||
enabled: false,
|
enabled: false,
|
||||||
|
useTimeWeightedRetriever: false,
|
||||||
settings: {
|
settings: {
|
||||||
top_k: 2,
|
top_k: 2,
|
||||||
chunk_size: 1024,
|
chunk_size: 1024,
|
||||||
|
|||||||
@ -11,13 +11,14 @@ export function toolRetrievalUpdateTextSplitter(
|
|||||||
export async function toolRetrievalIngestNewDocument(
|
export async function toolRetrievalIngestNewDocument(
|
||||||
file: string,
|
file: string,
|
||||||
model: string,
|
model: string,
|
||||||
engine: string
|
engine: string,
|
||||||
|
useTimeWeighted: boolean
|
||||||
) {
|
) {
|
||||||
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
|
const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file))
|
||||||
const threadPath = path.dirname(filePath.replace('files', ''))
|
const threadPath = path.dirname(filePath.replace('files', ''))
|
||||||
retrieval.updateEmbeddingEngine(model, engine)
|
retrieval.updateEmbeddingEngine(model, engine)
|
||||||
return retrieval
|
return retrieval
|
||||||
.ingestAgentKnowledge(filePath, `${threadPath}/memory`)
|
.ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted)
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
})
|
})
|
||||||
@ -33,8 +34,11 @@ export async function toolRetrievalLoadThreadMemory(threadId: string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function toolRetrievalQueryResult(query: string) {
|
export async function toolRetrievalQueryResult(
|
||||||
return retrieval.generateResult(query).catch((err) => {
|
query: string,
|
||||||
|
useTimeWeighted: boolean = false
|
||||||
|
) {
|
||||||
|
return retrieval.generateResult(query, useTimeWeighted).catch((err) => {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,11 +2,16 @@ import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
|
|||||||
import { formatDocumentsAsString } from 'langchain/util/document'
|
import { formatDocumentsAsString } from 'langchain/util/document'
|
||||||
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
|
import { PDFLoader } from 'langchain/document_loaders/fs/pdf'
|
||||||
|
|
||||||
|
import { TimeWeightedVectorStoreRetriever } from 'langchain/retrievers/time_weighted'
|
||||||
|
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
|
||||||
|
|
||||||
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
|
import { HNSWLib } from 'langchain/vectorstores/hnswlib'
|
||||||
|
|
||||||
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
|
import { OpenAIEmbeddings } from 'langchain/embeddings/openai'
|
||||||
import { readEmbeddingEngine } from './engine'
|
import { readEmbeddingEngine } from './engine'
|
||||||
|
|
||||||
|
import path from 'path'
|
||||||
|
|
||||||
export class Retrieval {
|
export class Retrieval {
|
||||||
public chunkSize: number = 100
|
public chunkSize: number = 100
|
||||||
public chunkOverlap?: number = 0
|
public chunkOverlap?: number = 0
|
||||||
@ -15,8 +20,25 @@ export class Retrieval {
|
|||||||
private embeddingModel?: OpenAIEmbeddings = undefined
|
private embeddingModel?: OpenAIEmbeddings = undefined
|
||||||
private textSplitter?: RecursiveCharacterTextSplitter
|
private textSplitter?: RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
// to support time-weighted retrieval
|
||||||
|
private timeWeightedVectorStore: MemoryVectorStore
|
||||||
|
private timeWeightedretriever: any | TimeWeightedVectorStoreRetriever
|
||||||
|
|
||||||
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
|
constructor(chunkSize: number = 4000, chunkOverlap: number = 200) {
|
||||||
this.updateTextSplitter(chunkSize, chunkOverlap)
|
this.updateTextSplitter(chunkSize, chunkOverlap)
|
||||||
|
|
||||||
|
// declare time-weighted retriever and storage
|
||||||
|
this.timeWeightedVectorStore = new MemoryVectorStore(
|
||||||
|
new OpenAIEmbeddings(
|
||||||
|
{ openAIApiKey: 'nitro-embedding' },
|
||||||
|
{ basePath: 'http://127.0.0.1:3928/v1' }
|
||||||
|
)
|
||||||
|
)
|
||||||
|
this.timeWeightedretriever = new TimeWeightedVectorStoreRetriever({
|
||||||
|
vectorStore: this.timeWeightedVectorStore,
|
||||||
|
memoryStream: [],
|
||||||
|
searchKwargs: 2,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
|
public updateTextSplitter(chunkSize: number, chunkOverlap: number): void {
|
||||||
@ -44,11 +66,15 @@ export class Retrieval {
|
|||||||
openAIApiKey: settings.api_key,
|
openAIApiKey: settings.api_key,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update time-weighted embedding model
|
||||||
|
this.timeWeightedVectorStore.embeddings = this.embeddingModel
|
||||||
}
|
}
|
||||||
|
|
||||||
public ingestAgentKnowledge = async (
|
public ingestAgentKnowledge = async (
|
||||||
filePath: string,
|
filePath: string,
|
||||||
memoryPath: string
|
memoryPath: string,
|
||||||
|
useTimeWeighted: boolean
|
||||||
): Promise<any> => {
|
): Promise<any> => {
|
||||||
const loader = new PDFLoader(filePath, {
|
const loader = new PDFLoader(filePath, {
|
||||||
splitPages: true,
|
splitPages: true,
|
||||||
@ -57,6 +83,13 @@ export class Retrieval {
|
|||||||
const doc = await loader.load()
|
const doc = await loader.load()
|
||||||
const docs = await this.textSplitter!.splitDocuments(doc)
|
const docs = await this.textSplitter!.splitDocuments(doc)
|
||||||
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
|
const vectorStore = await HNSWLib.fromDocuments(docs, this.embeddingModel)
|
||||||
|
|
||||||
|
// add documents with metadata by using the time-weighted retriever in order to support time-weighted retrieval
|
||||||
|
if (useTimeWeighted && this.timeWeightedretriever) {
|
||||||
|
await (
|
||||||
|
this.timeWeightedretriever as TimeWeightedVectorStoreRetriever
|
||||||
|
).addDocuments(docs)
|
||||||
|
}
|
||||||
return vectorStore.save(memoryPath)
|
return vectorStore.save(memoryPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,10 +100,25 @@ export class Retrieval {
|
|||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
public generateResult = async (query: string): Promise<string> => {
|
public generateResult = async (
|
||||||
|
query: string,
|
||||||
|
useTimeWeighted: boolean
|
||||||
|
): Promise<string> => {
|
||||||
|
if (useTimeWeighted) {
|
||||||
|
if (!this.timeWeightedretriever) {
|
||||||
|
return Promise.resolve(' ')
|
||||||
|
}
|
||||||
|
// use invoke because getRelevantDocuments is deprecated
|
||||||
|
const relevantDocs = await this.timeWeightedretriever.invoke(query)
|
||||||
|
const serializedDoc = formatDocumentsAsString(relevantDocs)
|
||||||
|
return Promise.resolve(serializedDoc)
|
||||||
|
}
|
||||||
|
|
||||||
if (!this.retriever) {
|
if (!this.retriever) {
|
||||||
return Promise.resolve(' ')
|
return Promise.resolve(' ')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// should use invoke(query) because getRelevantDocuments is deprecated
|
||||||
const relevantDocs = await this.retriever.getRelevantDocuments(query)
|
const relevantDocs = await this.retriever.getRelevantDocuments(query)
|
||||||
const serializedDoc = formatDocumentsAsString(relevantDocs)
|
const serializedDoc = formatDocumentsAsString(relevantDocs)
|
||||||
return Promise.resolve(serializedDoc)
|
return Promise.resolve(serializedDoc)
|
||||||
|
|||||||
@ -37,7 +37,8 @@ export class RetrievalTool extends InferenceTool {
|
|||||||
'toolRetrievalIngestNewDocument',
|
'toolRetrievalIngestNewDocument',
|
||||||
docFile,
|
docFile,
|
||||||
data.model?.id,
|
data.model?.id,
|
||||||
data.model?.engine
|
data.model?.engine,
|
||||||
|
tool?.useTimeWeightedRetriever ?? false
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
return Promise.resolve(data)
|
return Promise.resolve(data)
|
||||||
@ -78,7 +79,8 @@ export class RetrievalTool extends InferenceTool {
|
|||||||
const retrievalResult = await executeOnMain(
|
const retrievalResult = await executeOnMain(
|
||||||
NODE,
|
NODE,
|
||||||
'toolRetrievalQueryResult',
|
'toolRetrievalQueryResult',
|
||||||
prompt
|
prompt,
|
||||||
|
tool?.useTimeWeightedRetriever ?? false
|
||||||
)
|
)
|
||||||
console.debug('toolRetrievalQueryResult', retrievalResult)
|
console.debug('toolRetrievalQueryResult', retrievalResult)
|
||||||
|
|
||||||
|
|||||||
@ -66,6 +66,32 @@ const Tools = () => {
|
|||||||
[activeThread, updateThreadMetadata]
|
[activeThread, updateThreadMetadata]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const onTimeWeightedRetrieverSwitchUpdate = useCallback(
|
||||||
|
(enabled: boolean) => {
|
||||||
|
if (!activeThread) return
|
||||||
|
updateThreadMetadata({
|
||||||
|
...activeThread,
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
...activeThread.assistants[0],
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
type: 'retrieval',
|
||||||
|
enabled: true,
|
||||||
|
useTimeWeightedRetriever: enabled,
|
||||||
|
settings:
|
||||||
|
(activeThread.assistants[0].tools &&
|
||||||
|
activeThread.assistants[0].tools[0]?.settings) ??
|
||||||
|
{},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[activeThread, updateThreadMetadata]
|
||||||
|
)
|
||||||
|
|
||||||
if (!experimentalFeature) return null
|
if (!experimentalFeature) return null
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -143,6 +169,46 @@ const Tools = () => {
|
|||||||
className="inline-block font-medium"
|
className="inline-block font-medium"
|
||||||
>
|
>
|
||||||
Vector Database
|
Vector Database
|
||||||
|
<Tooltip
|
||||||
|
trigger={
|
||||||
|
<InfoIcon
|
||||||
|
size={16}
|
||||||
|
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
content="Vector Database is crucial for efficient storage
|
||||||
|
and retrieval of embeddings. Consider your
|
||||||
|
specific task, available resources, and language
|
||||||
|
requirements. Experiment to find the best fit for
|
||||||
|
your specific use case."
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
<div className="ml-auto flex items-center justify-between">
|
||||||
|
<Switch
|
||||||
|
name="use-time-weighted-retriever"
|
||||||
|
className="mr-2"
|
||||||
|
checked={
|
||||||
|
activeThread?.assistants[0].tools[0]
|
||||||
|
.useTimeWeightedRetriever || false
|
||||||
|
}
|
||||||
|
onChange={(e) =>
|
||||||
|
onTimeWeightedRetrieverSwitchUpdate(e.target.checked)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="w-full">
|
||||||
|
<Input value="HNSWLib" disabled readOnly />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="mb-4">
|
||||||
|
<div className="mb-2 flex items-center">
|
||||||
|
<label
|
||||||
|
id="use-time-weighted-retriever"
|
||||||
|
className="inline-block font-medium"
|
||||||
|
>
|
||||||
|
Time-Weighted Retrieval?
|
||||||
</label>
|
</label>
|
||||||
<Tooltip
|
<Tooltip
|
||||||
trigger={
|
trigger={
|
||||||
@ -151,17 +217,13 @@ const Tools = () => {
|
|||||||
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]"
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
content="Vector Database is crucial for efficient storage
|
content="Time-Weighted Retriever looks at how similar
|
||||||
and retrieval of embeddings. Consider your
|
they are and how new they are. It compares
|
||||||
specific task, available resources, and language
|
documents based on their meaning like usual, but
|
||||||
requirements. Experiment to find the best fit for
|
also considers when they were added to give
|
||||||
your specific use case."
|
newer ones more importance."
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="w-full">
|
|
||||||
<Input value="HNSWLib" disabled readOnly />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
<AssistantSetting
|
<AssistantSetting
|
||||||
componentData={componentDataAssistantSetting}
|
componentData={componentDataAssistantSetting}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user