From d9bcf4df925b8bfe3181835f7b37c8205fd06d4e Mon Sep 17 00:00:00 2001
From: Louis de Courcel <einsenhorn@gmail.com>
Date: Wed, 27 Sep 2023 01:14:19 +0200
Subject: [PATCH] impr: add fromVectorStore method

---
 packages/core/src/QueryEngine.ts              | 11 ++++++++--
 packages/core/src/ResponseSynthesizer.ts      | 13 +++++++-----
 packages/core/src/Retriever.ts                |  8 +++++--
 .../vectorStore/VectorIndexRetriever.ts       | 13 ++++++++----
 .../indices/vectorStore/VectorStoreIndex.ts   | 21 +++++++++++++++++++
 5 files changed, 53 insertions(+), 13 deletions(-)

diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts
index 500f97356..daad1e6b3 100644
--- a/packages/core/src/QueryEngine.ts
+++ b/packages/core/src/QueryEngine.ts
@@ -1,4 +1,5 @@
 import { v4 as uuidv4 } from "uuid";
+import { Event } from "./callbacks/CallbackManager";
 import { NodeWithScore, TextNode } from "./Node";
 import {
   BaseQuestionGenerator,
@@ -10,7 +11,6 @@ import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
 import { BaseRetriever } from "./Retriever";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
 import { QueryEngineTool, ToolMetadata } from "./Tool";
-import { Event } from "./callbacks/CallbackManager";
 
 /**
  * A query engine is a question answerer that can use one or more steps.
@@ -30,16 +30,19 @@ export interface BaseQueryEngine {
 export class RetrieverQueryEngine implements BaseQueryEngine {
   retriever: BaseRetriever;
   responseSynthesizer: ResponseSynthesizer;
+  preFilters?: unknown;
 
   constructor(
     retriever: BaseRetriever,
     responseSynthesizer?: ResponseSynthesizer,
+    preFilters?: unknown,
   ) {
     this.retriever = retriever;
     const serviceContext: ServiceContext | undefined =
       this.retriever.getServiceContext();
     this.responseSynthesizer =
       responseSynthesizer || new ResponseSynthesizer({ serviceContext });
+    this.preFilters = preFilters;
   }
 
   async query(query: string, parentEvent?: Event) {
@@ -48,7 +51,11 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
       type: "wrapper",
       tags: ["final"],
     };
-    const nodes = await this.retriever.retrieve(query, _parentEvent);
+    const nodes = await this.retriever.retrieve(
+      query,
+      _parentEvent,
+      this.preFilters,
+    );
     return this.responseSynthesizer.synthesize(query, nodes, _parentEvent);
   }
 }
diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts
index 912c02516..f3151a6dc 100644
--- a/packages/core/src/ResponseSynthesizer.ts
+++ b/packages/core/src/ResponseSynthesizer.ts
@@ -1,18 +1,18 @@
+import { Event } from "./callbacks/CallbackManager";
+import { LLM } from "./llm/LLM";
 import { MetadataMode, NodeWithScore } from "./Node";
 import {
+  defaultRefinePrompt,
+  defaultTextQaPrompt,
+  defaultTreeSummarizePrompt,
   RefinePrompt,
   SimplePrompt,
   TextQaPrompt,
   TreeSummarizePrompt,
-  defaultRefinePrompt,
-  defaultTextQaPrompt,
-  defaultTreeSummarizePrompt,
 } from "./Prompt";
 import { getBiggestPrompt } from "./PromptHelper";
 import { Response } from "./Response";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
-import { Event } from "./callbacks/CallbackManager";
-import { LLM } from "./llm/LLM";
 
 /**
  * Response modes of the response synthesizer
@@ -231,6 +231,7 @@ export class TreeSummarize implements BaseResponseBuilder {
       throw new Error("Must have at least one text chunk");
     }
 
+    // Should we send the query here too?
     const packedTextChunks = this.serviceContext.promptHelper.repack(
       this.summaryTemplate,
       textChunks,
@@ -241,6 +242,7 @@ export class TreeSummarize implements BaseResponseBuilder {
         await this.serviceContext.llm.complete(
           this.summaryTemplate({
             context: packedTextChunks[0],
+            query,
           }),
           parentEvent,
         )
@@ -251,6 +253,7 @@ export class TreeSummarize implements BaseResponseBuilder {
           this.serviceContext.llm.complete(
             this.summaryTemplate({
               context: chunk,
+              query,
             }),
             parentEvent,
           ),
diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts
index 303d0fa54..6b0f1024d 100644
--- a/packages/core/src/Retriever.ts
+++ b/packages/core/src/Retriever.ts
@@ -1,11 +1,15 @@
+import { Event } from "./callbacks/CallbackManager";
 import { NodeWithScore } from "./Node";
 import { ServiceContext } from "./ServiceContext";
-import { Event } from "./callbacks/CallbackManager";
 
 /**
  * Retrievers retrieve the nodes that most closely match our query in similarity.
  */
 export interface BaseRetriever {
-  retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>;
+  retrieve(
+    query: string,
+    parentEvent?: Event,
+    preFilters?: unknown,
+  ): Promise<NodeWithScore[]>;
   getServiceContext(): ServiceContext;
 }
diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts
index d2bc4dbb3..8e3bff927 100644
--- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts
+++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts
@@ -1,9 +1,9 @@
+import { Event } from "../../callbacks/CallbackManager";
+import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
 import { globalsHelper } from "../../GlobalsHelper";
 import { NodeWithScore } from "../../Node";
 import { BaseRetriever } from "../../Retriever";
 import { ServiceContext } from "../../ServiceContext";
-import { Event } from "../../callbacks/CallbackManager";
-import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
 import {
   VectorStoreQuery,
   VectorStoreQueryMode,
@@ -32,7 +32,7 @@ export class VectorIndexRetriever implements BaseRetriever {
     this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
   }
 
-  async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> {
+  async retrieve(query: string, parentEvent?: Event, preFilters?: unknown): Promise<NodeWithScore[]> {
     const queryEmbedding =
       await this.serviceContext.embedModel.getQueryEmbedding(query);
 
@@ -41,10 +41,15 @@ export class VectorIndexRetriever implements BaseRetriever {
       mode: VectorStoreQueryMode.DEFAULT,
       similarityTopK: this.similarityTopK,
     };
-    const result = await this.index.vectorStore.query(q);
+    const result = await this.index.vectorStore.query(q, preFilters);
 
     let nodesWithScores: NodeWithScore[] = [];
     for (let i = 0; i < result.ids.length; i++) {
+      const nodeFromResult = result.nodes?.[i];
+      if (!this.index.indexStruct.nodesDict[result.ids[i]] && nodeFromResult) {
+        this.index.indexStruct.nodesDict[result.ids[i]] = nodeFromResult;
+      }
+
       const node = this.index.indexStruct.nodesDict[result.ids[i]];
       nodesWithScores.push({
         node: node,
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index 305f8e744..d2452962d 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -219,6 +219,27 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     return index;
   }
 
+  static async fromVectorStore(
+    vectorStore: VectorStore,
+    serviceContext: ServiceContext,
+  ) {
+    if (!vectorStore.storesText) {
+      throw new Error(
+        "Cannot initialize from a vector store that does not store text",
+      );
+    }
+
+    const storageContext = await storageContextFromDefaults({ vectorStore });
+
+    const index = await VectorStoreIndex.init({
+      nodes: [],
+      storageContext,
+      serviceContext,
+    });
+
+    return index;
+  }
+
   asRetriever(options?: any): VectorIndexRetriever {
     return new VectorIndexRetriever({ index: this, ...options });
   }
-- 
GitLab