Skip to content
Snippets Groups Projects
Unverified Commit 6d5c445e authored by Jerry Liu's avatar Jerry Liu Committed by GitHub
Browse files

[experimental] "train" classifier on titanic with GPT-3 (#59)

parent 5a66cef4
No related branches found
No related tags found
No related merge requests found
# 🧪 Experimental
This section is for experiments, cool ideas, and more!
Code here lives outside the base package. If a project is sufficiently interesting and validated, then we will move it into the core abstractions.
\ No newline at end of file
%% Cell type:code id:f445c1d1-acb9-431e-a7ff-50c41f064359 tags:
``` python
from utils import (
get_train_str,
get_train_and_eval_data,
get_eval_preds,
train_prompt
)
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
```
%% Output
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
[nltk_data] Downloading package stopwords to
[nltk_data] /Users/jerryliu/nltk_data...
[nltk_data] Package stopwords is already up-to-date!
%% Cell type:code id:cf3cbd90-d5e1-4c30-a3bc-8b39fbd85d70 tags:
``` python
# load up the titanic data
train_df, train_labels, eval_df, eval_labels = get_train_and_eval_data('data/train.csv')
```
%% Cell type:markdown id:fa2634f9-cb33-4f1e-81f9-3a3b285e2580 tags:
## Few-shot Prompting with GPT-3 for Titanic Dataset
In this section, we can show how we can prompt GPT-3 on its own (without using GPT Index) to attain ~80% accuracy on Titanic!
We can do this by simply providing a few example inputs. Or we can simply provide no example inputs at all (zero-shot). Both achieve the same results.
%% Cell type:code id:d0698fd2-1361-49ae-8c17-8124e9b932a4 tags:
``` python
# first demonstrate the prompt template
print(train_prompt.template)
```
%% Output
The following structured data is provided in "Feature Name":"Feature Value" format.
Each datapoint describes a passenger on the Titanic.
The task is to decide whether the passenger survived.
Some example datapoints are given below:
-------------------
{train_str}
-------------------
Given this, predict whether the following passenger survived. Return answer as a number between 0 or 1.
{eval_str}
Survived:
%% Cell type:code id:4b39e2e7-be07-42f8-a27a-3419e84cfb2c tags:
``` python
# Get "training" prompt string
train_n = 10
eval_n = 40
train_str = get_train_str(train_df, train_labels, train_n=train_n)
print(f"Example datapoints in `train_str`: \n{train_str}")
```
%% Output
Example datapoints in `train_str`:
This is the Data:
Age:28.0
Embarked:S
Fare:7.8958
Parch:0
Pclass:3
Sex:male
SibSp:0
This is the correct answer:
Survived: 0
This is the Data:
Age:17.0
Embarked:S
Fare:7.925
Parch:2
Pclass:3
Sex:female
SibSp:4
This is the correct answer:
Survived: 1
This is the Data:
Age:30.0
Embarked:S
Fare:16.1
Parch:0
Pclass:3
Sex:male
SibSp:1
This is the correct answer:
Survived: 0
This is the Data:
Age:22.0
Embarked:S
Fare:7.25
Parch:0
Pclass:3
Sex:male
SibSp:0
This is the correct answer:
Survived: 0
This is the Data:
Age:45.0
Embarked:S
Fare:13.5
Parch:0
Pclass:2
Sex:female
SibSp:0
This is the correct answer:
Survived: 1
This is the Data:
Age:25.0
Embarked:S
Fare:0.0
Parch:0
Pclass:3
Sex:male
SibSp:0
This is the correct answer:
Survived: 1
This is the Data:
Age:18.0
Embarked:S
Fare:20.2125
Parch:1
Pclass:3
Sex:male
SibSp:1
This is the correct answer:
Survived: 0
This is the Data:
Age:33.0
Embarked:S
Fare:9.5
Parch:0
Pclass:3
Sex:male
SibSp:0
This is the correct answer:
Survived: 0
This is the Data:
Age:24.0
Embarked:S
Fare:65.0
Parch:2
Pclass:2
Sex:female
SibSp:1
This is the correct answer:
Survived: 1
This is the Data:
Age:26.0
Embarked:S
Fare:7.925
Parch:0
Pclass:3
Sex:female
SibSp:0
This is the correct answer:
Survived: 1
%% Cell type:markdown id:819a06f7-3171-4edb-b90c-0a3eae308a04 tags:
#### Do evaluation with the training prompt string
%% Cell type:code id:4a7f2202-518c-41a3-80ab-1e98bbcca903 tags:
``` python
from sklearn.metrics import accuracy_score
import numpy as np
eval_preds = get_eval_preds(train_prompt, train_str, eval_df, n=eval_n)
eval_label_chunk = eval_labels[:eval_n]
```
%% Cell type:code id:64323a4d-6eea-4e40-9eac-b2deed60192b tags:
``` python
acc = accuracy_score(eval_label_chunk, np.array(eval_preds).round())
print(f'ACCURACY: {acc}')
```
%% Output
ACCURACY: 0.8
%% Cell type:markdown id:11790d28-8f34-42dd-b11f-6aad21fd5f46 tags:
#### Do evaluation with no training prompt string!
%% Cell type:code id:aaf993e5-c363-4f18-a28f-09761e49cb6d tags:
``` python
from sklearn.metrics import accuracy_score
import numpy as np
eval_preds_null = get_eval_preds(train_prompt, "", eval_df, n=eval_n)
eval_label_chunk = eval_labels[:eval_n]
```
%% Cell type:code id:c3b8bcd5-5972-4ce5-9aa1-57460cdde199 tags:
``` python
acc_null = accuracy_score(eval_label_chunk, np.array(eval_preds_null).round())
print(f'ACCURACY: {acc_null}')
```
%% Output
ACCURACY: 0.8
%% Cell type:markdown id:8f0a5e4b-e627-4b47-a807-939813596594 tags:
## Extending with GPT List Index
%% Cell type:markdown id:42a1ca28-96e9-4cd2-bd48-0673917ad057 tags:
#### Build Index
%% Cell type:code id:6c59b030-855d-4e27-89c3-74c972d1bf19 tags:
``` python
from gpt_index import GPTListIndex
from gpt_index.schema import Document
```
%% Cell type:code id:8f9556de-e323-4318-bb71-cff75bf8c3c1 tags:
``` python
index = GPTListIndex([])
```
%% Cell type:code id:e27720fc-af36-40fd-8c55-41485248aa9f tags:
``` python
# insertion into index
batch_size = 40
num_train_chunks = 5
for i in range(num_train_chunks):
print(f"Inserting chunk: {i}/{num_train_chunks}")
start_idx = i*batch_size
end_idx = (i+1)*batch_size
train_batch = train_df.iloc[start_idx:end_idx+batch_size]
labels_batch = train_labels.iloc[start_idx:end_idx+batch_size]
all_train_str = get_train_str(train_batch, labels_batch, train_n=batch_size)
index.insert(Document(all_train_str))
```
%% Cell type:markdown id:e78db088-6649-44db-b52a-766316713b96 tags:
#### Query Index
%% Cell type:code id:9cb90564-1de2-412f-8318-d5280855004e tags:
``` python
from utils import query_str, qa_data_prompt, refine_prompt
```
%% Cell type:code id:77c1ae36-e0af-47bc-a656-4971af699755 tags:
``` python
query_str
```
%% Output
'Which is the relationship between these features and predicting survival?'
%% Cell type:code id:c403710f-d4b3-4287-94f5-e275ea19b476 tags:
``` python
response = index.query(
query_str,
text_qa_template=qa_data_prompt,
refine_template=refine_prompt,
)
```
%% Output
> Starting query: Which is the relationship between these features and predicting survival?
%% Cell type:code id:d2545ab1-980a-4fbd-8add-7ef957801644 tags:
``` python
print(response)
```
%% Output
There is no definitive answer to this question, as the relationship between the features and predicting survival will vary depending on the data. However, some possible relationships include: age (younger passengers are more likely to survive), sex (females are more likely to survive), fare (passengers who paid more for their ticket are more likely to survive), and pclass (passengers in first or second class are more likely to survive).
%% Cell type:markdown id:d0d7d260-2283-49f6-ac40-35c7071cc54d tags:
#### Get Predictions and Evaluate
%% Cell type:code id:e7b98057-957c-48ef-be85-59ff9813d201 tags:
``` python
# get eval preds
from utils import train_prompt_with_context
train_str = response
print(train_prompt_with_context.template)
print(f'\n\n`train_str`: {train_str}')
```
%% Output
The following structured data is provided in "Feature Name":"Feature Value" format.
Each datapoint describes a passenger on the Titanic.
The task is to decide whether the passenger survived.
We discovered the following relationship between features and survival:
-------------------
{train_str}
-------------------
Given this, predict whether the following passenger survived.
Return answer as a number between 0 or 1.
{eval_str}
Survived:
`train_str`:
There is no definitive answer to this question, as the relationship between the features and predicting survival will vary depending on the data. However, some possible relationships include: age (younger passengers are more likely to survive), sex (females are more likely to survive), fare (passengers who paid more for their ticket are more likely to survive), and pclass (passengers in first or second class are more likely to survive).
%% Cell type:code id:659c6a3f-1c5d-4314-87dc-908e76d50e4a tags:
``` python
# do evaluation
from sklearn.metrics import accuracy_score
import numpy as np
eval_n = 40
eval_preds = get_eval_preds(train_prompt_with_context, train_str, eval_df, n=eval_n)
```
%% Cell type:code id:7424e7d3-2576-42bc-b626-cf8088265004 tags:
``` python
eval_label_chunk = eval_labels[:eval_n]
acc = accuracy_score(eval_label_chunk, np.array(eval_preds).round())
print(f'ACCURACY: {acc}')
```
%% Output
ACCURACY: 0.85
%% Cell type:code id:e010b497-eeed-4142-a8ac-f5545e85fcc2 tags:
``` python
```
This diff is collapsed.
"""Helper functions for Titanic GPT-3 experiments."""
# form prompt, run GPT
import re
from typing import List, Optional, Tuple
import pandas as pd
from sklearn.model_selection import train_test_split
from gpt_index.indices.utils import extract_numbers_given_response
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt
def get_train_and_eval_data(
csv_path: str,
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
"""Get train and eval data."""
df = pd.read_csv(csv_path)
label_col = "Survived"
cols_to_drop = ["PassengerId", "Ticket", "Name", "Cabin"]
df = df.drop(cols_to_drop, axis=1)
labels = df.pop(label_col)
train_df, eval_df, train_labels, eval_labels = train_test_split(
df, labels, test_size=0.25, random_state=0
)
return train_df, train_labels, eval_df, eval_labels
def get_sorted_dict_str(d: dict) -> str:
"""Get sorted dict string."""
keys = sorted(list(d.keys()))
return "\n".join([f"{k}:{d[k]}" for k in keys])
def get_label_str(labels: pd.Series, i: int) -> str:
"""Get label string."""
return f"{labels.name}: {labels.iloc[i]}"
def get_train_str(
train_df: pd.DataFrame, train_labels: pd.Series, train_n: int = 10
) -> str:
"""Get train str."""
dict_list = train_df.to_dict("records")[:train_n]
item_list = []
for i, d in enumerate(dict_list):
dict_str = get_sorted_dict_str(d)
label_str = get_label_str(train_labels, i)
item_str = (
f"This is the Data:\n{dict_str}\nThis is the correct answer:\n{label_str}"
)
item_list.append(item_str)
return "\n\n".join(item_list)
def extract_float_given_response(response: str, n: int = 1) -> Optional[float]:
"""Extract number given the GPT-generated response.
Used by tree-structured indices.
"""
numbers = re.findall(r"\d+\.\d+", response)
if len(numbers) == 0:
# if no floats, try extracting ints, and convert to float
new_numbers = extract_numbers_given_response(response, n=n)
if new_numbers is None:
return None
else:
return float(numbers[0])
else:
return float(numbers[0])
def get_eval_preds(
train_prompt: Prompt, train_str: str, eval_df: pd.DataFrame, n: int = 20
) -> List:
"""Get eval preds."""
llm_predictor = LLMPredictor()
eval_preds = []
for i in range(n):
eval_str = get_sorted_dict_str(eval_df.iloc[i].to_dict())
response, _ = llm_predictor.predict(
train_prompt, train_str=train_str, eval_str=eval_str
)
pred = extract_float_given_response(response)
print(f"Getting preds: {i}/{n}: {pred}")
if pred is None:
# something went wrong, impute a 0.5
eval_preds.append(0.5)
else:
eval_preds.append(pred)
return eval_preds
# default train prompt
train_prompt_str = (
"The following structured data is provided in "
'"Feature Name":"Feature Value" format.\n'
"Each datapoint describes a passenger on the Titanic.\n"
"The task is to decide whether the passenger survived.\n"
"Some example datapoints are given below: \n"
"-------------------\n"
"{train_str}\n"
"-------------------\n"
"Given this, predict whether the following passenger survived. "
"Return answer as a number between 0 or 1. \n"
"{eval_str}\n"
"Survived: "
)
train_prompt = Prompt(
input_variables=["train_str", "eval_str"], template=train_prompt_str
)
# prompt to summarize the data
query_str = "Which is the relationship between these features and predicting survival?"
qa_data_str = (
"The following structured data is provided in "
'"Feature Name":"Feature Value" format.\n'
"Each datapoint describes a passenger on the Titanic.\n"
"The task is to decide whether the passenger survived.\n"
"Some example datapoints are given below: \n"
"-------------------\n"
"{context_str}\n"
"-------------------\n"
"Given this, answer the question: {query_str}"
)
qa_data_prompt = Prompt(
input_variables=["context_str", "query_str"], template=qa_data_str
)
# prompt to refine the answer
refine_str = (
"The original question is as follows: {query_str}\n"
"We have provided an existing answer: {existing_answer}\n"
"The following structured data is provided in "
'"Feature Name":"Feature Value" format.\n'
"Each datapoint describes a passenger on the Titanic.\n"
"The task is to decide whether the passenger survived.\n"
"We have the opportunity to refine the existing answer"
"(only if needed) with some more datapoints below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question. "
"If the context isn't useful, return the original answer."
)
refine_prompt = Prompt(
input_variables=["query_str", "existing_answer", "context_msg"],
template=refine_str,
)
# train prompt with refined context
train_prompt_with_context_str = (
"The following structured data is provided in "
'"Feature Name":"Feature Value" format.\n'
"Each datapoint describes a passenger on the Titanic.\n"
"The task is to decide whether the passenger survived.\n"
"We discovered the following relationship between features and survival:\n"
"-------------------\n"
"{train_str}\n"
"-------------------\n"
"Given this, predict whether the following passenger survived. \n"
"Return answer as a number between 0 or 1. \n"
"{eval_str}\n"
"Survived: "
)
train_prompt_with_context = Prompt(
input_variables=["train_str", "eval_str"], template=train_prompt_with_context_str
)
...@@ -58,11 +58,8 @@ class GPTListIndex(BaseGPTIndex[IndexList]): ...@@ -58,11 +58,8 @@ class GPTListIndex(BaseGPTIndex[IndexList]):
def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery: def _mode_to_query(self, mode: str, **query_kwargs: Any) -> BaseGPTIndexQuery:
if mode == DEFAULT_MODE: if mode == DEFAULT_MODE:
query_kwargs.update( if "text_qa_template" not in query_kwargs:
{ query_kwargs["text_qa_template"] = self.text_qa_template
"text_qa_template": self.text_qa_template,
}
)
query = GPTListIndexQuery(self.index_struct, **query_kwargs) query = GPTListIndexQuery(self.index_struct, **query_kwargs)
else: else:
raise ValueError(f"Invalid query mode: {mode}.") raise ValueError(f"Invalid query mode: {mode}.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment