diff --git a/.gitignore b/.gitignore index 6e00effac74049916dcf0f677c3ccb7cd2d6f4f4..9143ac0b34a2636b326147a3a5f460bdd1f25115 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Pyre type checker .pyre/ + +# Jetbrains +.idea \ No newline at end of file diff --git a/clearNotebookMetadata.sh b/clearNotebookMetadata.sh new file mode 100644 index 0000000000000000000000000000000000000000..53dd091504721feeb1a5fb0bafe49fc909344bfd --- /dev/null +++ b/clearNotebookMetadata.sh @@ -0,0 +1,5 @@ +# Clears kernel spec all example notebooks +for file in examples/**/*.ipynb +do + jq 'del(.metadata.kernelspec)' "$file" | sponge "$file" +done diff --git a/gpt_index/indices/keyword_table/base.py b/gpt_index/indices/keyword_table/base.py index 007758d7c93c004d22afc58fcbb6a7687de3d99b..32f0d1f1b6f3986771566cee14af88e6e8479620 100644 --- a/gpt_index/indices/keyword_table/base.py +++ b/gpt_index/indices/keyword_table/base.py @@ -141,5 +141,5 @@ class GPTKeywordTableIndex(BaseGPTKeywordTableIndex): self.keyword_extract_template, text=text, ) - keywords = extract_keywords_given_response(response) + keywords = extract_keywords_given_response(response, start_token="KEYWORDS:") return keywords diff --git a/gpt_index/indices/keyword_table/utils.py b/gpt_index/indices/keyword_table/utils.py index c8fb139c4a3c116f33108fbaeb9922ce0ff5d40e..4afd893ac94d8ef62581ff3fdcf0ca7800e773ac 100644 --- a/gpt_index/indices/keyword_table/utils.py +++ b/gpt_index/indices/keyword_table/utils.py @@ -43,17 +43,23 @@ def rake_extract_keywords( return set(keywords) -def extract_keywords_given_response(response: str, lowercase: bool = True) -> Set[str]: +def extract_keywords_given_response( + response: str, lowercase: bool = True, start_token: str = "" +) -> Set[str]: """Extract keywords given the GPT-generated response. Used by keyword table indices. - + Parses <start_token>: <word1>, <word2>, ... into [word1, word2, ...] + Raises exception if response doesn't start with <start_token> """ results = [] + response = response.strip() # Strip newlines from responses. + + if response.startswith(start_token): + response = response[len(start_token) :] + keywords = response.split(",") for k in keywords: - if "KEYWORD" in k: - continue rk = k if lowercase: rk = rk.lower() diff --git a/gpt_index/indices/query/keyword_table/query.py b/gpt_index/indices/query/keyword_table/query.py index 992f340f615ba6b50006f5ddbef441176191b0ac..403b89b036edde19a38d7c1f17e83f0b202f996e 100644 --- a/gpt_index/indices/query/keyword_table/query.py +++ b/gpt_index/indices/query/keyword_table/query.py @@ -118,7 +118,7 @@ class GPTKeywordTableGPTQuery(BaseGPTKeywordTableQuery): max_keywords=self.max_keywords_per_query, question=query_str, ) - keywords = extract_keywords_given_response(response) + keywords = extract_keywords_given_response(response, start_token="KEYWORDS:") return list(keywords) diff --git a/tests/indices/keyword_table/test_utils.py b/tests/indices/keyword_table/test_utils.py index 91f2ac36ecd0393e2838bc0bbb9791a9b9d92608..a565c5ede9769e63647e38e89a7ebe3be397c9c9 100644 --- a/tests/indices/keyword_table/test_utils.py +++ b/tests/indices/keyword_table/test_utils.py @@ -17,3 +17,22 @@ def test_expand_tokens_with_subtokens() -> None: "world", "bye", } + + +def test_extract_keywords_with_start_delimiter() -> None: + """Test extract keywords with start delimiter.""" + response = "KEYWORDS: foo, bar, foobar" + keywords = extract_keywords_given_response(response, start_token="KEYWORDS:") + assert keywords == { + "foo", + "bar", + "foobar", + } + + response = "TOKENS: foo, bar, foobar" + keywords = extract_keywords_given_response(response, start_token="TOKENS:") + assert keywords == { + "foo", + "bar", + "foobar", + }