diff --git a/examples/paul_graham_essay/TestEssay.ipynb b/examples/paul_graham_essay/TestEssay.ipynb index a7dc994344cd31c19d9c6aef328ffb9d23172e59..f07f7a6a31c57fbe5116091a62632a757b9b186b 100644 --- a/examples/paul_graham_essay/TestEssay.ipynb +++ b/examples/paul_graham_essay/TestEssay.ipynb @@ -41,16 +41,27 @@ "from IPython.display import Markdown, display" ] }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1c297fd3-3424-41d8-9d0d-25fe6310ab62", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "documents = SimpleDirectoryReader('data').load_data()" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "1298bbb4-c99e-431e-93ef-eb32c0a2fc2a", + "id": "370fd08f-56ff-4c24-b0c4-c93116a6d482", "metadata": { "tags": [] }, "outputs": [], "source": [ - "documents = SimpleDirectoryReader('data').load_data()\n", "index = GPTTreeIndex(documents)" ] }, @@ -189,12 +200,84 @@ "display(Markdown(f\"<b>{response}</b>\"))" ] }, + { + "cell_type": "markdown", + "id": "3c572726-bb95-49c3-a762-d966de59ee5f", + "metadata": {}, + "source": [ + "#### [Demo] Build Tree Index during Query-Time" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "255fb052-1ff6-4f27-881f-28d4790e9520", + "metadata": {}, + "outputs": [], + "source": [ + "documents = SimpleDirectoryReader('data').load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "85371256-292c-473e-9485-7de5c1997a59", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> [build_index_from_documents] Total token usage: 0 tokens\n" + ] + } + ], + "source": [ + "index_light = GPTTreeIndex(documents, build_tree=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "77b0acb3-5593-4f00-8eef-315a031fedc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> Starting query: What did the author do after his time at Y Combinator?\n", + "> Building index from nodes: 5 chunks\n", + "0/57\n", + "10/57\n", + "20/57\n", + "30/57\n", + "40/57\n", + "50/57\n", + "> [query] Total token usage: 18200 tokens\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\nThe author went back to painting.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_light.query(\"What did the author do after his time at Y Combinator?\", mode=\"summarize\")" + ] + }, { "cell_type": "markdown", "id": "f9773497-9aa6-4a16-884a-cd882e63d012", "metadata": {}, "source": [ - "#### [Demo] Build Tree Index with query_str, directly retrieve answer from root node" + "#### [Demo] Build Tree Index with a custom Summary Prompt, directly retrieve answer from root node" ] }, { diff --git a/gpt_index/indices/base.py b/gpt_index/indices/base.py index 6b51a01da32332fc70aff25afd8e65ac577fd573..a444323f635a22ddfcbae62f97aa09acb3bdf195 100644 --- a/gpt_index/indices/base.py +++ b/gpt_index/indices/base.py @@ -24,8 +24,10 @@ from gpt_index.utils import llm_token_counter IS = TypeVar("IS", bound=IndexStruct) +# TODO: remove and consolidate with QueryMode DEFAULT_MODE = "default" EMBEDDING_MODE = "embedding" +SUMMARIZE_MODE = "summarize" DOCUMENTS_INPUT = Union[BaseDocument, "BaseGPTIndex"] diff --git a/gpt_index/indices/common/__init__.py b/gpt_index/indices/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4640565ae2765d9ca96a509dc9809217f62f2f --- /dev/null +++ b/gpt_index/indices/common/__init__.py @@ -0,0 +1 @@ +"""Init file.""" diff --git a/gpt_index/indices/common/tree/__init__.py b/gpt_index/indices/common/tree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4640565ae2765d9ca96a509dc9809217f62f2f --- /dev/null +++ b/gpt_index/indices/common/tree/__init__.py @@ -0,0 +1 @@ +"""Init file.""" diff --git a/gpt_index/indices/common/tree/base.py b/gpt_index/indices/common/tree/base.py new file mode 100644 index 0000000000000000000000000000000000000000..07f389b13184441b6843b9a12426d6da9044fdeb --- /dev/null +++ b/gpt_index/indices/common/tree/base.py @@ -0,0 +1,119 @@ +"""Common classes/functions for tree index operations.""" + + +from typing import Dict, Optional, Sequence + +from gpt_index.indices.data_structs import IndexGraph, Node +from gpt_index.indices.prompt_helper import PromptHelper +from gpt_index.indices.utils import get_sorted_node_list, truncate_text +from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor +from gpt_index.prompts.prompts import SummaryPrompt +from gpt_index.schema import BaseDocument + + +class GPTTreeIndexBuilder: + """GPT tree index builder. + + Helper class to build the tree-structured index, + or to synthesize an answer. + + """ + + def __init__( + self, + num_children: int, + summary_prompt: SummaryPrompt, + llm_predictor: Optional[LLMPredictor], + prompt_helper: Optional[PromptHelper], + ) -> None: + """Initialize with params.""" + if num_children < 2: + raise ValueError("Invalid number of children.") + self.num_children = num_children + self.summary_prompt = summary_prompt + self._prompt_helper = prompt_helper or PromptHelper() + self._text_splitter = self._prompt_helper.get_text_splitter_given_prompt( + self.summary_prompt, self.num_children + ) + self._llm_predictor = llm_predictor or LLMPredictor() + + def _get_nodes_from_document( + self, start_idx: int, document: BaseDocument + ) -> Dict[int, Node]: + """Add document to index.""" + text_chunks = self._text_splitter.split_text(document.get_text()) + doc_nodes = { + (start_idx + i): Node( + text=t, index=(start_idx + i), ref_doc_id=document.get_doc_id() + ) + for i, t in enumerate(text_chunks) + } + return doc_nodes + + def build_from_text( + self, + documents: Sequence[BaseDocument], + build_tree: bool = True, + verbose: bool = False, + ) -> IndexGraph: + """Build from text. + + Returns: + IndexGraph: graph object consisting of all_nodes, root_nodes + + """ + all_nodes: Dict[int, Node] = {} + for d in documents: + all_nodes.update(self._get_nodes_from_document(len(all_nodes), d)) + + if build_tree: + # instantiate all_nodes from initial text chunks + root_nodes = self.build_index_from_nodes( + all_nodes, all_nodes, verbose=verbose + ) + else: + # if build_tree is False, then don't surface any root nodes + root_nodes = {} + return IndexGraph(all_nodes=all_nodes, root_nodes=root_nodes) + + def build_index_from_nodes( + self, + cur_nodes: Dict[int, Node], + all_nodes: Dict[int, Node], + verbose: bool = False, + ) -> Dict[int, Node]: + """Consolidates chunks recursively, in a bottoms-up fashion.""" + cur_node_list = get_sorted_node_list(cur_nodes) + cur_index = len(all_nodes) + new_node_dict = {} + print( + f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks" + ) + for i in range(0, len(cur_node_list), self.num_children): + print(f"{i}/{len(cur_nodes)}") + cur_nodes_chunk = cur_node_list[i : i + self.num_children] + text_chunk = self._prompt_helper.get_text_from_nodes( + cur_nodes_chunk, prompt=self.summary_prompt + ) + + new_summary, _ = self._llm_predictor.predict( + self.summary_prompt, context_str=text_chunk + ) + + if verbose: + fmt_summary = truncate_text(new_summary, 50) + print(f"> {i}/{len(cur_nodes)}, summary: {fmt_summary}") + new_node = Node( + text=new_summary, + index=cur_index, + child_indices={n.index for n in cur_nodes_chunk}, + ) + new_node_dict[cur_index] = new_node + cur_index += 1 + + all_nodes.update(new_node_dict) + + if len(new_node_dict) <= self.num_children: + return new_node_dict + else: + return self.build_index_from_nodes(new_node_dict, all_nodes) diff --git a/gpt_index/indices/query/base.py b/gpt_index/indices/query/base.py index 92a8d4a697f4ed7c88f608d445da9ed6a6b943db..ba72f947dc600a872ad7694188bd844ccc34919e 100644 --- a/gpt_index/indices/query/base.py +++ b/gpt_index/indices/query/base.py @@ -9,7 +9,7 @@ from gpt_index.indices.prompt_helper import PromptHelper from gpt_index.indices.response_utils import give_response, refine_response from gpt_index.indices.utils import truncate_text from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor -from gpt_index.prompts.base import Prompt +from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt from gpt_index.schema import DocumentStore from gpt_index.utils import llm_token_counter @@ -63,8 +63,8 @@ class BaseGPTIndexQuery(Generic[IS]): self, query_str: str, node: Node, - text_qa_template: Prompt, - refine_template: Prompt, + text_qa_template: QuestionAnswerPrompt, + refine_template: RefinePrompt, response: Optional[str] = None, verbose: bool = False, level: Optional[int] = None, diff --git a/gpt_index/indices/query/query_map.py b/gpt_index/indices/query/query_map.py index f129c34b0f5d370136e31e5ec9b8aa8fc84f0517..6fb5487151e7f68a406191df01ac8d556e82768c 100644 --- a/gpt_index/indices/query/query_map.py +++ b/gpt_index/indices/query/query_map.py @@ -15,12 +15,14 @@ from gpt_index.indices.query.schema import QueryMode from gpt_index.indices.query.tree.embedding_query import GPTTreeIndexEmbeddingQuery from gpt_index.indices.query.tree.leaf_query import GPTTreeIndexLeafQuery from gpt_index.indices.query.tree.retrieve_query import GPTTreeIndexRetQuery +from gpt_index.indices.query.tree.summarize_query import GPTTreeIndexSummarizeQuery # TODO: migrate _mode_to_query in indices/base.py to use this file MODE_TO_QUERY_MAP_TREE = { QueryMode.DEFAULT: GPTTreeIndexLeafQuery, QueryMode.RETRIEVE: GPTTreeIndexRetQuery, QueryMode.EMBEDDING: GPTTreeIndexEmbeddingQuery, + QueryMode.SUMMARIZE: GPTTreeIndexSummarizeQuery, } MODE_TO_QUERY_MAP_LIST = { diff --git a/gpt_index/indices/query/schema.py b/gpt_index/indices/query/schema.py index aa04fb5e076aeb32016b498468b5bf837eeda8d0..75261b004aa2b9cf683dc6b363e96bf71efe1910 100644 --- a/gpt_index/indices/query/schema.py +++ b/gpt_index/indices/query/schema.py @@ -16,6 +16,9 @@ class QueryMode(str, Enum): RETRIEVE = "retrieve" EMBEDDING = "embedding" + # to hiearchically summarize using tree + SUMMARIZE = "summarize" + # for keyword extractor SIMPLE = "simple" RAKE = "rake" diff --git a/gpt_index/indices/query/tree/summarize_query.py b/gpt_index/indices/query/tree/summarize_query.py new file mode 100644 index 0000000000000000000000000000000000000000..261c49996ba8d4444534d949dd62bd1fffda728b --- /dev/null +++ b/gpt_index/indices/query/tree/summarize_query.py @@ -0,0 +1,75 @@ +"""Summarize query.""" + + +from typing import Any, Optional + +from gpt_index.indices.common.tree.base import GPTTreeIndexBuilder +from gpt_index.indices.data_structs import IndexGraph +from gpt_index.indices.query.base import BaseGPTIndexQuery +from gpt_index.indices.response_utils.response import give_response +from gpt_index.indices.utils import get_sorted_node_list +from gpt_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT +from gpt_index.prompts.prompts import QuestionAnswerPrompt, SummaryPrompt + + +class GPTTreeIndexSummarizeQuery(BaseGPTIndexQuery[IndexGraph]): + """GPT Tree Index summarize query. + + This class builds a query-specific tree from leaf nodes to return a response. + Using this query mode means that the tree index doesn't need to be built + when initialized, since we rebuild the tree for each query. + + .. code-block:: python + + response = index.query("<query_str>", mode="summarize") + + Args: + text_qa_template (Optional[QuestionAnswerPrompt]): Question-Answer Prompt + (see :ref:`Prompt-Templates`). + + """ + + def __init__( + self, + index_struct: IndexGraph, + text_qa_template: Optional[QuestionAnswerPrompt] = None, + num_children: int = 10, + **kwargs: Any, + ) -> None: + """Initialize params.""" + super().__init__(index_struct, **kwargs) + self.text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT + self.num_children = num_children + + def _query(self, query_str: str, verbose: bool = False) -> str: + """Answer a query.""" + print(f"> Starting query: {query_str}") + + # use prompt composability to build a summary prompt + text_qa_template = self.text_qa_template.partial_format(query_str=query_str) + summary_template = SummaryPrompt.from_prompt(text_qa_template) + + index_builder = GPTTreeIndexBuilder( + self.num_children, + summary_template, + self._llm_predictor, + self._prompt_helper, + ) + all_nodes = self._index_struct.all_nodes.copy() + root_nodes = index_builder.build_index_from_nodes( + all_nodes, all_nodes, verbose=verbose + ) + + node_list = get_sorted_node_list(root_nodes) + node_text = self._prompt_helper.get_text_from_nodes( + node_list, prompt=self.text_qa_template + ) + response = give_response( + self._prompt_helper, + self._llm_predictor, + query_str, + node_text, + text_qa_template=self.text_qa_template, + verbose=verbose, + ) + return response diff --git a/gpt_index/indices/response_utils/response.py b/gpt_index/indices/response_utils/response.py index a10f712d2dd721694e60f3d47d7286452f866587..55eddf5aec1298b5b7ed2a95c8dbab53428ebecc 100644 --- a/gpt_index/indices/response_utils/response.py +++ b/gpt_index/indices/response_utils/response.py @@ -3,11 +3,11 @@ from gpt_index.indices.prompt_helper import PromptHelper from gpt_index.indices.utils import truncate_text from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor -from gpt_index.prompts.base import Prompt from gpt_index.prompts.default_prompts import ( DEFAULT_REFINE_PROMPT, DEFAULT_TEXT_QA_PROMPT, ) +from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt def refine_response( @@ -16,7 +16,7 @@ def refine_response( response: str, query_str: str, text_chunk: str, - refine_template: Prompt = DEFAULT_REFINE_PROMPT, + refine_template: RefinePrompt = DEFAULT_REFINE_PROMPT, verbose: bool = False, ) -> str: """Refine response.""" @@ -27,12 +27,12 @@ def refine_response( refine_template, 1 ) text_chunks = refine_text_splitter.split_text(text_chunk) - for text_chunk in text_chunks: + for cur_text_chunk in text_chunks: response, _ = llm_predictor.predict( refine_template, query_str=query_str, existing_answer=response, - context_msg=text_chunk, + context_msg=cur_text_chunk, ) if verbose: print(f"> Refined response: {response}") @@ -44,27 +44,28 @@ def give_response( llm_predictor: LLMPredictor, query_str: str, text_chunk: str, - text_qa_template: Prompt = DEFAULT_TEXT_QA_PROMPT, - refine_template: Prompt = DEFAULT_REFINE_PROMPT, + text_qa_template: QuestionAnswerPrompt = DEFAULT_TEXT_QA_PROMPT, + refine_template: RefinePrompt = DEFAULT_REFINE_PROMPT, verbose: bool = False, ) -> str: """Give response given a query and a corresponding text chunk.""" qa_text_splitter = prompt_helper.get_text_splitter_given_prompt(text_qa_template, 1) text_chunks = qa_text_splitter.split_text(text_chunk) response = None - for text_chunk in text_chunks: + for cur_text_chunk in text_chunks: if response is None: response, _ = llm_predictor.predict( - text_qa_template, query_str=query_str, context_str=text_chunk + text_qa_template, query_str=query_str, context_str=cur_text_chunk ) if verbose: print(f"> Initial response: {response}") else: response = refine_response( + prompt_helper, llm_predictor, response, query_str, - text_chunk, + cur_text_chunk, refine_template=refine_template, verbose=verbose, ) diff --git a/gpt_index/indices/tree/base.py b/gpt_index/indices/tree/base.py index 6967b0a1101ade14222a8725e7a061b024baf92e..22c330ee219b2b95613df401290495aa50120f98 100644 --- a/gpt_index/indices/tree/base.py +++ b/gpt_index/indices/tree/base.py @@ -1,21 +1,22 @@ """Tree-based index.""" -from typing import Any, Dict, Optional, Sequence +from typing import Any, Optional, Sequence from gpt_index.indices.base import ( DEFAULT_MODE, DOCUMENTS_INPUT, EMBEDDING_MODE, + SUMMARIZE_MODE, BaseGPTIndex, ) -from gpt_index.indices.data_structs import IndexGraph, Node -from gpt_index.indices.prompt_helper import PromptHelper +from gpt_index.indices.common.tree.base import GPTTreeIndexBuilder +from gpt_index.indices.data_structs import IndexGraph from gpt_index.indices.query.base import BaseGPTIndexQuery from gpt_index.indices.query.tree.embedding_query import GPTTreeIndexEmbeddingQuery from gpt_index.indices.query.tree.leaf_query import GPTTreeIndexLeafQuery from gpt_index.indices.query.tree.retrieve_query import GPTTreeIndexRetQuery +from gpt_index.indices.query.tree.summarize_query import GPTTreeIndexSummarizeQuery from gpt_index.indices.tree.inserter import GPTIndexInserter -from gpt_index.indices.utils import get_sorted_node_list, truncate_text from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor from gpt_index.prompts.default_prompts import ( DEFAULT_INSERT_PROMPT, @@ -26,103 +27,11 @@ from gpt_index.schema import BaseDocument RETRIEVE_MODE = "retrieve" - -class GPTTreeIndexBuilder: - """GPT tree index builder. - - Helper class to build the tree-structured index. - - """ - - def __init__( - self, - num_children: int, - summary_prompt: SummaryPrompt, - llm_predictor: Optional[LLMPredictor], - prompt_helper: Optional[PromptHelper], - ) -> None: - """Initialize with params.""" - if num_children < 2: - raise ValueError("Invalid number of children.") - self.num_children = num_children - self.summary_prompt = summary_prompt - self._prompt_helper = prompt_helper or PromptHelper() - self._text_splitter = self._prompt_helper.get_text_splitter_given_prompt( - self.summary_prompt, self.num_children - ) - self._llm_predictor = llm_predictor or LLMPredictor() - - def _get_nodes_from_document( - self, start_idx: int, document: BaseDocument - ) -> Dict[int, Node]: - """Add document to index.""" - text_chunks = self._text_splitter.split_text(document.get_text()) - doc_nodes = { - (start_idx + i): Node( - text=t, index=(start_idx + i), ref_doc_id=document.get_doc_id() - ) - for i, t in enumerate(text_chunks) - } - return doc_nodes - - def build_from_text( - self, documents: Sequence[BaseDocument], verbose: bool = False - ) -> IndexGraph: - """Build from text. - - Returns: - IndexGraph: graph object consisting of all_nodes, root_nodes - - """ - all_nodes: Dict[int, Node] = {} - for d in documents: - all_nodes.update(self._get_nodes_from_document(len(all_nodes), d)) - - # instantiate all_nodes from initial text chunks - root_nodes = self._build_index_from_nodes(all_nodes, all_nodes, verbose=verbose) - return IndexGraph(all_nodes=all_nodes, root_nodes=root_nodes) - - def _build_index_from_nodes( - self, - cur_nodes: Dict[int, Node], - all_nodes: Dict[int, Node], - verbose: bool = False, - ) -> Dict[int, Node]: - """Consolidates chunks recursively, in a bottoms-up fashion.""" - cur_node_list = get_sorted_node_list(cur_nodes) - cur_index = len(all_nodes) - new_node_dict = {} - print( - f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks" - ) - for i in range(0, len(cur_node_list), self.num_children): - print(f"{i}/{len(cur_nodes)}") - cur_nodes_chunk = cur_node_list[i : i + self.num_children] - text_chunk = self._prompt_helper.get_text_from_nodes( - cur_nodes_chunk, prompt=self.summary_prompt - ) - - new_summary, _ = self._llm_predictor.predict( - self.summary_prompt, text=text_chunk - ) - - if verbose: - fmt_summary = truncate_text(new_summary, 50) - print(f"> {i}/{len(cur_nodes)}, summary: {fmt_summary}") - new_node = Node( - text=new_summary, - index=cur_index, - child_indices={n.index for n in cur_nodes_chunk}, - ) - new_node_dict[cur_index] = new_node - cur_index += 1 - - all_nodes.update(new_node_dict) - - if len(new_node_dict) <= self.num_children: - return new_node_dict - else: - return self._build_index_from_nodes(new_node_dict, all_nodes) +REQUIRE_TREE_MODES = { + DEFAULT_MODE, + EMBEDDING_MODE, + RETRIEVE_MODE, +} class GPTTreeIndex(BaseGPTIndex[IndexGraph]): @@ -141,6 +50,8 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): (see :ref:`Prompt-Templates`). insert_prompt (Optional[TreeInsertPrompt]): An Tree Insertion Prompt (see :ref:`Prompt-Templates`). + num_children (int): The number of children each node should have. + build_tree (bool): Whether to build the tree during index construction. """ @@ -154,6 +65,7 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): insert_prompt: Optional[TreeInsertPrompt] = None, num_children: int = 10, llm_predictor: Optional[LLMPredictor] = None, + build_tree: bool = True, **kwargs: Any, ) -> None: """Initialize params.""" @@ -161,6 +73,7 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): self.num_children = num_children self.summary_template = summary_template or DEFAULT_SUMMARY_PROMPT self.insert_prompt: TreeInsertPrompt = insert_prompt or DEFAULT_INSERT_PROMPT + self.build_tree = build_tree super().__init__( documents=documents, index_struct=index_struct, @@ -168,8 +81,17 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): **kwargs, ) + def _validate_build_tree_required(self, mode: str) -> None: + """Check if index supports modes that require trees.""" + if mode in REQUIRE_TREE_MODES and not self.build_tree: + raise ValueError( + "Index was constructed without building trees, " + f"but mode {mode} requires trees." + ) + def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery: """Query mode to class.""" + self._validate_build_tree_required(mode) if mode == DEFAULT_MODE: query: BaseGPTIndexQuery = GPTTreeIndexLeafQuery( self.index_struct, **query_kwargs @@ -178,6 +100,8 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): query = GPTTreeIndexRetQuery(self.index_struct, **query_kwargs) elif mode == EMBEDDING_MODE: query = GPTTreeIndexEmbeddingQuery(self.index_struct, **query_kwargs) + elif mode == SUMMARIZE_MODE: + query = GPTTreeIndexSummarizeQuery(self.index_struct, **query_kwargs) else: raise ValueError(f"Invalid query mode: {mode}.") return query @@ -193,7 +117,9 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): self._llm_predictor, self._prompt_helper, ) - index_graph = index_builder.build_from_text(documents, verbose=verbose) + index_graph = index_builder.build_from_text( + documents, build_tree=self.build_tree, verbose=verbose + ) return index_graph def _insert(self, document: BaseDocument, **insert_kwargs: Any) -> None: diff --git a/gpt_index/indices/tree/inserter.py b/gpt_index/indices/tree/inserter.py index 5c3c632c67e7c71a050ad13bf3ffa4cb3a6e1729..d10177ef27552b52ab91c6cbd8d0a2e8fe972b6a 100644 --- a/gpt_index/indices/tree/inserter.py +++ b/gpt_index/indices/tree/inserter.py @@ -72,7 +72,7 @@ class GPTIndexInserter: half1, prompt=self.summary_prompt ) summary1, _ = self._llm_predictor.predict( - self.summary_prompt, text=text_chunk1 + self.summary_prompt, context_str=text_chunk1 ) node1 = Node( text=summary1, @@ -84,7 +84,7 @@ class GPTIndexInserter: half2, prompt=self.summary_prompt ) summary2, _ = self._llm_predictor.predict( - self.summary_prompt, text=text_chunk2 + self.summary_prompt, context_str=text_chunk2 ) node2 = Node( text=summary2, @@ -150,7 +150,7 @@ class GPTIndexInserter: cur_graph_node_list, prompt=self.summary_prompt ) new_summary, _ = self._llm_predictor.predict( - self.summary_prompt, text=text_chunk + self.summary_prompt, context_str=text_chunk ) parent_node.text = new_summary diff --git a/gpt_index/prompts/default_prompts.py b/gpt_index/prompts/default_prompts.py index 4aee940e85ea5e8016f3213af11a64333cec14cc..118fb5ea37e1ece1f79693760b32b98a9f7de5c0 100644 --- a/gpt_index/prompts/default_prompts.py +++ b/gpt_index/prompts/default_prompts.py @@ -21,7 +21,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = ( "Try to include as many key details as possible.\n" "\n" "\n" - "{text}\n" + "{context_str}\n" "\n" "\n" 'SUMMARY:"""\n' diff --git a/gpt_index/prompts/prompts.py b/gpt_index/prompts/prompts.py index 1611de447d171fea11fcecd0ca42d566046f4fe7..aac9e299a17f0db3649d68b119b0b263b1d18f8b 100644 --- a/gpt_index/prompts/prompts.py +++ b/gpt_index/prompts/prompts.py @@ -8,9 +8,9 @@ from gpt_index.prompts.prompt_type import PromptType class SummaryPrompt(Prompt): """Summary prompt. - Prompt to summarize the provided `text`. + Prompt to summarize the provided `context_str`. - Required template variables: `text` + Required template variables: `context_str` Args: template (str): Template for the prompt. @@ -19,7 +19,7 @@ class SummaryPrompt(Prompt): """ prompt_type: PromptType = PromptType.SUMMARY - input_variables: List[str] = ["text"] + input_variables: List[str] = ["context_str"] class TreeInsertPrompt(Prompt): diff --git a/tests/indices/test_response.py b/tests/indices/test_response.py new file mode 100644 index 0000000000000000000000000000000000000000..271c9df86a8efc4c32252e87c77a9c65e8632473 --- /dev/null +++ b/tests/indices/test_response.py @@ -0,0 +1,62 @@ +"""Test response utils.""" + +from typing import Any, List + +import pytest + +from gpt_index.indices.prompt_helper import PromptHelper +from gpt_index.indices.response_utils.response import give_response +from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor +from gpt_index.schema import Document +from tests.mock_utils.mock_decorator import patch_common +from tests.mock_utils.mock_prompts import MOCK_REFINE_PROMPT, MOCK_TEXT_QA_PROMPT + + +@pytest.fixture +def documents() -> List[Document]: + """Get documents.""" + # NOTE: one document for now + doc_text = ( + "Hello world.\n" + "This is a test.\n" + "This is another test.\n" + "This is a test v2." + ) + return [Document(doc_text)] + + +@patch_common +def test_give_response( + _mock_init: Any, + _mock_predict: Any, + _mock_total_tokens_used: Any, + _mock_split_text: Any, + documents: List[Document], +) -> None: + """Test give response.""" + prompt_helper = PromptHelper() + llm_predictor = LLMPredictor() + query_str = "What is?" + + # test single line + response = give_response( + prompt_helper, + llm_predictor, + query_str, + "This is a single line.", + text_qa_template=MOCK_TEXT_QA_PROMPT, + refine_template=MOCK_REFINE_PROMPT, + ) + + assert response == "What is?:This is a single line." + + # test multiple lines + response = give_response( + prompt_helper, + llm_predictor, + query_str, + documents[0].get_text(), + text_qa_template=MOCK_TEXT_QA_PROMPT, + refine_template=MOCK_REFINE_PROMPT, + ) + assert response == "What is?:Hello world." diff --git a/tests/indices/tree/test_base.py b/tests/indices/tree/test_base.py index 1fcfce4fa197a328f5dd557cdcdd547666f5756a..871bc801da738d61f71756ea558274389178c672 100644 --- a/tests/indices/tree/test_base.py +++ b/tests/indices/tree/test_base.py @@ -132,6 +132,37 @@ def test_query( assert response == ("What is?:Hello world.") +@patch_common +def test_summarize_query( + _mock_init: Any, + _mock_predict: Any, + _mock_total_tokens_used: Any, + _mock_split_text: Any, + documents: List[Document], + struct_kwargs: Dict, +) -> None: + """Test summarize query.""" + # create tree index without building tree + index_kwargs, orig_query_kwargs = struct_kwargs + index_kwargs = index_kwargs.copy() + index_kwargs.update({"build_tree": False}) + tree = GPTTreeIndex(documents, **index_kwargs) + + # test summarize query + query_str = "What is?" + query_kwargs: Dict[str, Any] = { + "text_qa_template": MOCK_TEXT_QA_PROMPT, + "num_children": 2, + } + # TODO: fix unit test later + response = tree.query(query_str, mode="summarize", **query_kwargs) + assert response == ("What is?:Hello world.") + + # test that default query fails + with pytest.raises(ValueError): + tree.query(query_str, mode="default", **orig_query_kwargs) + + @patch_common def test_insert( _mock_init: Any, diff --git a/tests/mock_utils/mock_predict.py b/tests/mock_utils/mock_predict.py index a91b0549b4aaeccbeff2e08f92923f8bf2d65c68..679f5bd707acc835c8417278d8b7ad861a3027c3 100644 --- a/tests/mock_utils/mock_predict.py +++ b/tests/mock_utils/mock_predict.py @@ -9,7 +9,7 @@ from gpt_index.token_predictor.utils import mock_extract_keywords_response def _mock_summary_predict(prompt_args: Dict) -> str: """Mock summary predict.""" - return prompt_args["text"] + return prompt_args["context_str"] def _mock_insert_predict() -> str: diff --git a/tests/mock_utils/mock_prompts.py b/tests/mock_utils/mock_prompts.py index 4163747a13bb3d69557c149cb06c74c337ca75f5..8d64751faacbc4d9dbc7573670346c749e53d5d2 100644 --- a/tests/mock_utils/mock_prompts.py +++ b/tests/mock_utils/mock_prompts.py @@ -10,7 +10,7 @@ from gpt_index.prompts.prompts import ( TreeSelectPrompt, ) -MOCK_SUMMARY_PROMPT_TMPL = "{text}\n" +MOCK_SUMMARY_PROMPT_TMPL = "{context_str}\n" MOCK_SUMMARY_PROMPT = SummaryPrompt(MOCK_SUMMARY_PROMPT_TMPL) MOCK_INSERT_PROMPT_TMPL = "{num_chunks}\n{context_list}{new_chunk_text}\n"