diff --git a/.gitignore b/.gitignore
index 17d753d6a0bbc64eb30c7a71beeeefa9bc6d37ba..ce1b05f6c4f7b78c38c9991f1c4154e73d13f5f9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -129,3 +129,6 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# OS generated files
+.DS_Store
diff --git a/gpt_index/indices/base.py b/gpt_index/indices/base.py
index a444323f635a22ddfcbae62f97aa09acb3bdf195..a94ac6e4abfc6c5075fa6f5111484180df11635b 100644
--- a/gpt_index/indices/base.py
+++ b/gpt_index/indices/base.py
@@ -14,21 +14,16 @@ from typing import (
     cast,
 )
 
-from gpt_index.indices.data_structs import IndexStruct
+from gpt_index.indices.data_structs import IndexStruct, IndexStructType
 from gpt_index.indices.prompt_helper import PromptHelper
-from gpt_index.indices.query.base import BaseGPTIndexQuery
 from gpt_index.indices.query.query_runner import QueryRunner
+from gpt_index.indices.query.schema import QueryConfig, QueryMode
 from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
 from gpt_index.schema import BaseDocument, DocumentStore
 from gpt_index.utils import llm_token_counter
 
 IS = TypeVar("IS", bound=IndexStruct)
 
-# TODO: remove and consolidate with QueryMode
-DEFAULT_MODE = "default"
-EMBEDDING_MODE = "embedding"
-SUMMARIZE_MODE = "summarize"
-
 
 DOCUMENTS_INPUT = Union[BaseDocument, "BaseGPTIndex"]
 
@@ -157,15 +152,22 @@ class BaseGPTIndex(Generic[IS]):
     def delete(self, document: BaseDocument) -> None:
         """Delete a document."""
 
-    @abstractmethod
-    def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery:
-        """Query mode to class."""
+    def _preprocess_query(self, mode: QueryMode, query_kwargs: Dict) -> None:
+        """Preprocess query.
+
+        This allows subclasses to pass in additional query kwargs
+        to query, for instance arguments that are shared between the
+        index and the query class. By default, this does nothing.
+        This also allows subclasses to do validation.
+
+        """
+        pass
 
     def query(
         self,
         query_str: str,
         verbose: bool = False,
-        mode: str = DEFAULT_MODE,
+        mode: str = QueryMode.DEFAULT,
         **query_kwargs: Any
     ) -> str:
         """Answer a query.
@@ -179,8 +181,9 @@ class BaseGPTIndex(Generic[IS]):
 
 
         """
+        mode_enum = QueryMode(mode)
         # TODO: remove _mode_to_query and consolidate with query_runner
-        if mode == "recursive":
+        if mode_enum == QueryMode.RECURSIVE:
             if "query_configs" not in query_kwargs:
                 raise ValueError("query_configs must be provided for recursive mode.")
             query_configs = query_kwargs["query_configs"]
@@ -189,16 +192,25 @@ class BaseGPTIndex(Generic[IS]):
                 self._docstore,
                 query_configs=query_configs,
                 verbose=verbose,
+                recursive=True,
             )
             return query_runner.query(query_str, self._index_struct)
         else:
-            query_obj = self._mode_to_query(mode, **query_kwargs)
-            # set llm_predictor if exists
-            if not query_obj._llm_predictor_set:
-                query_obj.set_llm_predictor(self._llm_predictor)
-            # set prompt_helper if exists
-            query_obj.set_prompt_helper(self._prompt_helper)
-            return query_obj.query(query_str, verbose=verbose)
+            self._preprocess_query(mode_enum, query_kwargs)
+            # TODO: pass in query config directly
+            query_config = QueryConfig(
+                index_struct_type=IndexStructType.from_index_struct(self._index_struct),
+                query_mode=mode_enum,
+                query_kwargs=query_kwargs,
+            ).to_dict()
+            query_runner = QueryRunner(
+                self._llm_predictor,
+                self._docstore,
+                query_configs=[query_config],
+                verbose=verbose,
+                recursive=False,
+            )
+            return query_runner.query(query_str, self._index_struct)
 
     @classmethod
     def load_from_disk(cls, save_path: str, **kwargs: Any) -> "BaseGPTIndex":
diff --git a/gpt_index/indices/data_structs.py b/gpt_index/indices/data_structs.py
index a3ccf7783ae25e8f45806312c83f300e2b35a304..017a783e22a75a39a598594e93a357b96f44d455 100644
--- a/gpt_index/indices/data_structs.py
+++ b/gpt_index/indices/data_structs.py
@@ -143,6 +143,7 @@ class IndexList(IndexStruct):
         return cur_node.index
 
 
+# TODO: this should be specific to FAISS
 @dataclass
 class IndexDict(IndexStruct):
     """A simple dictionary of documents."""
@@ -194,6 +195,7 @@ class IndexStructType(str, Enum):
     TREE = "tree"
     LIST = "list"
     KEYWORD_TABLE = "keyword_table"
+    DICT = "dict"
 
     def get_index_struct_cls(self) -> type:
         """Get index struct class."""
@@ -203,6 +205,8 @@ class IndexStructType(str, Enum):
             return IndexList
         elif self == IndexStructType.KEYWORD_TABLE:
             return KeywordTable
+        elif self == IndexStructType.DICT:
+            return IndexDict
         else:
             raise ValueError("Invalid index struct type.")
 
@@ -215,5 +219,7 @@ class IndexStructType(str, Enum):
             return cls.LIST
         elif isinstance(index_struct, KeywordTable):
             return cls.KEYWORD_TABLE
+        elif isinstance(index_struct, IndexDict):
+            return cls.DICT
         else:
             raise ValueError("Invalid index struct type.")
diff --git a/gpt_index/indices/keyword_table/base.py b/gpt_index/indices/keyword_table/base.py
index 4a3b404cb46a94a9b6e279b7a254f60ac6d0fb7b..f3e70adb0d66cce19e41681bc8ceab8ec2f3a060 100644
--- a/gpt_index/indices/keyword_table/base.py
+++ b/gpt_index/indices/keyword_table/base.py
@@ -11,20 +11,9 @@ existing keywords in the table.
 from abc import abstractmethod
 from typing import Any, Optional, Sequence, Set
 
-from gpt_index.indices.base import (
-    DEFAULT_MODE,
-    DOCUMENTS_INPUT,
-    BaseGPTIndex,
-    BaseGPTIndexQuery,
-)
+from gpt_index.indices.base import DOCUMENTS_INPUT, BaseGPTIndex
 from gpt_index.indices.data_structs import KeywordTable
 from gpt_index.indices.keyword_table.utils import extract_keywords_given_response
-from gpt_index.indices.query.keyword_table.query import (
-    BaseGPTKeywordTableQuery,
-    GPTKeywordTableGPTQuery,
-    GPTKeywordTableRAKEQuery,
-    GPTKeywordTableSimpleQuery,
-)
 from gpt_index.indices.utils import truncate_text
 from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
 from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter
@@ -54,10 +43,6 @@ class BaseGPTKeywordTableIndex(BaseGPTIndex[KeywordTable]):
         keyword_extract_template (Optional[KeywordExtractPrompt]): A Keyword
             Extraction Prompt
             (see :ref:`Prompt-Templates`).
-        max_keywords_per_query (int): The maximum number of keywords to extract
-            per query.
-        max_keywords_per_query (int): The maximum number of keywords to extract
-            per chunk.
 
     """
 
@@ -68,7 +53,6 @@ class BaseGPTKeywordTableIndex(BaseGPTIndex[KeywordTable]):
         documents: Optional[Sequence[DOCUMENTS_INPUT]] = None,
         index_struct: Optional[KeywordTable] = None,
         keyword_extract_template: Optional[KeywordExtractPrompt] = None,
-        max_keywords_per_query: int = 10,
         max_keywords_per_chunk: int = 10,
         llm_predictor: Optional[LLMPredictor] = None,
         **kwargs: Any,
@@ -78,7 +62,6 @@ class BaseGPTKeywordTableIndex(BaseGPTIndex[KeywordTable]):
         self.keyword_extract_template = (
             keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE
         )
-        self.max_keywords_per_query = max_keywords_per_query
         self.max_keywords_per_chunk = max_keywords_per_chunk
         super().__init__(
             documents=documents,
@@ -90,26 +73,6 @@ class BaseGPTKeywordTableIndex(BaseGPTIndex[KeywordTable]):
             self.keyword_extract_template, 1
         )
 
-    def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery:
-        """Query mode to class."""
-        if mode == DEFAULT_MODE:
-            query_kwargs.update(
-                {
-                    "max_keywords_per_query": self.max_keywords_per_query,
-                    "keyword_extract_template": self.keyword_extract_template,
-                }
-            )
-            query: BaseGPTKeywordTableQuery = GPTKeywordTableGPTQuery(
-                self.index_struct, **query_kwargs
-            )
-        elif mode == "simple":
-            query = GPTKeywordTableSimpleQuery(self.index_struct, **query_kwargs)
-        elif mode == "rake":
-            query = GPTKeywordTableRAKEQuery(self.index_struct, **query_kwargs)
-        else:
-            raise ValueError(f"Invalid query mode: {mode}.")
-        return query
-
     @abstractmethod
     def _extract_keywords(self, text: str) -> Set[str]:
         """Extract keywords from text."""
diff --git a/gpt_index/indices/list/base.py b/gpt_index/indices/list/base.py
index e05fa7c3e3fd047fcdfe6ba46c85ef9f16406dfa..338cda0ce7dbb26be630a277e8f37511f2197aef 100644
--- a/gpt_index/indices/list/base.py
+++ b/gpt_index/indices/list/base.py
@@ -7,16 +7,8 @@ in sequence in order to answer a given query.
 
 from typing import Any, Optional, Sequence
 
-from gpt_index.indices.base import (
-    DEFAULT_MODE,
-    DOCUMENTS_INPUT,
-    EMBEDDING_MODE,
-    BaseGPTIndex,
-)
+from gpt_index.indices.base import DOCUMENTS_INPUT, BaseGPTIndex
 from gpt_index.indices.data_structs import IndexList
-from gpt_index.indices.query.base import BaseGPTIndexQuery
-from gpt_index.indices.query.list.embedding_query import GPTListIndexEmbeddingQuery
-from gpt_index.indices.query.list.query import BaseGPTListIndexQuery, GPTListIndexQuery
 from gpt_index.indices.utils import truncate_text
 from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
 from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter
@@ -99,21 +91,6 @@ class GPTListIndex(BaseGPTIndex[IndexList]):
             self._add_document_to_index(index_struct, d, text_splitter)
         return index_struct
 
-    def _mode_to_query(
-        self, mode: str, *query_args: Any, **query_kwargs: Any
-    ) -> BaseGPTIndexQuery:
-        if mode == DEFAULT_MODE:
-            if "text_qa_template" not in query_kwargs:
-                query_kwargs["text_qa_template"] = self.text_qa_template
-            query: BaseGPTListIndexQuery = GPTListIndexQuery(
-                self.index_struct, **query_kwargs
-            )
-        elif mode == EMBEDDING_MODE:
-            query = GPTListIndexEmbeddingQuery(self.index_struct, **query_kwargs)
-        else:
-            raise ValueError(f"Invalid query mode: {mode}.")
-        return query
-
     def _insert(self, document: BaseDocument, **insert_kwargs: Any) -> None:
         """Insert a document."""
         text_chunks = self._text_splitter.split_text(document.get_text())
diff --git a/gpt_index/indices/query/list/embedding_query.py b/gpt_index/indices/query/list/embedding_query.py
index 9d2a81a78d6e125c11718d938f4747a13ccdbee3..ae4f3c7c29fca19ae59567fbbc64a2e9e0f3645c 100644
--- a/gpt_index/indices/query/list/embedding_query.py
+++ b/gpt_index/indices/query/list/embedding_query.py
@@ -1,5 +1,5 @@
 """Embedding query for list index."""
-from typing import List, Optional
+from typing import Any, List, Optional
 
 from gpt_index.embeddings.openai import OpenAIEmbedding
 from gpt_index.indices.data_structs import IndexList, Node
@@ -31,6 +31,7 @@ class GPTListIndexEmbeddingQuery(BaseGPTListIndexQuery):
         keyword: Optional[str] = None,
         similarity_top_k: Optional[int] = 1,
         embed_model: Optional[OpenAIEmbedding] = None,
+        **kwargs: Any,
     ) -> None:
         """Initialize params."""
         super().__init__(
@@ -38,6 +39,7 @@ class GPTListIndexEmbeddingQuery(BaseGPTListIndexQuery):
             text_qa_template=text_qa_template,
             refine_template=refine_template,
             keyword=keyword,
+            **kwargs,
         )
         self._embed_model = embed_model or OpenAIEmbedding()
         self.similarity_top_k = similarity_top_k
diff --git a/gpt_index/indices/query/query_map.py b/gpt_index/indices/query/query_map.py
index 6fb5487151e7f68a406191df01ac8d556e82768c..d784b1aa2f656b7d54780e6c3ca3e44f17c1646c 100644
--- a/gpt_index/indices/query/query_map.py
+++ b/gpt_index/indices/query/query_map.py
@@ -16,8 +16,8 @@ from gpt_index.indices.query.tree.embedding_query import GPTTreeIndexEmbeddingQu
 from gpt_index.indices.query.tree.leaf_query import GPTTreeIndexLeafQuery
 from gpt_index.indices.query.tree.retrieve_query import GPTTreeIndexRetQuery
 from gpt_index.indices.query.tree.summarize_query import GPTTreeIndexSummarizeQuery
+from gpt_index.indices.query.vector_store.faiss import GPTFaissIndexQuery
 
-# TODO: migrate _mode_to_query in indices/base.py to use this file
 MODE_TO_QUERY_MAP_TREE = {
     QueryMode.DEFAULT: GPTTreeIndexLeafQuery,
     QueryMode.RETRIEVE: GPTTreeIndexRetQuery,
@@ -36,6 +36,10 @@ MODE_TO_QUERY_MAP_KEYWORD_TABLE = {
     QueryMode.RAKE: GPTKeywordTableSimpleQuery,
 }
 
+MODE_TO_QUERY_MAP_VECTOR = {
+    QueryMode.DEFAULT: GPTFaissIndexQuery,
+}
+
 
 def get_query_cls(
     index_struct_type: IndexStructType, mode: QueryMode
@@ -47,5 +51,7 @@ def get_query_cls(
         return MODE_TO_QUERY_MAP_LIST[mode]
     elif index_struct_type == IndexStructType.KEYWORD_TABLE:
         return MODE_TO_QUERY_MAP_KEYWORD_TABLE[mode]
+    elif index_struct_type == IndexStructType.DICT:
+        return MODE_TO_QUERY_MAP_VECTOR[mode]
     else:
         raise ValueError(f"Invalid index_struct_type: {index_struct_type}")
diff --git a/gpt_index/indices/query/query_runner.py b/gpt_index/indices/query/query_runner.py
index 67da872bb84341832255a1f06459e8c0f3fc5ba7..e0db4d9e416f8b05d0f4e362d979f5efd2a20a7f 100644
--- a/gpt_index/indices/query/query_runner.py
+++ b/gpt_index/indices/query/query_runner.py
@@ -38,6 +38,7 @@ class QueryRunner(BaseQueryRunner):
         docstore: DocumentStore,
         query_configs: Optional[List[Dict]] = None,
         verbose: bool = False,
+        recursive: bool = False,
     ) -> None:
         """Init params."""
         config_dict: Dict[IndexStructType, QueryConfig] = {}
@@ -52,6 +53,7 @@ class QueryRunner(BaseQueryRunner):
         self._llm_predictor = llm_predictor
         self._docstore = docstore
         self._verbose = verbose
+        self._recursive = recursive
 
     def query(self, query_str: str, index_struct: IndexStruct) -> str:
         """Run query."""
@@ -61,10 +63,12 @@ class QueryRunner(BaseQueryRunner):
         config = self._config_dict[index_struct_type]
         mode = config.query_mode
         query_cls = get_query_cls(index_struct_type, mode)
+        # if recursive, pass self as query_runner to each individual query
+        query_runner = self if self._recursive else None
         query_obj = query_cls(
             index_struct,
             **config.query_kwargs,
-            query_runner=self,
+            query_runner=query_runner,
             docstore=self._docstore,
         )
 
diff --git a/gpt_index/indices/query/schema.py b/gpt_index/indices/query/schema.py
index 75261b004aa2b9cf683dc6b363e96bf71efe1910..c25f6fb08f75a3402a766a164bef8e9342d196af 100644
--- a/gpt_index/indices/query/schema.py
+++ b/gpt_index/indices/query/schema.py
@@ -23,6 +23,9 @@ class QueryMode(str, Enum):
     SIMPLE = "simple"
     RAKE = "rake"
 
+    # recursive queries (composable queries)
+    RECURSIVE = "recursive"
+
 
 @dataclass
 class QueryConfig(DataClassJsonMixin):
diff --git a/gpt_index/indices/query/tree/embedding_query.py b/gpt_index/indices/query/tree/embedding_query.py
index 4c81140a73410d54d914eb0c55ca57ff524cd7b0..3f9e6d3a2976ba3bf131a1dd1bdee7c8eae5fa50 100644
--- a/gpt_index/indices/query/tree/embedding_query.py
+++ b/gpt_index/indices/query/tree/embedding_query.py
@@ -1,6 +1,6 @@
 """Query Tree using embedding similarity between query and node text."""
 
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from gpt_index.embeddings.openai import OpenAIEmbedding
 from gpt_index.indices.data_structs import IndexGraph, Node
@@ -53,15 +53,17 @@ class GPTTreeIndexEmbeddingQuery(GPTTreeIndexLeafQuery):
         refine_template: Optional[RefinePrompt] = None,
         child_branch_factor: int = 1,
         embed_model: Optional[OpenAIEmbedding] = None,
+        **kwargs: Any,
     ) -> None:
         """Initialize params."""
         super().__init__(
             index_struct,
-            query_template,
-            query_template_multiple,
-            text_qa_template,
-            refine_template,
-            child_branch_factor,
+            query_template=query_template,
+            query_template_multiple=query_template_multiple,
+            text_qa_template=text_qa_template,
+            refine_template=refine_template,
+            child_branch_factor=child_branch_factor,
+            **kwargs,
         )
         self._embed_model = embed_model or OpenAIEmbedding()
         self.child_branch_factor = child_branch_factor
diff --git a/gpt_index/indices/tree/base.py b/gpt_index/indices/tree/base.py
index 22c330ee219b2b95613df401290495aa50120f98..8d90da84424e6941d98e21965d6f4ff556f62eb0 100644
--- a/gpt_index/indices/tree/base.py
+++ b/gpt_index/indices/tree/base.py
@@ -2,20 +2,10 @@
 
 from typing import Any, Optional, Sequence
 
-from gpt_index.indices.base import (
-    DEFAULT_MODE,
-    DOCUMENTS_INPUT,
-    EMBEDDING_MODE,
-    SUMMARIZE_MODE,
-    BaseGPTIndex,
-)
+from gpt_index.indices.base import DOCUMENTS_INPUT, BaseGPTIndex
 from gpt_index.indices.common.tree.base import GPTTreeIndexBuilder
 from gpt_index.indices.data_structs import IndexGraph
-from gpt_index.indices.query.base import BaseGPTIndexQuery
-from gpt_index.indices.query.tree.embedding_query import GPTTreeIndexEmbeddingQuery
-from gpt_index.indices.query.tree.leaf_query import GPTTreeIndexLeafQuery
-from gpt_index.indices.query.tree.retrieve_query import GPTTreeIndexRetQuery
-from gpt_index.indices.query.tree.summarize_query import GPTTreeIndexSummarizeQuery
+from gpt_index.indices.query.schema import QueryMode
 from gpt_index.indices.tree.inserter import GPTIndexInserter
 from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
 from gpt_index.prompts.default_prompts import (
@@ -25,12 +15,10 @@ from gpt_index.prompts.default_prompts import (
 from gpt_index.prompts.prompts import SummaryPrompt, TreeInsertPrompt
 from gpt_index.schema import BaseDocument
 
-RETRIEVE_MODE = "retrieve"
-
 REQUIRE_TREE_MODES = {
-    DEFAULT_MODE,
-    EMBEDDING_MODE,
-    RETRIEVE_MODE,
+    QueryMode.DEFAULT,
+    QueryMode.EMBEDDING,
+    QueryMode.RETRIEVE,
 }
 
 
@@ -81,7 +69,7 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]):
             **kwargs,
         )
 
-    def _validate_build_tree_required(self, mode: str) -> None:
+    def _validate_build_tree_required(self, mode: QueryMode) -> None:
         """Check if index supports modes that require trees."""
         if mode in REQUIRE_TREE_MODES and not self.build_tree:
             raise ValueError(
@@ -89,22 +77,9 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]):
                 f"but mode {mode} requires trees."
             )
 
-    def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery:
+    def _preprocess_query(self, mode: QueryMode, query_kwargs: Any) -> None:
         """Query mode to class."""
         self._validate_build_tree_required(mode)
-        if mode == DEFAULT_MODE:
-            query: BaseGPTIndexQuery = GPTTreeIndexLeafQuery(
-                self.index_struct, **query_kwargs
-            )
-        elif mode == RETRIEVE_MODE:
-            query = GPTTreeIndexRetQuery(self.index_struct, **query_kwargs)
-        elif mode == EMBEDDING_MODE:
-            query = GPTTreeIndexEmbeddingQuery(self.index_struct, **query_kwargs)
-        elif mode == SUMMARIZE_MODE:
-            query = GPTTreeIndexSummarizeQuery(self.index_struct, **query_kwargs)
-        else:
-            raise ValueError(f"Invalid query mode: {mode}.")
-        return query
 
     def _build_index_from_documents(
         self, documents: Sequence[BaseDocument], verbose: bool = False
diff --git a/gpt_index/indices/vector_store/faiss.py b/gpt_index/indices/vector_store/faiss.py
index 317613a9d25c945a2bf5b76c4ba6fe8e7036da73..005decd701946a23b11d1029883ce2e9563bbc51 100644
--- a/gpt_index/indices/vector_store/faiss.py
+++ b/gpt_index/indices/vector_store/faiss.py
@@ -9,10 +9,9 @@ from typing import Any, Optional, Sequence, cast
 import numpy as np
 
 from gpt_index.embeddings.openai import OpenAIEmbedding
-from gpt_index.indices.base import DEFAULT_MODE, DOCUMENTS_INPUT, BaseGPTIndex
+from gpt_index.indices.base import DOCUMENTS_INPUT, BaseGPTIndex
 from gpt_index.indices.data_structs import IndexDict
-from gpt_index.indices.query.base import BaseGPTIndexQuery
-from gpt_index.indices.query.vector_store.faiss import GPTFaissIndexQuery
+from gpt_index.indices.query.schema import QueryMode
 from gpt_index.indices.utils import truncate_text
 from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
 from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter
@@ -103,6 +102,11 @@ class GPTFaissIndex(BaseGPTIndex[IndexDict]):
             # add to index
             index_struct.add_text(text_chunk, document.get_doc_id(), text_id=new_id)
 
+    def _preprocess_query(self, mode: QueryMode, query_kwargs: Any) -> None:
+        """Query mode to class."""
+        # pass along faiss_index
+        query_kwargs["faiss_index"] = self._faiss_index
+
     def _build_index_from_documents(
         self, documents: Sequence[BaseDocument], verbose: bool = False
     ) -> IndexDict:
@@ -115,19 +119,6 @@ class GPTFaissIndex(BaseGPTIndex[IndexDict]):
             self._add_document_to_index(index_struct, d, text_splitter)
         return index_struct
 
-    def _mode_to_query(
-        self, mode: str, *query_args: Any, **query_kwargs: Any
-    ) -> BaseGPTIndexQuery:
-        if mode == DEFAULT_MODE:
-            if "text_qa_template" not in query_kwargs:
-                query_kwargs["text_qa_template"] = self.text_qa_template
-            query: GPTFaissIndexQuery = GPTFaissIndexQuery(
-                self.index_struct, faiss_index=self._faiss_index, **query_kwargs
-            )
-        else:
-            raise ValueError(f"Invalid query mode: {mode}.")
-        return query
-
     def _insert(self, document: BaseDocument, **insert_kwargs: Any) -> None:
         """Insert a document."""
         self._add_document_to_index(self._index_struct, document, self._text_splitter)