diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 110ed5f33daea889abafe8fcfcf052eb47a50e40..4c86ee3bb4070091ed4568e8e31a032cb91bf183 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -13,7 +13,7 @@ jobs: # You can use PyPy versions in python-version. # For example, pypy-2.7 and pypy-3.8 matrix: - python-version: ["3.9"] + python-version: ["3.9", "3.8"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 91d3113047d1e96be0f7720becda736d8e7cca7a..3170150b55fd587d7c76f0eb39382e751f0ccd48 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,9 +1,11 @@ version: 2 sphinx: configuration: docs/conf.py +build: + image: testing formats: all python: - version: 3.8 + version: 3.9 install: - requirements: docs/requirements.txt - method: pip diff --git a/README.md b/README.md index 657f7fcf241105c2b098ea247bd41d0eeee6ba89..1807662177a9fa8bc736a47509637eac5fcbd9da 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ index.query("<question_text>?", child_branch_factor=1) ## 🔧 Dependencies -The main third-party package requirements are `transformers`, `openai`, and `langchain`. +The main third-party package requirements are `tiktoken`, `openai`, and `langchain`. All requirements should be contained within the `setup.py` file. To run the package locally without building the wheel, simply do `pip install -r requirements.txt`. diff --git a/gpt_index/indices/base.py b/gpt_index/indices/base.py index 915fddbd12c301be024e5530cbd5a72e8a049b72..7bc5e49289ab45ddb3dc89c32c60d29359d1a628 100644 --- a/gpt_index/indices/base.py +++ b/gpt_index/indices/base.py @@ -20,6 +20,7 @@ from gpt_index.indices.query.base import BaseGPTIndexQuery from gpt_index.indices.query.query_runner import QueryRunner from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor from gpt_index.schema import BaseDocument, DocumentStore +from gpt_index.utils import llm_token_counter IS = TypeVar("IS", bound=IndexStruct) @@ -131,6 +132,7 @@ class BaseGPTIndex(Generic[IS]): def _insert(self, document: BaseDocument, **insert_kwargs: Any) -> None: """Insert a document.""" + @llm_token_counter("insert") def insert(self, document: DOCUMENTS_INPUT, **insert_kwargs: Any) -> None: """Insert a document.""" processed_doc = self._process_documents([document], self._docstore)[0] @@ -145,6 +147,7 @@ class BaseGPTIndex(Generic[IS]): def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery: """Query mode to class.""" + @llm_token_counter("query") def query( self, query_str: str, diff --git a/gpt_index/indices/keyword_table/base.py b/gpt_index/indices/keyword_table/base.py index 70f1556dc85f0fd7449176371c5286c8d5dfabbc..3d87b3d91d76033139e8ed81c1b885804066725d 100644 --- a/gpt_index/indices/keyword_table/base.py +++ b/gpt_index/indices/keyword_table/base.py @@ -34,6 +34,7 @@ from gpt_index.prompts.default_prompts import ( DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) from gpt_index.schema import BaseDocument +from gpt_index.utils import llm_token_counter DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE @@ -133,6 +134,7 @@ class BaseGPTKeywordTableIndex(BaseGPTIndex[KeywordTable]): ) print(f"> Keywords: {keywords}") + @llm_token_counter("build_index_from_documents") def build_index_from_documents( self, documents: Sequence[BaseDocument] ) -> KeywordTable: diff --git a/gpt_index/indices/prompt_helper.py b/gpt_index/indices/prompt_helper.py index 8aa538d551cdcbab20c4444d8d6ddab61913f241..e1720ff5a734306986b23d5fe4c8245d8e32084f 100644 --- a/gpt_index/indices/prompt_helper.py +++ b/gpt_index/indices/prompt_helper.py @@ -23,7 +23,7 @@ class PromptHelper: num_output: int = NUM_OUTPUTS, max_chunk_overlap: int = MAX_CHUNK_OVERLAP, embedding_limit: Optional[int] = None, - tokenizer: Optional[Callable] = None, + tokenizer: Optional[Callable[[str], List]] = None, ) -> None: """Init params.""" self.max_input_size = max_input_size @@ -46,7 +46,7 @@ class PromptHelper: """ prompt_tokens = self._tokenizer(prompt_text) - num_prompt_tokens = len(prompt_tokens["input_ids"]) + num_prompt_tokens = len(prompt_tokens) # NOTE: if embedding limit is specified, then chunk_size must not be larger than # embedding_limit diff --git a/gpt_index/indices/query/keyword_table/query.py b/gpt_index/indices/query/keyword_table/query.py index 52b00270bc1195fb5a0e3a250875702a0c5fc146..1158632a25ceba244a4ad0d744a5a7071d64d1a8 100644 --- a/gpt_index/indices/query/keyword_table/query.py +++ b/gpt_index/indices/query/keyword_table/query.py @@ -18,6 +18,7 @@ from gpt_index.prompts.default_prompts import ( DEFAULT_REFINE_PROMPT, DEFAULT_TEXT_QA_PROMPT, ) +from gpt_index.utils import llm_token_counter DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE @@ -67,6 +68,7 @@ class BaseGPTKeywordTableQuery(BaseGPTIndexQuery[KeywordTable]): def _get_keywords(self, query_str: str, verbose: bool = False) -> List[str]: """Extract keywords.""" + @llm_token_counter("query") def query(self, query_str: str, verbose: bool = False) -> str: """Answer a query.""" print(f"> Starting query: {query_str}") diff --git a/gpt_index/indices/query/tree/leaf_query.py b/gpt_index/indices/query/tree/leaf_query.py index 80531d38a08ccf2498617fe0521f14fe5bca02ad..12f7d094f252508bcbe5a9b92919b7a3aaf458f5 100644 --- a/gpt_index/indices/query/tree/leaf_query.py +++ b/gpt_index/indices/query/tree/leaf_query.py @@ -12,6 +12,7 @@ from gpt_index.prompts.default_prompts import ( DEFAULT_REFINE_PROMPT, DEFAULT_TEXT_QA_PROMPT, ) +from gpt_index.utils import llm_token_counter class GPTTreeIndexLeafQuery(BaseGPTIndexQuery[IndexGraph]): @@ -186,6 +187,7 @@ class GPTTreeIndexLeafQuery(BaseGPTIndexQuery[IndexGraph]): # result_response should not be None return cast(str, result_response) + @llm_token_counter("query") def query(self, query_str: str, verbose: bool = False) -> str: """Answer a query.""" print(f"> Starting query: {query_str}") diff --git a/gpt_index/indices/tree/base.py b/gpt_index/indices/tree/base.py index a946c4e8313ae8c7d898dd5c0a46d1da290164b7..1254b28c6e86079fe665031534fff2c7202f61a8 100644 --- a/gpt_index/indices/tree/base.py +++ b/gpt_index/indices/tree/base.py @@ -23,6 +23,7 @@ from gpt_index.prompts.default_prompts import ( DEFAULT_SUMMARY_PROMPT, ) from gpt_index.schema import BaseDocument +from gpt_index.utils import llm_token_counter RETRIEVE_MODE = "retrieve" @@ -182,6 +183,7 @@ class GPTTreeIndex(BaseGPTIndex[IndexGraph]): raise ValueError(f"Invalid query mode: {mode}.") return query + @llm_token_counter("build_index_from_documents") def build_index_from_documents( self, documents: Sequence[BaseDocument] ) -> IndexGraph: diff --git a/gpt_index/indices/tree/inserter.py b/gpt_index/indices/tree/inserter.py index 5c3c632c67e7c71a050ad13bf3ffa4cb3a6e1729..a8bd06e52c2184785d9ada8b2a725c0086fcce68 100644 --- a/gpt_index/indices/tree/inserter.py +++ b/gpt_index/indices/tree/inserter.py @@ -12,6 +12,7 @@ from gpt_index.prompts.default_prompts import ( DEFAULT_SUMMARY_PROMPT, ) from gpt_index.schema import BaseDocument +from gpt_index.utils import llm_token_counter class GPTIndexInserter: @@ -155,6 +156,7 @@ class GPTIndexInserter: parent_node.text = new_summary + @llm_token_counter("insert") def insert(self, doc: BaseDocument) -> None: """Insert into index_graph.""" text_chunks = self._text_splitter.split_text(doc.get_text()) diff --git a/gpt_index/langchain_helpers/chain_wrapper.py b/gpt_index/langchain_helpers/chain_wrapper.py index 390dff4c632cbe7e054f2cb1cbdea571017cc893..12e9bdfe8581663dc4480783f915ad06a0906475 100644 --- a/gpt_index/langchain_helpers/chain_wrapper.py +++ b/gpt_index/langchain_helpers/chain_wrapper.py @@ -6,6 +6,7 @@ from langchain import LLMChain, OpenAI from langchain.llms.base import LLM from gpt_index.prompts.base import Prompt +from gpt_index.utils import globals_helper class LLMPredictor: @@ -26,6 +27,8 @@ class LLMPredictor: def __init__(self, llm: Optional[LLM] = None) -> None: """Initialize params.""" self._llm = llm or OpenAI(temperature=0, model_name="text-davinci-002") + self._total_tokens_used = 0 + self.flag = True def predict(self, prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]: """Predict the answer to a query. @@ -39,6 +42,24 @@ class LLMPredictor: """ llm_chain = LLMChain(prompt=prompt, llm=self._llm) + # Note: we don't pass formatted_prompt to llm_chain.predict because + # langchain does the same formatting under the hood formatted_prompt = prompt.format(**prompt_args) full_prompt_args = prompt.get_full_format_args(prompt_args) - return llm_chain.predict(**full_prompt_args), formatted_prompt + llm_prediction = llm_chain.predict(**full_prompt_args) + + # We assume that the value of formatted_prompt is exactly the thing + # eventually sent to OpenAI, or whatever LLM downstream + prompt_tokens_count = self._count_tokens(formatted_prompt) + prediction_tokens_count = self._count_tokens(llm_prediction) + self._total_tokens_used += prompt_tokens_count + prediction_tokens_count + return llm_prediction, formatted_prompt + + @property + def total_tokens_used(self) -> int: + """Get the total tokens used so far.""" + return self._total_tokens_used + + def _count_tokens(self, text: str) -> int: + tokens = globals_helper.tokenizer(text) + return len(tokens) diff --git a/gpt_index/langchain_helpers/text_splitter.py b/gpt_index/langchain_helpers/text_splitter.py index ee41e4ac4b1a9fd504fac22b36509ece864700fa..746b9579b37edc480f5ee24d03e4272199d1d72f 100644 --- a/gpt_index/langchain_helpers/text_splitter.py +++ b/gpt_index/langchain_helpers/text_splitter.py @@ -39,13 +39,20 @@ class TokenTextSplitter(TextSplitter): current_doc: List[str] = [] total = 0 for d in splits: - num_tokens = len(self.tokenizer(d)["input_ids"]) + num_tokens = len(self.tokenizer(d)) + # If the total tokens in current_doc exceeds the chunk size: + # 1. Update the docs list if total + num_tokens > self._chunk_size: docs.append(self._separator.join(current_doc)) + # 2. Shrink the current_doc (from the front) until it is gets smaller + # than the overlap size while total > self._chunk_overlap: cur_tokens = self.tokenizer(current_doc[0]) - total -= len(cur_tokens["input_ids"]) + total -= len(cur_tokens) current_doc = current_doc[1:] + # 3. From here we can continue to build up the current_doc again + # Build up the current_doc with term d, and update the total counter with + # the number of the number of tokens in d, wrt self.tokenizer current_doc.append(d) total += num_tokens docs.append(self._separator.join(current_doc)) @@ -62,7 +69,7 @@ class TokenTextSplitter(TextSplitter): current_doc: List[str] = [] total = 0 for d in splits: - num_tokens = len(self.tokenizer(d)["input_ids"]) + num_tokens = len(self.tokenizer(d)) if total + num_tokens > self._chunk_size: break current_doc.append(d) diff --git a/gpt_index/utils.py b/gpt_index/utils.py index dfd084745642d4af1c18bae430bf705e37b89287..c526d0154cae1a8214ccc8db649e496581638d7e 100644 --- a/gpt_index/utils.py +++ b/gpt_index/utils.py @@ -1,7 +1,8 @@ """General utils functions.""" +import sys import uuid -from typing import List, Optional, Set +from typing import Any, Callable, List, Optional, Set import nltk from transformers import GPT2TokenizerFast @@ -15,14 +16,32 @@ class GlobalsHelper: """ - _tokenizer: Optional[GPT2TokenizerFast] = None + _tokenizer: Optional[Callable[[str], List]] = None _stopwords: Optional[List[str]] = None @property - def tokenizer(self) -> GPT2TokenizerFast: + def tokenizer(self) -> Callable[[str], List]: """Get tokenizer.""" if self._tokenizer is None: - self._tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + # if python version >= 3.9, then use tiktoken + # else use GPT2TokenizerFast + if sys.version_info >= (3, 9): + tiktoken_import_err = ( + "`tiktoken` package not found, please run `pip install tiktoken`" + ) + try: + import tiktoken + except ImportError: + raise ValueError(tiktoken_import_err) + enc = tiktoken.get_encoding("gpt2") + self._tokenizer = enc.encode + else: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + def tokenizer_fn(text: str) -> List: + return tokenizer(text)["input_ids"] + + self._tokenizer = tokenizer_fn return self._tokenizer @property @@ -50,3 +69,41 @@ def get_new_id(d: Set) -> str: if new_id not in d: break return new_id + + +def llm_token_counter(method_name_str: str) -> Callable: + """ + Use this as a decorator for methods in index/query classes that make calls to LLMs. + + At the moment, this decorator can only be used on class instance methods with a + `_llm_predictor` attribute. + + Do not use this on abstract methods. + + For example, consider the class below: + .. code-block:: python + class GPTTreeIndexBuilder: + ... + @llm_token_counter("build_from_text") + def build_from_text(self, documents: Sequence[BaseDocument]) -> IndexGraph: + ... + + If you run `build_from_text()`, it will print the output in the form below: + + ``` + [build_from_text] Total token usage: <some-number> tokens + ``` + """ + + def wrap(f: Callable) -> Callable: + def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any: + start_token_ct = _self._llm_predictor.total_tokens_used + f_return_val = f(_self, *args, **kwargs) + net_tokens = _self._llm_predictor.total_tokens_used - start_token_ct + print(f"> [{method_name_str}] Total token usage: {net_tokens} tokens") + + return f_return_val + + return wrapped_llm_predict + + return wrap diff --git a/setup.py b/setup.py index 940c404e9c04a1b791a1ea3fadfdfbc5ed730305..9f90886147a9b137dac4237333b1f8510e77380e 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ """Set up the package.""" +import sys from pathlib import Path from setuptools import find_packages, setup @@ -9,23 +10,29 @@ with open(Path(__file__).absolute().parents[0] / "gpt_index" / "VERSION") as _f: with open("README.md", "r") as f: long_description = f.read() +install_requires = [ + "langchain", + "openai", + "dataclasses_json", + "transformers", + "nltk", + # for openAI embeddings + "matplotlib", + "plotly", + "scipy", + "scikit-learn", +] + +# NOTE: if python version >= 3.9, install tiktoken +if sys.version_info >= (3, 9): + install_requires.extend(["tiktoken"]) + setup( name="gpt_index", version=__version__, packages=find_packages(), description="Building an index of GPT summaries.", - install_requires=[ - "langchain", - "openai", - "dataclasses_json", - "transformers", - "nltk", - # for openAI embeddings - "matplotlib", - "plotly", - "scipy", - "scikit-learn", - ], + install_requires=install_requires, long_description=long_description, license="MIT", url="https://github.com/jerryjliu/gpt_index", diff --git a/tests/indices/embedding/test_base.py b/tests/indices/embedding/test_base.py index 2a4e6cf7ea6b70d2c2f2eda3960564aef733bcb7..fb51de0eff019701c521961b78d1d61468f3406d 100644 --- a/tests/indices/embedding/test_base.py +++ b/tests/indices/embedding/test_base.py @@ -10,10 +10,13 @@ from gpt_index.embeddings.openai import OpenAIEmbedding from gpt_index.indices.data_structs import Node from gpt_index.indices.query.tree.embedding_query import GPTTreeIndexEmbeddingQuery from gpt_index.indices.tree.base import GPTTreeIndex -from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor +from gpt_index.langchain_helpers.chain_wrapper import LLMChain, LLMPredictor from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter from gpt_index.schema import Document -from tests.mock_utils.mock_predict import mock_openai_llm_predict +from tests.mock_utils.mock_predict import ( + mock_llmchain_predict, + mock_llmpredictor_predict, +) from tests.mock_utils.mock_prompts import ( MOCK_INSERT_PROMPT, MOCK_QUERY_PROMPT, @@ -80,8 +83,9 @@ def _get_node_text_embedding_similarities( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object( GPTTreeIndexEmbeddingQuery, "_get_query_text_embedding_similarities", @@ -91,6 +95,7 @@ def test_embedding_query( _mock_similarity: Any, _mock_predict: Any, _mock_init: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, struct_kwargs: Dict, documents: List[Document], @@ -103,3 +108,43 @@ def test_embedding_query( query_str = "What is?" response = tree.query(query_str, mode="embedding", **query_kwargs) assert response == ("What is?:Hello world.") + + +@patch.object(LLMChain, "predict", side_effect=mock_llmchain_predict) +@patch("gpt_index.langchain_helpers.chain_wrapper.OpenAI") +@patch.object(LLMChain, "__init__", return_value=None) +@patch.object( + GPTTreeIndexEmbeddingQuery, + "_get_query_text_embedding_similarities", + side_effect=_get_node_text_embedding_similarities, +) +def test_query_and_count_tokens( + _mock_similarity: Any, + _mock_llmchain: Any, + _mock_init: Any, + _mock_predict: Any, + struct_kwargs: Dict, + documents: List[Document], +) -> None: + """Test query and count tokens.""" + index_kwargs, query_kwargs = struct_kwargs + # mock_prompts.MOCK_SUMMARY_PROMPT_TMPL adds a "\n" to the document text + document_token_count = 24 + llmchain_mock_resp_token_count = 10 + # build the tree + tree = GPTTreeIndex(documents, **index_kwargs) + assert ( + tree._llm_predictor.total_tokens_used + == document_token_count + llmchain_mock_resp_token_count + ) + + # test embedding query + start_token_ct = tree._llm_predictor.total_tokens_used + query_str = "What is?" + # From MOCK_TEXT_QA_PROMPT, the prompt is 28 total + query_prompt_token_count = 28 + tree.query(query_str, mode="embedding", **query_kwargs) + assert ( + tree._llm_predictor.total_tokens_used - start_token_ct + == query_prompt_token_count + llmchain_mock_resp_token_count + ) diff --git a/tests/indices/keyword_table/test_base.py b/tests/indices/keyword_table/test_base.py index ced54b2714149a4c6c308d2d7309761bd1af089b..7ad6f031bb04a9825686dcd6eb48574a572bda40 100644 --- a/tests/indices/keyword_table/test_base.py +++ b/tests/indices/keyword_table/test_base.py @@ -27,13 +27,17 @@ def documents() -> List[Document]: @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) @patch( "gpt_index.indices.keyword_table.simple_base.simple_extract_keywords", mock_extract_keywords, ) def test_build_table( - _mock_init: Any, _mock_predict: Any, documents: List[Document] + _mock_init: Any, + _mock_total_tokens_used: Any, + _mock_split_text: Any, + documents: List[Document], ) -> None: """Test build table.""" # test simple keyword table @@ -65,8 +69,14 @@ def test_build_table( "gpt_index.indices.keyword_table.simple_base.simple_extract_keywords", mock_extract_keywords, ) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) -def test_insert(_mock_init: Any, _mock_predict: Any, documents: List[Document]) -> None: +def test_insert( + _mock_init: Any, + _mock_total_tokens_used: Any, + _mock_split_text: Any, + documents: List[Document], +) -> None: """Test insert.""" table = GPTSimpleKeywordTableIndex([]) assert len(table.index_struct.table.keys()) == 0 diff --git a/tests/indices/list/test_base.py b/tests/indices/list/test_base.py index 5b591dc0520266c89da22194b96bd84d1ef3d659..0d1fc7c8c737eda61dc8e6671a1642c37988254c 100644 --- a/tests/indices/list/test_base.py +++ b/tests/indices/list/test_base.py @@ -12,7 +12,7 @@ from gpt_index.indices.query.list.embedding_query import GPTListIndexEmbeddingQu from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter from gpt_index.schema import Document -from tests.mock_utils.mock_predict import mock_openai_llm_predict +from tests.mock_utils.mock_predict import mock_llmpredictor_predict from tests.mock_utils.mock_prompts import MOCK_REFINE_PROMPT, MOCK_TEXT_QA_PROMPT from tests.mock_utils.mock_text_splitter import mock_token_splitter_newline @@ -76,9 +76,13 @@ def test_build_list_multiple(_mock_init: Any, _mock_splitter: Any) -> None: @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) def test_list_insert( - _mock_init: Any, _mock_splitter: Any, documents: List[Document] + _mock_init: Any, + _mock_total_tokens_used: Any, + _mock_splitter: Any, + documents: List[Document], ) -> None: """Test insert to list.""" list_index = GPTListIndex([]) @@ -118,11 +122,13 @@ def _get_node_text_embedding_similarities( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) def test_query( _mock_predict: Any, _mock_init: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, @@ -138,8 +144,9 @@ def test_query( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object( GPTListIndexEmbeddingQuery, "_get_query_text_embedding_similarities", @@ -149,6 +156,7 @@ def test_embedding_query( _mock_similarity: Any, _mock_predict: Any, _mock_init: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, diff --git a/tests/indices/query/test_recursive.py b/tests/indices/query/test_recursive.py index 4f9ad045cc3e80cf9b6710a2b0a483362710e545..fc2c0f18b1147297e29b33f5ef068a7b1b28e0de 100644 --- a/tests/indices/query/test_recursive.py +++ b/tests/indices/query/test_recursive.py @@ -10,10 +10,13 @@ from gpt_index.indices.keyword_table.simple_base import GPTSimpleKeywordTableInd from gpt_index.indices.list.base import GPTListIndex from gpt_index.indices.query.schema import QueryConfig, QueryMode from gpt_index.indices.tree.base import GPTTreeIndex -from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor +from gpt_index.langchain_helpers.chain_wrapper import LLMChain, LLMPredictor from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter from gpt_index.schema import Document -from tests.mock_utils.mock_predict import mock_openai_llm_predict +from tests.mock_utils.mock_predict import ( + mock_llmchain_predict, + mock_llmpredictor_predict, +) from tests.mock_utils.mock_prompts import ( MOCK_INSERT_PROMPT, MOCK_KEYWORD_EXTRACT_PROMPT, @@ -90,10 +93,12 @@ def documents() -> List[Document]: @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) def test_recursive_query_list_tree( _mock_init: Any, + _mock_total_tokens_used: Any, _mock_predict: Any, _mock_split_text: Any, documents: List[Document], @@ -137,10 +142,12 @@ def test_recursive_query_list_tree( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) def test_recursive_query_tree_list( _mock_init: Any, + _mock_total_tokens_used: Any, _mock_predict: Any, _mock_split_text: Any, documents: List[Document], @@ -175,10 +182,12 @@ def test_recursive_query_tree_list( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) def test_recursive_query_table_list( _mock_init: Any, + _mock_total_tokens_used: Any, _mock_predict: Any, _mock_split_text: Any, documents: List[Document], @@ -210,10 +219,12 @@ def test_recursive_query_table_list( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) @patch.object(LLMPredictor, "__init__", return_value=None) def test_recursive_query_list_table( _mock_init: Any, + _mock_total_tokens_used: Any, _mock_predict: Any, _mock_split_text: Any, documents: List[Document], @@ -245,3 +256,53 @@ def test_recursive_query_list_table( query_str = "Cat?" response = table.query(query_str, mode="recursive", query_configs=query_configs) assert response == ("Cat?:This is another test.") + + +@patch.object(LLMChain, "predict", side_effect=mock_llmchain_predict) +@patch("gpt_index.langchain_helpers.chain_wrapper.OpenAI") +@patch.object(LLMChain, "__init__", return_value=None) +def test_recursive_query_list_tree_token_count( + _mock_init: Any, + _mock_llmchain: Any, + _mock_predict: Any, + documents: List[Document], + struct_kwargs: Dict, +) -> None: + """Test query.""" + index_kwargs, query_configs = struct_kwargs + list_kwargs = index_kwargs["list"] + tree_kwargs = index_kwargs["tree"] + # try building a list for every two, then a tree + list1 = GPTListIndex(documents[0:2], **list_kwargs) + list1.set_text("summary1") + list2 = GPTListIndex(documents[2:4], **list_kwargs) + list2.set_text("summary2") + list3 = GPTListIndex(documents[4:6], **list_kwargs) + list3.set_text("summary3") + list4 = GPTListIndex(documents[6:8], **list_kwargs) + list4.set_text("summary4") + + # there are two root nodes in this tree: one containing [list1, list2] + # and the other containing [list3, list4] + # import pdb; pdb.set_trace() + tree = GPTTreeIndex( + [ + list1, + list2, + list3, + list4, + ], + **tree_kwargs + ) + # first pass prompt is "summary1\nsummary2\n" (6 tokens), + # response is the mock response (10 tokens) + # total is 16 tokens, multiply by 2 to get the total + assert tree._llm_predictor.total_tokens_used == 32 + + query_str = "What is?" + # query should first pick the left root node, then pick list1 + # within list1, it should go through the first document and second document + start_token_ct = tree._llm_predictor.total_tokens_used + tree.query(query_str, mode="recursive", query_configs=query_configs) + # prompt is which is 35 tokens, plus 10 for the mock response + assert tree._llm_predictor.total_tokens_used - start_token_ct == 45 diff --git a/tests/indices/test_prompt_helper.py b/tests/indices/test_prompt_helper.py index c6fd45c394b9970f7577283312be797a13f426e1..f9b7fb963ca1641b1939352e2de992dc095a56ea 100644 --- a/tests/indices/test_prompt_helper.py +++ b/tests/indices/test_prompt_helper.py @@ -1,15 +1,15 @@ """Test PromptHelper.""" -from typing import Dict, List +from typing import List from gpt_index.indices.data_structs import Node from gpt_index.indices.prompt_helper import PromptHelper from gpt_index.prompts.base import Prompt -def mock_tokenizer(text: str) -> Dict[str, List[str]]: +def mock_tokenizer(text: str) -> List[str]: """Mock tokenizer.""" tokens = text.split(" ") - return {"input_ids": tokens} + return tokens def test_get_chunk_size() -> None: diff --git a/tests/indices/tree/test_base.py b/tests/indices/tree/test_base.py index eaddd34cbc354c9c4d8ec5c0148043c8f3ebe637..a00201482b91d89129ae4cda9241c477b781b0cc 100644 --- a/tests/indices/tree/test_base.py +++ b/tests/indices/tree/test_base.py @@ -7,10 +7,13 @@ import pytest from gpt_index.indices.data_structs import IndexGraph, Node from gpt_index.indices.tree.base import GPTTreeIndex -from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor +from gpt_index.langchain_helpers.chain_wrapper import LLMChain, LLMPredictor from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter from gpt_index.schema import Document -from tests.mock_utils.mock_predict import mock_openai_llm_predict +from tests.mock_utils.mock_predict import ( + mock_llmchain_predict, + mock_llmpredictor_predict, +) from tests.mock_utils.mock_prompts import ( MOCK_INSERT_PROMPT, MOCK_QUERY_PROMPT, @@ -67,11 +70,13 @@ def _get_left_or_right_node( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object(LLMPredictor, "__init__", return_value=None) def test_build_tree( _mock_init: Any, _mock_predict: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, @@ -92,11 +97,13 @@ def test_build_tree( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object(LLMPredictor, "__init__", return_value=None) def test_build_tree_multiple( _mock_init: Any, _mock_predict: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, @@ -117,11 +124,13 @@ def test_build_tree_multiple( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object(LLMPredictor, "__init__", return_value=None) def test_query( _mock_init: Any, _mock_predict: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, @@ -137,11 +146,13 @@ def test_query( @patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline) -@patch.object(LLMPredictor, "predict", side_effect=mock_openai_llm_predict) +@patch.object(LLMPredictor, "total_tokens_used", return_value=0) +@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict) @patch.object(LLMPredictor, "__init__", return_value=None) def test_insert( _mock_init: Any, _mock_predict: Any, + _mock_total_tokens_used: Any, _mock_split_text: Any, documents: List[Document], struct_kwargs: Dict, @@ -188,3 +199,26 @@ def test_insert( assert len(tree.index_struct.all_nodes) == 1 assert tree.index_struct.all_nodes[0].text == "This is a new doc." assert tree.index_struct.all_nodes[0].ref_doc_id == "new_doc_test" + + +@patch.object(LLMChain, "predict", side_effect=mock_llmchain_predict) +@patch("gpt_index.langchain_helpers.chain_wrapper.OpenAI") +@patch.object(LLMChain, "__init__", return_value=None) +def test_build_and_count_tokens( + _mock_init: Any, + _mock_llmchain: Any, + _mock_predict: Any, + documents: List[Document], + struct_kwargs: Dict, +) -> None: + """Test build and count tokens.""" + index_kwargs, _ = struct_kwargs + # mock_prompts.MOCK_SUMMARY_PROMPT_TMPL adds a "\n" to the document text + # and the document is 23 tokens + document_token_count = 24 + llmchain_mock_resp_token_count = 10 + tree = GPTTreeIndex(documents, **index_kwargs) + assert ( + tree._llm_predictor.total_tokens_used + == document_token_count + llmchain_mock_resp_token_count + ) diff --git a/tests/mock_utils/mock_predict.py b/tests/mock_utils/mock_predict.py index c1498273e93902f713744e038d44a74f4fe5d882..56a4e0a07c55494b083a067cbd11b9e88b583496 100644 --- a/tests/mock_utils/mock_predict.py +++ b/tests/mock_utils/mock_predict.py @@ -60,8 +60,8 @@ def _mock_query_keyword_extract(prompt_args: Dict) -> str: return mock_extract_keywords_response(prompt_args["question"]) -def mock_openai_llm_predict(prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]: - """Mock OpenAI LLM predict. +def mock_llmpredictor_predict(prompt: Prompt, **prompt_args: Any) -> Tuple[str, str]: + """Mock predict method of LLMPredictor. Depending on the prompt, return response. @@ -85,3 +85,8 @@ def mock_openai_llm_predict(prompt: Prompt, **prompt_args: Any) -> Tuple[str, st raise ValueError("Invalid prompt to use with mocks.") return response, formatted_prompt + + +def mock_llmchain_predict(**full_prompt_args: Any) -> str: + """Mock LLMChain predict with a generic response.""" + return "generic response from LLMChain.predict()" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2396d1c2c8f4ecd29085d4ddefbfdfcc30481231 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,14 @@ +"""Test utils.""" + +from gpt_index.utils import globals_helper + + +def test_tokenizer() -> None: + """Make sure tokenizer works. + + NOTE: we use a different tokenizer for python >= 3.9. + + """ + text = "hello world foo bar" + tokenizer = globals_helper.tokenizer + assert len(tokenizer(text)) == 4