diff --git a/apps/simple/listIndex.ts b/apps/simple/listIndex.ts index 1748496ef38d95446db331527d0dcbc962f552ef..982f434717dc28141316603122aa5b70ee8c9f07 100644 --- a/apps/simple/listIndex.ts +++ b/apps/simple/listIndex.ts @@ -18,7 +18,9 @@ async function main() { documents: [document], serviceContext, }); - const queryEngine = index.asQueryEngine(ListRetrieverMode.LLM); + const queryEngine = index.asQueryEngine({ + retriever: index.asRetriever({ mode: ListRetrieverMode.LLM }), + }); const response = await queryEngine.query( "What did the author do growing up?" ); diff --git a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts b/packages/core/src/callbacks/utility/handleOpenAIStream.ts similarity index 97% rename from packages/core/src/callbacks/utility/aHandleOpenAIStream.ts rename to packages/core/src/callbacks/utility/handleOpenAIStream.ts index b477806086ff09c75f55f4eb3e71ac33ba3c82ed..83b742b1d4bb5cb79e4d49f8a76e17ca33a00485 100644 --- a/packages/core/src/callbacks/utility/aHandleOpenAIStream.ts +++ b/packages/core/src/callbacks/utility/handleOpenAIStream.ts @@ -2,7 +2,7 @@ import { globalsHelper } from "../../GlobalsHelper"; import { StreamCallbackResponse, Event } from "../CallbackManager"; import { StreamToken } from "../CallbackManager"; -export async function aHandleOpenAIStream({ +export async function handleOpenAIStream({ response, onLLMStream, parentEvent, diff --git a/packages/core/src/indices/BaseIndex.ts b/packages/core/src/indices/BaseIndex.ts index 043f18e129b14c8cb4323fd627e2953cff0a3191..899ba8163ee44d1886d39a8c724be9761845fab6 100644 --- a/packages/core/src/indices/BaseIndex.ts +++ b/packages/core/src/indices/BaseIndex.ts @@ -6,6 +6,8 @@ import { StorageContext } from "../storage/StorageContext"; import { BaseDocumentStore } from "../storage/docStore/types"; import { VectorStore } from "../storage/vectorStore/types"; import { BaseIndexStore } from "../storage/indexStore/types"; +import { BaseQueryEngine } from "../QueryEngine"; +import { ResponseSynthesizer } from "../ResponseSynthesizer"; /** * The underlying structure of each index. @@ -82,7 +84,21 @@ export abstract class BaseIndex<T> { this.indexStruct = init.indexStruct; } - abstract asRetriever(): BaseRetriever; + /** + * Create a new retriever from the index. + * @param retrieverOptions + */ + abstract asRetriever(options?: any): BaseRetriever; + + /** + * Create a new query engine from the index. It will also create a retriever + * and response synthezier if they are not provided. + * @param options you can supply your own custom Retriever and ResponseSynthesizer + */ + abstract asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine; } export interface VectorIndexOptions { diff --git a/packages/core/src/indices/list/ListIndex.ts b/packages/core/src/indices/list/ListIndex.ts index 23d178f7dab682cd1b3515e4ad2cc2b21877f5d3..048264b3ae73c0d98b8c2affaae5adb804da57c9 100644 --- a/packages/core/src/indices/list/ListIndex.ts +++ b/packages/core/src/indices/list/ListIndex.ts @@ -102,9 +102,9 @@ export class ListIndex extends BaseIndex<IndexList> { return index; } - asRetriever( - mode: ListRetrieverMode = ListRetrieverMode.DEFAULT - ): BaseRetriever { + asRetriever(options?: { mode: ListRetrieverMode }): BaseRetriever { + const { mode = ListRetrieverMode.DEFAULT } = options ?? {}; + switch (mode) { case ListRetrieverMode.DEFAULT: return new ListIndexRetriever(this); @@ -115,21 +115,25 @@ export class ListIndex extends BaseIndex<IndexList> { } } - asQueryEngine( - mode: ListRetrieverMode = ListRetrieverMode.DEFAULT, - responseSynthesizer?: ResponseSynthesizer - ): BaseQueryEngine { - if (_.isNil(responseSynthesizer)) { + asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine { + let { retriever, responseSynthesizer } = options ?? {}; + + if (!retriever) { + retriever = this.asRetriever(); + } + + if (!responseSynthesizer) { let responseBuilder = new CompactAndRefine(this.serviceContext); responseSynthesizer = new ResponseSynthesizer({ serviceContext: this.serviceContext, responseBuilder, }); } - return new RetrieverQueryEngine( - this.asRetriever(mode), - responseSynthesizer - ); + + return new RetrieverQueryEngine(retriever, responseSynthesizer); } static async _buildIndexFromNodes( diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index 862481640b1d8e6b509adaa7e1ce165e1998248d..eb5e57f23e91c0616437433bdc38f9e434cbc1e4 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -16,12 +16,20 @@ import { BaseRetriever } from "../../Retriever"; export class VectorIndexRetriever implements BaseRetriever { index: VectorStoreIndex; - similarityTopK = DEFAULT_SIMILARITY_TOP_K; + similarityTopK; private serviceContext: ServiceContext; - constructor(index: VectorStoreIndex) { + constructor({ + index, + similarityTopK, + }: { + index: VectorStoreIndex; + similarityTopK?: number; + }) { this.index = index; this.serviceContext = this.index.serviceContext; + + this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K; } async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> { diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index e9c9631e7eeeb60eece7d22c120c13db10382031..d3153b358a7d3061c5fe4f0cc42282868da1ed4c 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -21,6 +21,8 @@ import { VectorIndexConstructorProps, VectorIndexOptions, } from "../BaseIndex"; +import { BaseRetriever } from "../../Retriever"; +import { ResponseSynthesizer } from "../../ResponseSynthesizer"; /** * The VectorStoreIndex, an index that stores the nodes only according to their vector embedings. @@ -156,23 +158,17 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return index; } - /** - * Get a VectorIndexRetriever for this index. - * - * NOTE: if you want to use a custom retriever you don't have to use this method. - * @returns retriever for the index - */ - asRetriever(): VectorIndexRetriever { - return new VectorIndexRetriever(this); + asRetriever(options?: any): VectorIndexRetriever { + return new VectorIndexRetriever({ index: this, ...options }); } - /** - * Get a retriever query engine for this index. - * - * NOTE: if you are using a custom query engine you don't have to use this method. - * @returns - */ - asQueryEngine(): BaseQueryEngine { - return new RetrieverQueryEngine(this.asRetriever()); + asQueryEngine(options?: { + retriever?: BaseRetriever; + responseSynthesizer?: ResponseSynthesizer; + }): BaseQueryEngine { + let { retriever, responseSynthesizer } = options ?? {}; + + retriever = retriever ?? this.asRetriever(); + return new RetrieverQueryEngine(this.asRetriever(), responseSynthesizer); } } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 839deabc61dc200ef29466d70a517bdc1a654a2f..9caa2cbe60bfadb6acacdde1cc4c43fa7d9b7e4f 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -1,5 +1,5 @@ import { CallbackManager, Event } from "../callbacks/CallbackManager"; -import { aHandleOpenAIStream } from "../callbacks/utility/aHandleOpenAIStream"; +import { handleOpenAIStream } from "../callbacks/utility/handleOpenAIStream"; import { ChatCompletionRequestMessageRoleEnum, CreateChatCompletionRequest, @@ -124,7 +124,7 @@ export class OpenAI implements LLM { { responseType: "stream" } ); - const fullResponse = await aHandleOpenAIStream({ + const fullResponse = await handleOpenAIStream({ response, onLLMStream: this.callbackManager.onLLMStream, parentEvent, diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 6471067bfc21472e26666ae125b3a4d9c9581ad8..cb82242af038d0a12413f68ed1caa7f0c45dadc4 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -147,10 +147,9 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { serviceContext: serviceContext, responseBuilder, }); - const queryEngine = listIndex.asQueryEngine( - ListRetrieverMode.DEFAULT, - responseSynthesizer - ); + const queryEngine = listIndex.asQueryEngine({ + responseSynthesizer, + }); const query = "What is the author's name?"; const response = await queryEngine.query(query); expect(response.toString()).toBe("MOCK_TOKEN_1-MOCK_TOKEN_2");