diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..0fb583996cdc52a8429eff37f8fb6002a96f2e11 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +exclude = + .venv + __pycache__ + notebooks + .ipynb_checkpoints +# Recommend matching the black line length (default 88), +# rather than using the flake8 default of 79: +max-line-length = 88 +extend-ignore = + # See https://github.com/PyCQA/pycodestyle/issues/373 + E203, diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..17d753d6a0bbc64eb30c7a71beeeefa9bc6d37ba --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +.vscode/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..346be9bb3350e54aba5e148b28345c525e1df2d0 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +.PHONY: format lint + +format: + black . + isort . + +lint: + mypy . + black . --check + isort . --check + flake8 . \ No newline at end of file diff --git a/gpt_index/__init__.py b/gpt_index/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c7f52fec30622c48bcb140c269b96c7b8231b1fe 100644 --- a/gpt_index/__init__.py +++ b/gpt_index/__init__.py @@ -0,0 +1 @@ +"""Init file of GPT Index.""" diff --git a/gpt_index/file_reader.py b/gpt_index/file_reader.py index ca369de980352b0ca63b435d85c3c749c8e4724c..23e303af8d183013b38d9ab82ade4e907501caad 100644 --- a/gpt_index/file_reader.py +++ b/gpt_index/file_reader.py @@ -5,7 +5,9 @@ from pathlib import Path class SimpleDirectoryReader: """Utilities for loading data from a directory.""" + def __init__(self, input_dir: Path) -> None: + """Initialize with parameters.""" self.input_dir = input_dir input_files = list(input_dir.iterdir()) for input_file in input_files: @@ -14,10 +16,10 @@ class SimpleDirectoryReader: self.input_files = input_files def load_data(self) -> str: - """Loads data from the input directory.""" + """Load data from the input directory.""" data = "" for input_file in self.input_files: with open(input_file, "r") as f: data += f.read() data += "\n" - return data \ No newline at end of file + return data diff --git a/gpt_index/index.py b/gpt_index/index.py index 5ce559655f664f5a8b8850404bc3f29c355746c4..5553244a434ebe515c0e76ce2cddc56b05e3b09d 100644 --- a/gpt_index/index.py +++ b/gpt_index/index.py @@ -1,44 +1,28 @@ """Core abstractions for building an index of GPT data.""" +import json from dataclasses import dataclass -from dataclasses_json import DataClassJsonMixin, Undefined, dataclass_json from pathlib import Path -from gpt_index.file_reader import SimpleDirectoryReader -from langchain.text_splitter import CharacterTextSplitter -from langchain import OpenAI, Prompt, LLMChain -from gpt_index.prompts import DEFAULT_SUMMARY_PROMPT, DEFAULT_QUERY_PROMPT, DEFAULT_TEXT_QA_PROMPT -from gpt_index.utils import get_chunk_size_given_prompt, extract_number_given_response -from gpt_index.text_splitter import TokenTextSplitter +from typing import Dict, List -from typing import List, Dict, Set -import json +from dataclasses_json import DataClassJsonMixin +from langchain import LLMChain, OpenAI, Prompt +from gpt_index.file_reader import SimpleDirectoryReader +from gpt_index.prompts import ( + DEFAULT_QUERY_PROMPT, + DEFAULT_SUMMARY_PROMPT, + DEFAULT_TEXT_QA_PROMPT, +) +from gpt_index.schema import IndexGraph, Node +from gpt_index.text_splitter import TokenTextSplitter +from gpt_index.utils import extract_number_given_response, get_chunk_size_given_prompt MAX_CHUNK_SIZE = 3900 MAX_CHUNK_OVERLAP = 200 NUM_OUTPUTS = 256 - -@dataclass -class Node(DataClassJsonMixin): - """A node in the GPT index.""" - - text: str - index: int - child_indices: Set[int] - - -@dataclass -class IndexGraph(DataClassJsonMixin): - all_nodes: Dict[int, Node] - root_nodes: Dict[int, Node] - - @property - def size(self): - return len(self.all_nodes) - - def _get_sorted_node_list(node_dict: Dict[int, Node]) -> List[Node]: sorted_indices = sorted(node_dict.keys()) return [node_dict[index] for index in sorted_indices] @@ -68,9 +52,7 @@ class GPTIndexBuilder: """GPT Index builder.""" def __init__( - self, - num_children: int = 10, - summary_prompt: str = DEFAULT_SUMMARY_PROMPT + self, num_children: int = 10, summary_prompt: str = DEFAULT_SUMMARY_PROMPT ) -> None: """Initialize with params.""" self.num_children = num_children @@ -83,8 +65,8 @@ class GPTIndexBuilder: ) self.text_splitter = TokenTextSplitter( separator=" ", - chunk_size=chunk_size, - chunk_overlap=MAX_CHUNK_OVERLAP // num_children + chunk_size=chunk_size, + chunk_overlap=MAX_CHUNK_OVERLAP // num_children, ) def build_from_text(self, text: str) -> IndexGraph: @@ -92,7 +74,7 @@ class GPTIndexBuilder: Returns: IndexGraph: graph object consisting of all_nodes, root_nodes - + """ text_chunks = self.text_splitter.split_text(text) @@ -108,18 +90,18 @@ class GPTIndexBuilder: cur_node_list = _get_sorted_node_list(cur_nodes) cur_index = len(all_nodes) new_node_dict = {} - print(f'> Building index from nodes: {len(cur_nodes) // self.num_children} chunks') + print( + f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks" + ) for i in range(0, len(cur_node_list), self.num_children): - print(f'{i}/{len(cur_nodes)}') - cur_nodes_chunk = cur_node_list[i:i+self.num_children] + print(f"{i}/{len(cur_nodes)}") + cur_nodes_chunk = cur_node_list[i : i + self.num_children] text_chunk = _get_text_from_nodes(cur_nodes_chunk) new_summary = self.llm_chain.predict(text=text_chunk) - print(f'> {i}/{len(cur_nodes)}, summary: {new_summary}') - new_node = Node( - new_summary, cur_index, {n.index for n in cur_nodes_chunk} - ) + print(f"> {i}/{len(cur_nodes)}, summary: {new_summary}") + new_node = Node(new_summary, cur_index, {n.index for n in cur_nodes_chunk}) new_node_dict[cur_index] = new_node cur_index += 1 @@ -139,89 +121,87 @@ class GPTIndex(DataClassJsonMixin): query_template: str = DEFAULT_QUERY_PROMPT text_qa_template: str = DEFAULT_TEXT_QA_PROMPT - def _query(self, cur_nodes: Dict[int, Node], query_str: str, verbose: bool = False) -> str: + def _query( + self, cur_nodes: Dict[int, Node], query_str: str, verbose: bool = False + ) -> str: """Answer a query recursively.""" cur_node_list = _get_sorted_node_list(cur_nodes) query_prompt = Prompt( - template=self.query_template, - input_variables=["num_chunks", "context_list", "query_str"] + template=self.query_template, + input_variables=["num_chunks", "context_list", "query_str"], ) llm = OpenAI(temperature=0) - llm_chain = LLMChain(prompt=query_prompt, llm=llm) + llm_chain = LLMChain(prompt=query_prompt, llm=llm) response = llm_chain.predict( query_str=query_str, - num_chunks=len(cur_node_list), - context_list=_get_numbered_text_from_nodes(cur_node_list) + num_chunks=len(cur_node_list), + context_list=_get_numbered_text_from_nodes(cur_node_list), ) - + if verbose: formatted_query = self.query_template.format( num_chunks=len(cur_node_list), query_str=query_str, - context_list=_get_numbered_text_from_nodes(cur_node_list) + context_list=_get_numbered_text_from_nodes(cur_node_list), ) - print(f'==============') - print(f'> current prompt template: {formatted_query}') + print("==============") + print(f"> current prompt template: {formatted_query}") number = extract_number_given_response(response) if number is None: if verbose: - print(f"> Could not retrieve response - no numbers present") + print("> Could not retrieve response - no numbers present") # just join text from current nodes as response return response elif number > len(cur_node_list): if verbose: - print(f'> Invalid response: {response} - number {number} out of range') + print(f"> Invalid response: {response} - number {number} out of range") return response # number is 1-indexed, so subtract 1 - selected_node = cur_node_list[number-1] + selected_node = cur_node_list[number - 1] print(f"> Selected node: {response}") print(f"> Node Summary text: {' '.join(selected_node.text.splitlines())}") if len(selected_node.child_indices) == 0: answer_prompt = Prompt( - template=self.text_qa_template, - input_variables=["context_str", "query_str"] + template=self.text_qa_template, + input_variables=["context_str", "query_str"], ) - llm_chain = LLMChain(prompt=answer_prompt, llm=llm) + llm_chain = LLMChain(prompt=answer_prompt, llm=llm) response = llm_chain.predict( - context_str=selected_node.text, - query_str=query_str + context_str=selected_node.text, query_str=query_str ) if verbose: formatted_answer_prompt = self.text_qa_template.format( - context_str=selected_node.text, - query_str=query_str + context_str=selected_node.text, query_str=query_str ) - print('==============') - print(f'> final answer prompt: {formatted_answer_prompt}') + print("==============") + print(f"> final answer prompt: {formatted_answer_prompt}") return response else: return self._query( - {i: self.graph.all_nodes[i] for i in selected_node.child_indices}, + {i: self.graph.all_nodes[i] for i in selected_node.child_indices}, query_str, - verbose=verbose + verbose=verbose, ) def query(self, query_str: str, verbose: bool = False) -> str: """Answer a query.""" - print(f'> Starting query: {query_str}') + print(f"> Starting query: {query_str}") return self._query(self.graph.root_nodes, query_str, verbose=verbose).strip() - + @classmethod def from_input_dir( - cls, - input_dir: str, - index_builder: GPTIndexBuilder = GPTIndexBuilder() + cls, input_dir: str, index_builder: GPTIndexBuilder = GPTIndexBuilder() ) -> "GPTIndex": - """Builds an index from an input directory. + """Build an index from an input directory. Uses the default index builder. - + """ - input_dir = Path(input_dir) + input_d = Path(input_dir) # instantiate file reader - reader = SimpleDirectoryReader(input_dir) + reader = SimpleDirectoryReader(input_d) text_data = reader.load_data() # Use index builder @@ -229,7 +209,7 @@ class GPTIndex(DataClassJsonMixin): return cls(index_graph) @classmethod - def load_from_disk(cls, save_path: str) -> None: + def load_from_disk(cls, save_path: str) -> "GPTIndex": """Load from disk.""" with open(save_path, "r") as f: return cls(graph=IndexGraph.from_dict(json.load(f))) @@ -238,6 +218,3 @@ class GPTIndex(DataClassJsonMixin): """Safe to file.""" with open(save_path, "w") as f: json.dump(self.graph.to_dict(), f) - - - \ No newline at end of file diff --git a/gpt_index/prompts.py b/gpt_index/prompts.py index 4104c9ade69f958db0aa6ec54f5480f534954c30..9cd1fc17a8a326d9523ef059758d00d9bd49d955 100644 --- a/gpt_index/prompts.py +++ b/gpt_index/prompts.py @@ -1,32 +1,35 @@ """Set of default prompts.""" DEFAULT_SUMMARY_PROMPT = ( - "Write a summary of the following. Try to use only the information provided. " + "Write a summary of the following. Try to use only the " + "information provided. " "Try to include as many key details as possible.\n" "\n" "\n" "{text}\n" "\n" "\n" - "SUMMARY:\"\"\"\n" + 'SUMMARY:"""\n' ) # # single choice DEFAULT_QUERY_PROMPT = ( - "Some choices are given below. It is provided in a numbered list (1 to {num_chunks})," + "Some choices are given below. It is provided in a numbered list " + "(1 to {num_chunks})," "where each item in the list corresponds to a summary.\n" "---------------------\n" "{context_list}" "\n---------------------\n" - "Using only the choices above and not prior knowledge, return the choice that " - "is most relevant to the question: '{query_str}'\n" + "Using only the choices above and not prior knowledge, return " + "the choice that is most relevant to the question: '{query_str}'\n" "Provide choice in the following format: 'ANSWER: <number>' and explain why " "this summary was selected in relation to the question.\n" ) # multiple choice # DEFAULT_QUERY_PROMPT = ( -# "Some choices are given below. It is provided in a numbered list (1 to {num_chunks})," +# "Some choices are given below. It is provided in a numbered " +# "list (1 to {num_chunks}), " # "where each item in the list corresponds to a summary.\n" # "---------------------\n" # "{context_list}" @@ -44,5 +47,6 @@ DEFAULT_TEXT_QA_PROMPT = ( "---------------------\n" "{context_str}" "\n---------------------\n" - "Given the context information and not prior knowledge, answer the question: {query_str}\n" -) \ No newline at end of file + "Given the context information and not prior knowledge, " + "answer the question: {query_str}\n" +) diff --git a/gpt_index/schema.py b/gpt_index/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c44d91bb17485422d08946884abd65b5d98afeef --- /dev/null +++ b/gpt_index/schema.py @@ -0,0 +1,27 @@ +"""Base schema for data structures.""" +from dataclasses import dataclass +from typing import Dict, Set + +from dataclasses_json import DataClassJsonMixin + + +@dataclass +class Node(DataClassJsonMixin): + """A node in the GPT tree index.""" + + text: str + index: int + child_indices: Set[int] + + +@dataclass +class IndexGraph(DataClassJsonMixin): + """A graph representing the tree-structured index.""" + + all_nodes: Dict[int, Node] + root_nodes: Dict[int, Node] + + @property + def size(self) -> int: + """Get the size of the graph.""" + return len(self.all_nodes) diff --git a/gpt_index/text_splitter.py b/gpt_index/text_splitter.py index f3257338dd93c9d91f2513953ea1f8feb1549103..3098863b052dd9a18250f41414c5069dd95cf550 100644 --- a/gpt_index/text_splitter.py +++ b/gpt_index/text_splitter.py @@ -1,6 +1,7 @@ -from langchain.text_splitter import TextSplitter +"""Text splitter implementations.""" from typing import List +from langchain.text_splitter import TextSplitter from transformers import GPT2TokenizerFast @@ -38,7 +39,7 @@ class TokenTextSplitter(TextSplitter): total -= len(cur_tokens["input_ids"]) current_doc = current_doc[1:] current_doc.append(d) - num_tokens = len(self.tokenizer(d)['input_ids']) + num_tokens = len(self.tokenizer(d)["input_ids"]) total += num_tokens docs.append(self._separator.join(current_doc)) - return docs \ No newline at end of file + return docs diff --git a/gpt_index/utils.py b/gpt_index/utils.py index f15251f6bb8de89689a777cf9ef9ba15c765e4e0..6d6e3b16dfc90521c891058ad26297e64cb627eb 100644 --- a/gpt_index/utils.py +++ b/gpt_index/utils.py @@ -1,8 +1,9 @@ """Utils file.""" -from transformers import GPT2TokenizerFast -from typing import Optional import re +from typing import Optional + +from transformers import GPT2TokenizerFast def get_chunk_size_given_prompt( @@ -18,8 +19,8 @@ def get_chunk_size_given_prompt( def extract_number_given_response(response: str) -> Optional[int]: """Extract number given the GPT-generated response.""" - numbers = re.findall(r'\d+', response) + numbers = re.findall(r"\d+", response) if len(numbers) == 0: return None else: - return int(numbers[0]) \ No newline at end of file + return int(numbers[0]) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..619e1edb280ec4eef47408153bc38e022da274e1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.isort] +profile = "black" + +[tool.mypy] +ignore_missing_imports = "True" +disallow_untyped_defs = "True" +exclude = ["notebooks"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ecf975e2fa63a1383916f0c755e0db3e0b3d62b2..a211e87c815a32dc1d0100d5375f0796bc06a3f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,8 @@ --e . \ No newline at end of file +-e . +# linting +black +isort +mypy +# flake8 +# flake8-docstrings +pylint