From d489a2120f429b9241bffa86eaf8add766f4f9ce Mon Sep 17 00:00:00 2001 From: Yi Ding <yi.s.ding@gmail.com> Date: Thu, 20 Jul 2023 17:15:11 -0700 Subject: [PATCH] allow asQueryEngine and asRetriever to take options --- apps/simple/listIndex.ts | 4 ++- ...eOpenAIStream.ts => handleOpenAIStream.ts} | 2 +- packages/core/src/indices/BaseIndex.ts | 18 +++++++++++- packages/core/src/indices/list/ListIndex.ts | 28 +++++++++++-------- .../vectorStore/VectorIndexRetriever.ts | 12 ++++++-- .../indices/vectorStore/VectorStoreIndex.ts | 28 ++++++++----------- packages/core/src/llm/LLM.ts | 4 +-- .../core/src/tests/CallbackManager.test.ts | 7 ++--- 8 files changed, 64 insertions(+), 39 deletions(-) rename packages/core/src/callbacks/utility/{aHandleOpenAIStream.ts => handleOpenAIStream.ts} (97%) diff --git a/apps/simple/listIndex.ts b/apps/simple/listIndex.ts index 1748496ef..982f43471 100644 --- a/apps/simple/listIndex.ts +++ b/apps/simple/listIndex.ts @@ -18,7 +18,9 @@ async function main() { documents: [document], serviceContext, }); - const queryEngine = index.asQueryEngine(ListRetrieverMode.LLM); + const queryEngine = index.asQueryEngine({ + retriever: index.asRetriever({ mode: ListRetrieverMode.LLM }), + }); const response = await queryEngine.query( "What did the author do growing up?" ); diff --git a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts b/packages/core/src/callbacks/utility/handleOpenAIStream.ts similarity index 97% rename from packages/core/src/callbacks/utility/aHandleOpenAIStream.ts rename to packages/core/src/callbacks/utility/handleOpenAIStream.ts index b47780608..83b742b1d 100644 --- a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts +++ b/packages/core/src/callbacks/utility/handleOpenAIStream.ts @@ -2,7 +2,7 @@ import { globalsHelper } from "../../GlobalsHelper"; import { StreamCallbackResponse, Event } from "../CallbackManager"; import { StreamToken } from "../CallbackManager"; -export async function aHandleOpenAIStream({ +export async function handleOpenAIStream({ response, onLLMStream, parentEvent, diff --git a/packages/core/src/indices/BaseIndex.ts b/packages/core/src/indices/BaseIndex.ts index 043f18e12..899ba8163 100644 --- a/packages/core/src/indices/BaseIndex.ts +++ b/packages/core/src/indices/BaseIndex.ts @@ -6,6 +6,8 @@ import { StorageContext } from "../storage/StorageContext"; import { BaseDocumentStore } from "../storage/docStore/types"; import { VectorStore } from "../storage/vectorStore/types"; import { BaseIndexStore } from "../storage/indexStore/types"; +import { BaseQueryEngine } from "../QueryEngine"; +import { ResponseSynthesizer } from "../ResponseSynthesizer"; /** * The underlying structure of each index. @@ -82,7 +84,21 @@ export abstract class BaseIndex<T> { this.indexStruct = init.indexStruct; } - abstract asRetriever(): BaseRetriever; + /** + * Create a new retriever from the index. + * @param retrieverOptions + */ + abstract asRetriever(options?: any): BaseRetriever; + + /** + * Create a new query engine from the index. It will also create a retriever + * and response synthezier if they are not provided. + * @param options you can supply your own custom Retriever and ResponseSynthesizer + */ + abstract asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine; } export interface VectorIndexOptions { diff --git a/packages/core/src/indices/list/ListIndex.ts b/packages/core/src/indices/list/ListIndex.ts index 23d178f7d..048264b3a 100644 --- a/packages/core/src/indices/list/ListIndex.ts +++ b/packages/core/src/indices/list/ListIndex.ts @@ -102,9 +102,9 @@ export class ListIndex extends BaseIndex<IndexList> { return index; } - asRetriever( - mode: ListRetrieverMode = ListRetrieverMode.DEFAULT - ): BaseRetriever { + asRetriever(options?: { mode: ListRetrieverMode }): BaseRetriever { + const { mode = ListRetrieverMode.DEFAULT } = options ?? {}; + switch (mode) { case ListRetrieverMode.DEFAULT: return new ListIndexRetriever(this); @@ -115,21 +115,25 @@ export class ListIndex extends BaseIndex<IndexList> { } } - asQueryEngine( - mode: ListRetrieverMode = ListRetrieverMode.DEFAULT, - responseSynthesizer?: ResponseSynthesizer - ): BaseQueryEngine { - if (_.isNil(responseSynthesizer)) { + asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine { + let { retriever, responseSynthesizer } = options ?? {}; + + if (!retriever) { + retriever = this.asRetriever(); + } + + if (!responseSynthesizer) { let responseBuilder = new CompactAndRefine(this.serviceContext); responseSynthesizer = new ResponseSynthesizer({ serviceContext: this.serviceContext, responseBuilder, }); } - return new RetrieverQueryEngine( - this.asRetriever(mode), - responseSynthesizer - ); + + return new RetrieverQueryEngine(retriever, responseSynthesizer); } static async _buildIndexFromNodes( diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index 862481640..eb5e57f23 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -16,12 +16,20 @@ import { BaseRetriever } from "../../Retriever"; export class VectorIndexRetriever implements BaseRetriever { index: VectorStoreIndex; - similarityTopK = DEFAULT_SIMILARITY_TOP_K; + similarityTopK; private serviceContext: ServiceContext; - constructor(index: VectorStoreIndex) { + constructor({ + index, + similarityTopK, + }: { + index: VectorStoreIndex; + similarityTopK?: number; + }) { this.index = index; this.serviceContext = this.index.serviceContext; + + this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K; } async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index e9c9631e7..d3153b358 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -21,6 +21,8 @@ import { VectorIndexConstructorProps, VectorIndexOptions, } from "../BaseIndex"; +import { BaseRetriever } from "../../Retriever"; +import { ResponseSynthesizer } from "../../ResponseSynthesizer"; /** * The VectorStoreIndex, an index that stores the nodes only according to their vector embedings. @@ -156,23 +158,17 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return index; } - /** - * Get a VectorIndexRetriever for this index. - * - * NOTE: if you want to use a custom retriever you don't have to use this method. - * @returns retriever for the index - */ - asRetriever(): VectorIndexRetriever { - return new VectorIndexRetriever(this); + asRetriever(options?: any): VectorIndexRetriever { + return new VectorIndexRetriever({ index: this, ...options }); } - /** - * Get a retriever query engine for this index. - * - * NOTE: if you are using a custom query engine you don't have to use this method. - * @returns - */ - asQueryEngine(): BaseQueryEngine { - return new RetrieverQueryEngine(this.asRetriever()); + asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine { + let { retriever, responseSynthesizer } = options ?? {}; + + retriever = retriever ?? this.asRetriever(); + return new RetrieverQueryEngine(this.asRetriever(), responseSynthesizer); } } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 839deabc6..9caa2cbe6 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -1,5 +1,5 @@ import { CallbackManager, Event } from "../callbacks/CallbackManager"; -import { aHandleOpenAIStream } from "../callbacks/utility/aHandleOpenAIStream"; +import { handleOpenAIStream } from "../callbacks/utility/handleOpenAIStream"; import { ChatCompletionRequestMessageRoleEnum, CreateChatCompletionRequest, @@ -124,7 +124,7 @@ export class OpenAI implements LLM { { responseType: "stream" } ); - const fullResponse = await aHandleOpenAIStream({ + const fullResponse = await handleOpenAIStream({ response, onLLMStream: this.callbackManager.onLLMStream, parentEvent, diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 6471067bf..cb82242af 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -147,10 +147,9 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { serviceContext: serviceContext, responseBuilder, }); - const queryEngine = listIndex.asQueryEngine( - ListRetrieverMode.DEFAULT, - responseSynthesizer - ); + const queryEngine = listIndex.asQueryEngine({ + responseSynthesizer, + }); const query = "What is the author's name?"; const response = await queryEngine.query(query); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2"); -- GitLab