From ae6d2c287086249a997f95162e0650f510ee4611 Mon Sep 17 00:00:00 2001
From: Zach Zhao <thezachzhao@gmail.com>
Date: Mon, 16 Jan 2023 22:14:52 -0500
Subject: [PATCH] Fix and extend keyword parsing (#238)

Co-authored-by: Jerry Liu <jerry@robustintelligence.com>
---
 .gitignore                                    |  3 +++
 clearNotebookMetadata.sh                      |  5 +++++
 gpt_index/indices/keyword_table/base.py       |  2 +-
 gpt_index/indices/keyword_table/utils.py      | 14 ++++++++++----
 .../indices/query/keyword_table/query.py      |  2 +-
 tests/indices/keyword_table/test_utils.py     | 19 +++++++++++++++++++
 6 files changed, 39 insertions(+), 6 deletions(-)
 create mode 100644 clearNotebookMetadata.sh

diff --git a/.gitignore b/.gitignore
index 6e00effac7..9143ac0b34 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,3 +130,6 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# Jetbrains
+.idea
\ No newline at end of file
diff --git a/clearNotebookMetadata.sh b/clearNotebookMetadata.sh
new file mode 100644
index 0000000000..53dd091504
--- /dev/null
+++ b/clearNotebookMetadata.sh
@@ -0,0 +1,5 @@
+# Clears kernel spec all example notebooks
+for file in examples/**/*.ipynb
+do
+   jq 'del(.metadata.kernelspec)' "$file" | sponge "$file"
+done
diff --git a/gpt_index/indices/keyword_table/base.py b/gpt_index/indices/keyword_table/base.py
index 007758d7c9..32f0d1f1b6 100644
--- a/gpt_index/indices/keyword_table/base.py
+++ b/gpt_index/indices/keyword_table/base.py
@@ -141,5 +141,5 @@ class GPTKeywordTableIndex(BaseGPTKeywordTableIndex):
             self.keyword_extract_template,
             text=text,
         )
-        keywords = extract_keywords_given_response(response)
+        keywords = extract_keywords_given_response(response, start_token="KEYWORDS:")
         return keywords
diff --git a/gpt_index/indices/keyword_table/utils.py b/gpt_index/indices/keyword_table/utils.py
index c8fb139c4a..4afd893ac9 100644
--- a/gpt_index/indices/keyword_table/utils.py
+++ b/gpt_index/indices/keyword_table/utils.py
@@ -43,17 +43,23 @@ def rake_extract_keywords(
         return set(keywords)
 
 
-def extract_keywords_given_response(response: str, lowercase: bool = True) -> Set[str]:
+def extract_keywords_given_response(
+    response: str, lowercase: bool = True, start_token: str = ""
+) -> Set[str]:
     """Extract keywords given the GPT-generated response.
 
     Used by keyword table indices.
-
+    Parses <start_token>: <word1>, <word2>, ... into [word1, word2, ...]
+    Raises exception if response doesn't start with <start_token>
     """
     results = []
+    response = response.strip()  # Strip newlines from responses.
+
+    if response.startswith(start_token):
+        response = response[len(start_token) :]
+
     keywords = response.split(",")
     for k in keywords:
-        if "KEYWORD" in k:
-            continue
         rk = k
         if lowercase:
             rk = rk.lower()
diff --git a/gpt_index/indices/query/keyword_table/query.py b/gpt_index/indices/query/keyword_table/query.py
index 992f340f61..403b89b036 100644
--- a/gpt_index/indices/query/keyword_table/query.py
+++ b/gpt_index/indices/query/keyword_table/query.py
@@ -118,7 +118,7 @@ class GPTKeywordTableGPTQuery(BaseGPTKeywordTableQuery):
             max_keywords=self.max_keywords_per_query,
             question=query_str,
         )
-        keywords = extract_keywords_given_response(response)
+        keywords = extract_keywords_given_response(response, start_token="KEYWORDS:")
         return list(keywords)
 
 
diff --git a/tests/indices/keyword_table/test_utils.py b/tests/indices/keyword_table/test_utils.py
index 91f2ac36ec..a565c5ede9 100644
--- a/tests/indices/keyword_table/test_utils.py
+++ b/tests/indices/keyword_table/test_utils.py
@@ -17,3 +17,22 @@ def test_expand_tokens_with_subtokens() -> None:
         "world",
         "bye",
     }
+
+
+def test_extract_keywords_with_start_delimiter() -> None:
+    """Test extract keywords with start delimiter."""
+    response = "KEYWORDS: foo, bar, foobar"
+    keywords = extract_keywords_given_response(response, start_token="KEYWORDS:")
+    assert keywords == {
+        "foo",
+        "bar",
+        "foobar",
+    }
+
+    response = "TOKENS: foo, bar, foobar"
+    keywords = extract_keywords_given_response(response, start_token="TOKENS:")
+    assert keywords == {
+        "foo",
+        "bar",
+        "foobar",
+    }
-- 
GitLab