From d9bcf4df925b8bfe3181835f7b37c8205fd06d4e Mon Sep 17 00:00:00 2001 From: Louis de Courcel <einsenhorn@gmail.com> Date: Wed, 27 Sep 2023 01:14:19 +0200 Subject: [PATCH] impr: add fromVectorStore method --- packages/core/src/QueryEngine.ts | 11 ++++++++-- packages/core/src/ResponseSynthesizer.ts | 13 +++++++----- packages/core/src/Retriever.ts | 8 +++++-- .../vectorStore/VectorIndexRetriever.ts | 13 ++++++++---- .../indices/vectorStore/VectorStoreIndex.ts | 21 +++++++++++++++++++ 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index 500f97356..daad1e6b3 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,4 +1,5 @@ import { v4 as uuidv4 } from "uuid"; +import { Event } from "./callbacks/CallbackManager"; import { NodeWithScore, TextNode } from "./Node"; import { BaseQuestionGenerator, @@ -10,7 +11,6 @@ import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { QueryEngineTool, ToolMetadata } from "./Tool"; -import { Event } from "./callbacks/CallbackManager"; /** * A query engine is a question answerer that can use one or more steps. @@ -30,16 +30,19 @@ export interface BaseQueryEngine { export class RetrieverQueryEngine implements BaseQueryEngine { retriever: BaseRetriever; responseSynthesizer: ResponseSynthesizer; + preFilters?: unknown; constructor( retriever: BaseRetriever, responseSynthesizer?: ResponseSynthesizer, + preFilters?: unknown, ) { this.retriever = retriever; const serviceContext: ServiceContext | undefined = this.retriever.getServiceContext(); this.responseSynthesizer = responseSynthesizer || new ResponseSynthesizer({ serviceContext }); + this.preFilters = preFilters; } async query(query: string, parentEvent?: Event) { @@ -48,7 +51,11 @@ export class RetrieverQueryEngine implements BaseQueryEngine { type: "wrapper", tags: ["final"], }; - const nodes = await this.retriever.retrieve(query, _parentEvent); + const nodes = await this.retriever.retrieve( + query, + _parentEvent, + this.preFilters, + ); return this.responseSynthesizer.synthesize(query, nodes, _parentEvent); } } diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index 912c02516..f3151a6dc 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,18 +1,18 @@ +import { Event } from "./callbacks/CallbackManager"; +import { LLM } from "./llm/LLM"; import { MetadataMode, NodeWithScore } from "./Node"; import { + defaultRefinePrompt, + defaultTextQaPrompt, + defaultTreeSummarizePrompt, RefinePrompt, SimplePrompt, TextQaPrompt, TreeSummarizePrompt, - defaultRefinePrompt, - defaultTextQaPrompt, - defaultTreeSummarizePrompt, } from "./Prompt"; import { getBiggestPrompt } from "./PromptHelper"; import { Response } from "./Response"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; -import { Event } from "./callbacks/CallbackManager"; -import { LLM } from "./llm/LLM"; /** * Response modes of the response synthesizer @@ -231,6 +231,7 @@ export class TreeSummarize implements BaseResponseBuilder { throw new Error("Must have at least one text chunk"); } + // Should we send the query here too? const packedTextChunks = this.serviceContext.promptHelper.repack( this.summaryTemplate, textChunks, @@ -241,6 +242,7 @@ export class TreeSummarize implements BaseResponseBuilder { await this.serviceContext.llm.complete( this.summaryTemplate({ context: packedTextChunks[0], + query, }), parentEvent, ) @@ -251,6 +253,7 @@ export class TreeSummarize implements BaseResponseBuilder { this.serviceContext.llm.complete( this.summaryTemplate({ context: chunk, + query, }), parentEvent, ), diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index 303d0fa54..6b0f1024d 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,11 +1,15 @@ +import { Event } from "./callbacks/CallbackManager"; import { NodeWithScore } from "./Node"; import { ServiceContext } from "./ServiceContext"; -import { Event } from "./callbacks/CallbackManager"; /** * Retrievers retrieve the nodes that most closely match our query in similarity. */ export interface BaseRetriever { - retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]>; + retrieve( + query: string, + parentEvent?: Event, + preFilters?: unknown, + ): Promise<NodeWithScore[]>; getServiceContext(): ServiceContext; } diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index d2bc4dbb3..8e3bff927 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -1,9 +1,9 @@ +import { Event } from "../../callbacks/CallbackManager"; +import { DEFAULT_SIMILARITY_TOP_K } from "../../constants"; import { globalsHelper } from "../../GlobalsHelper"; import { NodeWithScore } from "../../Node"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext } from "../../ServiceContext"; -import { Event } from "../../callbacks/CallbackManager"; -import { DEFAULT_SIMILARITY_TOP_K } from "../../constants"; import { VectorStoreQuery, VectorStoreQueryMode, @@ -32,7 +32,7 @@ export class VectorIndexRetriever implements BaseRetriever { this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K; } - async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { + async retrieve(query: string, parentEvent?: Event, preFilters?: unknown): Promise<NodeWithScore[]> { const queryEmbedding = await this.serviceContext.embedModel.getQueryEmbedding(query); @@ -41,10 +41,15 @@ export class VectorIndexRetriever implements BaseRetriever { mode: VectorStoreQueryMode.DEFAULT, similarityTopK: this.similarityTopK, }; - const result = await this.index.vectorStore.query(q); + const result = await this.index.vectorStore.query(q, preFilters); let nodesWithScores: NodeWithScore[] = []; for (let i = 0; i < result.ids.length; i++) { + const nodeFromResult = result.nodes?.[i]; + if (!this.index.indexStruct.nodesDict[result.ids[i]] && nodeFromResult) { + this.index.indexStruct.nodesDict[result.ids[i]] = nodeFromResult; + } + const node = this.index.indexStruct.nodesDict[result.ids[i]]; nodesWithScores.push({ node: node, diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 305f8e744..d2452962d 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -219,6 +219,27 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return index; } + static async fromVectorStore( + vectorStore: VectorStore, + serviceContext: ServiceContext, + ) { + if (!vectorStore.storesText) { + throw new Error( + "Cannot initialize from a vector store that does not store text", + ); + } + + const storageContext = await storageContextFromDefaults({ vectorStore }); + + const index = await VectorStoreIndex.init({ + nodes: [], + storageContext, + serviceContext, + }); + + return index; + } + asRetriever(options?: any): VectorIndexRetriever { return new VectorIndexRetriever({ index: this, ...options }); } -- GitLab