Skip to content
Snippets Groups Projects
Unverified Commit 2e62c698 authored by Jerry Liu's avatar Jerry Liu Committed by GitHub
Browse files

Add initial version of GPTIndex (#1)


* implementation of feature

* add setup files

* added readme and example

* add license

Co-authored-by: default avatarJerry Liu <jerry@robustintelligence.com>
parent 446a7cf1
No related branches found
No related tags found
No related merge requests found
LICENSE 0 → 100644
The MIT License
Copyright (c) Jerry Liu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
\ No newline at end of file
# Testing
\ No newline at end of file
# GPT Tree Index
A tree-based index containing text data that is created using GPT-3 and can be traversed using GPT-3 in order to answer queries.
## Overview
GPT-3 is a phenomenonal piece of technology that at its core takes in text input and is able to generate text output. It is a very simple but general paradigm, and GPT-3 (especially the latest iterations) is able to execute this amazingly well. It is able to perform many tasks in a zero-shot setting, from sentiment analysis to categorization to question answering.
However, one fundamental limitation of GPT-3 is the context size. The most sophisticated model, Davinci, has a combined input+completion limit of 4096 tokens. This is large, but not infinite. As a result, the ability to feed "knowledge" to GPT-3 is mostly limited to this limited prompt size and model weights - these model weights by default encode world knowledge through the training process, but can also be finetuned for custom tasks (which can be very expensive).
But what if GPT-3 can have access to potentially a much larger database of knowledge for use in say, question-answering tasks? That's where the **GPT Tree Index** comes in. The GPT Tree Index first takes in a large dataset of unprocessed text data as input. It then builds up a tree-index in a bottom-up fashion; each parent node is able to summarize the children nodes using a general **summarization prompt**; each intermediate node containing summary text summarizing the components below. Once the index is built, it can be saved to disk and loaded for future use.
Then, say the user wants to use GPT-3 to answer a question. Using a **query prompt template**, GPT-3 will be able to recursively perform tree traversal in a top-down fashion in order to answer a question. For example, in the very beginning GPT-3 is tasked with selecting between *n* top-level nodes which best answers a provided query, by outputting a number as a multiple-choice problem. The GPT Tree Index then uses the number to select the corresponding node, and the process repeats recursively among the children nodes until a leaf node is reached.
The high-level intent of this project is to be a design-exercise to test the capability of GPT-3 as a general-purpose processor. A somewhat handwavy anaology is that a CPU processor has limited memory of its own but is able to have access to a wider base of stored knowledge (e.g. in RAM, and then on dks) in order to achieve the broader goal. We are making one step in this direction with the GPT Index, by having GPT build its index and traverse its index through repeated processing.
## Example Usage
An example is provided in `examples/test_wiki/TestNYC.ipynb`. To build the index do something like
```python
from gpt_db_retrieve.index import GPTIndex
GPTIndex.from_input_dir('data')
```
To save to disk and load from disk, do
```python
# save to disk
index.save_to_disk('index.json')
# load from disk
index = GPTIndex.load_from_disk('index.json')
```
To query,
```python
index.query("<question_text>?")
```
## Additional Thoughts / FAQ
**How is this better than an embeddings-based approach / other state-of-the-art QA and retrieval methods?**
The intent is not to compete against existing methods. A simpler embedding-based technique could be to just encode each chunk as an embedding and do a simple question-document embedding look-up to retrieve the result. This project is a simple exercise to test how GPT can organize and lookup information.
**Why build a tree? Why not just incremental go through each chunk?**
Algorithmically speaking, $O(\log N)$ is better than $O(N)$.
More broadly, building a tree helps us to test GPT's capabilities in modeling information in a hierarchy. It seems to me that our brains organize information in a similar way (citation needed). We can use this design to test how GPT can use its own hierarchy to answer questions.
Practically speaking, it is much cheaper to do so and I want to limit my monthly spending (see below for costs).
**This work is very similar to X paper/project.**
Please let me know! I am not up-to-date on the latest NLP ArXiv papers or Github projects. I will give the appropriate references/credit below.
**Does this actually work?**
Kind of. It works for simple queries, such as the prompt provided for the NYC Wikipedia data above ("What are the three main airports?"). Sometimes it fails in frustrating ways, where the correct node to choose given the query is obvious but GPT stil picks another node for some unforseen reason (for instance, given a query prompt on "What are the main ethnicities within NYC?", GPT-3 somehow picks a node which summarizes the architecture within Brooklyn). Some of this can be fixed with prompt tuning; this is an active area of work!
**How much does this cost to run?**
We currently use the Davinci model for good results. Unfortunately Davinci is quite expensive. The cost of building the tree is roughly
$cN\log(N)\frac{p}{1000}$, where $p=4096$ is the prompt limit and $c$ is the cost per 1000 tokens ($0.02 as mentioned on the [pricing page](https://openai.com/api/pricing/)). The cost of querying the tree is roughly
$c\log(N)\frac{p}{1000}$.
For the NYC example, this equates to \$~0.40 per query.
## Dependencies
The main third-party package requirements are `transformers`, `openai`, and `langchain`.
All requirements should be contained within the `setup.py` file. To run the package locally without building the wheel, simply do `pip install -r requirements.txt`.
## Future Directions
- Add ability to insert/delete.
- Add ability to more easily customize prompts.
- Add different index structures beyond trees.
- Add ability for GPT itself to reason about connections between nodes.
%% Cell type:code id:b4b4387b-413e-4016-ba1e-88b3d9410a38 tags:
``` python
# fetch "New York City" page from Wikipedia
from pathlib import Path
import requests
response = requests.get(
'https://en.wikipedia.org/w/api.php',
params={
'action': 'query',
'format': 'json',
'titles': 'New York City',
'prop': 'extracts',
# 'exintro': True,
'explaintext': True,
}
).json()
page = next(iter(response['query']['pages'].values()))
nyc_text = page['extract']
data_path = Path('data')
if not data_path.exists():
Path.mkdir(data_path)
with open('data/nyc_text.txt', 'w') as fp:
fp.write(nyc_text)
```
%% Cell type:code id:f1a9eb90-335c-4214-8bb6-fd1edbe3ccbd tags:
``` python
# My OpenAI Key
import os
os.environ['OPENAI_API_KEY'] = "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"
```
%% Cell type:code id:8d0b2364-4806-4656-81e7-3f6e4b910b5b tags:
``` python
from gpt_db_retrieve.index import GPTIndex
```
%% Cell type:code id:1298bbb4-c99e-431e-93ef-eb32c0a2fc2a tags:
``` python
index = GPTIndex.from_input_dir('data')
```
%% Cell type:code id:0b4fe9b6-5762-4e86-b51e-aac45d3ecdb1 tags:
``` python
index.save_to_disk('index.json')
```
%% Cell type:code id:5eec265d-211b-4f26-b05b-5b4e7072bc6e tags:
``` python
# try loading
new_index = GPTIndex.load_from_disk('index.json')
```
%% Cell type:code id:68c9ebfe-b1b6-4f4e-9278-174346de8c90 tags:
``` python
new_index.query("What the three main airports within New York City?")
```
This diff is collapsed.
This diff is collapsed.
0.0.1
\ No newline at end of file
"""Utilities for loading data from files."""
from pathlib import Path
class SimpleDirectoryReader:
"""Utilities for loading data from a directory."""
def __init__(self, input_dir: Path) -> None:
self.input_dir = input_dir
input_files = list(input_dir.iterdir())
for input_file in input_files:
if not input_file.is_file():
raise ValueError(f"Expected {input_file} to be a file.")
self.input_files = input_files
def load_data(self) -> str:
"""Loads data from the input directory."""
data = ""
for input_file in self.input_files:
with open(input_file, "r") as f:
data += f.read()
data += "\n"
return data
\ No newline at end of file
"""Core abstractions for building an index of GPT data."""
from dataclasses import dataclass
from dataclasses_json import DataClassJsonMixin, Undefined, dataclass_json
from pathlib import Path
from gpt_db_retrieve.file_reader import SimpleDirectoryReader
from langchain.text_splitter import CharacterTextSplitter
from langchain import OpenAI, Prompt, LLMChain
from gpt_db_retrieve.prompts import DEFAULT_SUMMARY_PROMPT, DEFAULT_QUERY_PROMPT, DEFAULT_TEXT_QA_PROMPT
from gpt_db_retrieve.utils import get_chunk_size_given_prompt, extract_number_given_response
from gpt_db_retrieve.text_splitter import TokenTextSplitter
from typing import List
import json
MAX_CHUNK_SIZE = 3900
MAX_CHUNK_OVERLAP = 200
NUM_OUTPUTS = 256
@dataclass
class Node(DataClassJsonMixin):
"""A node in the GPT index."""
text: str
index: int
child_indices: List[int]
@dataclass
class IndexGraph(DataClassJsonMixin):
all_nodes: List[Node]
root_nodes: List[Node]
def _get_text_from_nodes(nodes: List[Node]) -> str:
"""Get text from nodes."""
text = ""
for node in nodes:
text += node.text
text += "\n"
return text
def _get_numbered_text_from_nodes(nodes: List[Node]) -> str:
"""Get text from nodes in the format of a numbered list."""
text = ""
number = 1
for node in nodes:
text += f"({number}) {' '.join(node.text.splitlines())}"
text += "\n\n"
number += 1
return text
class GPTIndexBuilder:
"""GPT Index builder."""
def __init__(
self,
num_children: int = 10,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT
) -> None:
"""Initialize with params."""
self.num_children = num_children
# instantiate LLM
summary_prompt_obj = Prompt(template=summary_prompt, input_variables=["text"])
llm = OpenAI(temperature=0)
self.llm_chain = LLMChain(prompt=summary_prompt_obj, llm=llm)
chunk_size = get_chunk_size_given_prompt(
summary_prompt.format(text=""), MAX_CHUNK_SIZE, num_children, NUM_OUTPUTS
)
self.text_splitter = TokenTextSplitter(
separator=" ",
chunk_size=chunk_size,
chunk_overlap=MAX_CHUNK_OVERLAP // num_children
)
def build_from_text(self, text: str) -> IndexGraph:
"""Build from text.
Returns:
IndexGraph: graph object consisting of all_nodes, root_nodes
"""
text_chunks = self.text_splitter.split_text(text)
# instantiate all_nodes from initial text chunks
all_nodes = [Node(t, i, []) for i, t in enumerate(text_chunks)]
root_nodes = self._build_index_from_nodes(all_nodes, all_nodes)
return IndexGraph(all_nodes, root_nodes)
def _build_index_from_nodes(
self, cur_nodes: List[Node], all_nodes: List[Node]
) -> List[Node]:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
cur_index = len(all_nodes)
new_node_list = []
print(f'building index from nodes: {len(cur_nodes) // self.num_children} chunks')
for i in range(0, len(cur_nodes), self.num_children):
print(f'{i}/{len(cur_nodes)}')
cur_nodes_chunk = cur_nodes[i:i+self.num_children]
text_chunk = _get_text_from_nodes(cur_nodes_chunk)
new_summary = self.llm_chain.predict(text=text_chunk)
print(f'{i}/{len(cur_nodes)}, summary: {new_summary}')
new_node = Node(new_summary, cur_index, [n.index for n in cur_nodes_chunk])
new_node_list.append(new_node)
cur_index += 1
all_nodes.extend(new_node_list)
if len(new_node_list) <= self.num_children:
return new_node_list
else:
return self._build_index_from_nodes(new_node_list, all_nodes)
@dataclass
class GPTIndex(DataClassJsonMixin):
"""GPT Index."""
graph: IndexGraph
query_template: str = DEFAULT_QUERY_PROMPT
text_qa_template: str = DEFAULT_TEXT_QA_PROMPT
def _query(self, cur_nodes: List[Node], query_str: str, verbose: bool = False) -> str:
"""Answer a query recursively."""
query_prompt = Prompt(
template=self.query_template,
input_variables=["num_chunks", "context_list", "query_str"]
)
llm = OpenAI(temperature=0)
llm_chain = LLMChain(prompt=query_prompt, llm=llm)
response = llm_chain.predict(
query_str=query_str,
num_chunks=len(cur_nodes),
context_list=_get_numbered_text_from_nodes(cur_nodes)
)
if verbose:
formatted_query = self.query_template.format(
num_chunks=len(cur_nodes),
query_str=query_str,
context_list=_get_numbered_text_from_nodes(cur_nodes)
)
print(f'==============')
print(f'current prompt template: {formatted_query}')
print(f'cur query response: {response}')
number = extract_number_given_response(response)
if number is None:
print(f"Could not retrieve response - no numbers present")
# just join text from current nodes as response
return _get_text_from_nodes(cur_nodes)
elif number > len(cur_nodes):
print(f'Invalid response: {response} - number {number} out of range')
return response
# number is 1-indexed, so subtract 1
selected_node = cur_nodes[number-1]
if len(selected_node.child_indices) == 0:
answer_prompt = Prompt(
template=self.text_qa_template,
input_variables=["context_str", "query_str"]
)
llm_chain = LLMChain(prompt=answer_prompt, llm=llm)
response = llm_chain.predict(
context_str=selected_node.text,
query_str=query_str
)
return response
else:
return self._query(
[self.graph.all_nodes[i] for i in selected_node.child_indices],
query_str
)
def query(self, query_str: str, verbose: bool = False) -> str:
"""Answer a query."""
if verbose:
print('Starting query: {query_str}')
return self._query(self.graph.root_nodes, query_str, verbose=verbose).strip()
@classmethod
def from_input_dir(
cls,
input_dir: str,
index_builder: GPTIndexBuilder = GPTIndexBuilder()
) -> "GPTIndex":
"""Builds an index from an input directory.
Uses the default index builder.
"""
input_dir = Path(input_dir)
# instantiate file reader
reader = SimpleDirectoryReader(input_dir)
text_data = reader.load_data()
# Use index builder
index_graph = index_builder.build_from_text(text_data)
return cls(index_graph)
@classmethod
def load_from_disk(cls, save_path: str) -> None:
"""Load from disk."""
with open(save_path, "r") as f:
return cls.from_dict(json.load(f))
def save_to_disk(self, save_path: str) -> None:
"""Safe to file."""
with open(save_path, "w") as f:
json.dump(self.to_dict(), f)
if __name__ == "__main__":
print('hello world')
\ No newline at end of file
"""Set of default prompts."""
DEFAULT_SUMMARY_PROMPT = (
"Write a concise summary of the following:\n"
"\n"
"\n"
"{text}\n"
"\n"
"\n"
"CONCISE SUMMARY:\"\"\"\n"
)
DEFAULT_QUERY_PROMPT = (
"Context information is below. It is provided in a numbered list (1 to {num_chunks}),"
"where each item in the list corresponds to a summary.\n"
"---------------------\n"
"{context_list}"
"---------------------\n"
"Given the context information, answer the question: {query_str}\n"
"The answer should be the number corresponding to the "
"summary that is most relevant to the question.\n"
)
DEFAULT_TEXT_QA_PROMPT = (
"Context information is below. "
"---------------------\n"
"{context_str}"
"---------------------\n"
"Given the context information, answer the question: {query_str}\n"
)
\ No newline at end of file
from langchain.text_splitter import TextSplitter
from typing import List
from transformers import GPT2TokenizerFast
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at word tokens."""
def __init__(
self, separator: str = " ", chunk_size: int = 4000, chunk_overlap: int = 200
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._separator = separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
if total > self._chunk_size:
docs.append(self._separator.join(current_doc))
while total > self._chunk_overlap:
cur_tokens = self.tokenizer(current_doc[0])
total -= len(cur_tokens["input_ids"])
current_doc = current_doc[1:]
current_doc.append(d)
num_tokens = len(self.tokenizer(d)['input_ids'])
total += num_tokens
docs.append(self._separator.join(current_doc))
return docs
\ No newline at end of file
"""Utils file."""
from transformers import GPT2TokenizerFast
from typing import Optional
import re
def get_chunk_size_given_prompt(
prompt: str, max_input_size: int, num_chunks: int, num_output: int
) -> int:
"""Get chunk size making sure we can also fit the prompt in."""
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
prompt_tokens = tokenizer(prompt)
num_prompt_tokens = len(prompt_tokens["input_ids"])
return (max_input_size - num_prompt_tokens - num_output) // num_chunks
def extract_number_given_response(response: str) -> Optional[int]:
"""Extract number given the GPT-generated response."""
numbers = re.findall(r'\d+', response)
if len(numbers) == 0:
return None
else:
return int(numbers[0])
\ No newline at end of file
-e .
\ No newline at end of file
setup.py 0 → 100644
"""Set up the package."""
from pathlib import Path
from setuptools import find_packages, setup
with open(Path(__file__).absolute().parents[0] / "gpt_db_retrieve" / "VERSION") as _f:
__version__ = _f.read().strip()
with open("README.md", "r") as f:
long_description = f.read()
setup(
name="gpt_db_retrieve",
version=__version__,
packages=find_packages(),
description="Building an index of GPT summaries.",
install_requires=["langchain", "openai", "dataclasses_json", "transformers"],
long_description=long_description,
license="MIT",
url="https://github.com/jerryjliu/gpt_db_retrieve",
include_package_data=True,
long_description_content_type="text/markdown",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment