From 78a4c9ed20e2b08e01d39487bec356da7c191ab8 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Mon, 26 Feb 2024 13:12:12 -0600
Subject: [PATCH] fix prompt helper init (#11379)

---
 .../core/indices/common/struct_store/base.py      |  6 +++---
 .../llama_index/core/indices/common_tree/base.py  |  7 +++----
 .../llama_index/core/indices/prompt_helper.py     |  1 +
 .../llama_index/core/indices/tree/inserter.py     |  6 +++---
 .../core/indices/tree/select_leaf_retriever.py    |  6 +++---
 .../core/response_synthesizers/base.py            | 10 +++++++---
 .../core/response_synthesizers/factory.py         | 15 +++++++++++----
 .../core/response_synthesizers/generation.py      |  3 +++
 .../core/response_synthesizers/refine.py          |  3 +++
 .../response_synthesizers/simple_summarize.py     |  3 +++
 .../core/response_synthesizers/tree_summarize.py  |  3 +++
 llama-index-core/llama_index/core/settings.py     | 12 ------------
 12 files changed, 43 insertions(+), 32 deletions(-)

diff --git a/llama-index-core/llama_index/core/indices/common/struct_store/base.py b/llama-index-core/llama_index/core/indices/common/struct_store/base.py
index 3333bc0667..d4c4753a99 100644
--- a/llama-index-core/llama_index/core/indices/common/struct_store/base.py
+++ b/llama-index-core/llama_index/core/indices/common/struct_store/base.py
@@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, cast
 
 from llama_index.core.callbacks.schema import CBEventType, EventPayload
 from llama_index.core.data_structs.table import StructDatapoint
+from llama_index.core.indices.prompt_helper import PromptHelper
 from llama_index.core.node_parser.interface import TextSplitter
 from llama_index.core.prompts import BasePromptTemplate
 from llama_index.core.prompts.default_prompt_selectors import (
@@ -26,7 +27,6 @@ from llama_index.core.settings import (
     Settings,
     callback_manager_from_settings_or_context,
     llm_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.utilities.sql_wrapper import SQLDatabase
 from llama_index.core.utils import truncate_text
@@ -67,8 +67,8 @@ class SQLDocumentContextBuilder:
         self._sql_database = sql_database
         self._text_splitter = text_splitter
         self._llm = llm or llm_from_settings_or_context(Settings, service_context)
-        self._prompt_helper = prompt_helper_from_settings_or_context(
-            Settings, service_context
+        self._prompt_helper = Settings._prompt_helper or PromptHelper.from_llm_metadata(
+            self._llm.metadata,
         )
         self._callback_manager = callback_manager_from_settings_or_context(
             Settings, service_context
diff --git a/llama-index-core/llama_index/core/indices/common_tree/base.py b/llama-index-core/llama_index/core/indices/common_tree/base.py
index 18213f1c7a..ba10592256 100644
--- a/llama-index-core/llama_index/core/indices/common_tree/base.py
+++ b/llama-index-core/llama_index/core/indices/common_tree/base.py
@@ -1,6 +1,5 @@
 """Common classes/functions for tree index operations."""
 
-
 import asyncio
 import logging
 from typing import Dict, List, Optional, Sequence, Tuple
@@ -8,6 +7,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
 from llama_index.core.async_utils import run_async_tasks
 from llama_index.core.callbacks.schema import CBEventType, EventPayload
 from llama_index.core.data_structs.data_structs import IndexGraph
+from llama_index.core.indices.prompt_helper import PromptHelper
 from llama_index.core.indices.utils import get_sorted_node_list, truncate_text
 from llama_index.core.llms.llm import LLM
 from llama_index.core.prompts import BasePromptTemplate
@@ -17,7 +17,6 @@ from llama_index.core.settings import (
     Settings,
     callback_manager_from_settings_or_context,
     llm_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.storage.docstore import BaseDocumentStore
 from llama_index.core.storage.docstore.registry import get_default_docstore
@@ -50,8 +49,8 @@ class GPTTreeIndexBuilder:
         self.num_children = num_children
         self.summary_prompt = summary_prompt
         self._llm = llm or llm_from_settings_or_context(Settings, service_context)
-        self._prompt_helper = prompt_helper_from_settings_or_context(
-            Settings, service_context
+        self._prompt_helper = Settings._prompt_helper or PromptHelper.from_llm_metadata(
+            self._llm.metadata,
         )
         self._callback_manager = callback_manager_from_settings_or_context(
             Settings, service_context
diff --git a/llama-index-core/llama_index/core/indices/prompt_helper.py b/llama-index-core/llama_index/core/indices/prompt_helper.py
index 359d194589..6b7f141ed6 100644
--- a/llama-index-core/llama_index/core/indices/prompt_helper.py
+++ b/llama-index-core/llama_index/core/indices/prompt_helper.py
@@ -115,6 +115,7 @@ class PromptHelper(BaseComponent):
 
         """
         context_window = llm_metadata.context_window
+
         if llm_metadata.num_output == -1:
             num_output = DEFAULT_NUM_OUTPUTS
         else:
diff --git a/llama-index-core/llama_index/core/indices/tree/inserter.py b/llama-index-core/llama_index/core/indices/tree/inserter.py
index 58f1d11607..0bb0b54cd4 100644
--- a/llama-index-core/llama_index/core/indices/tree/inserter.py
+++ b/llama-index-core/llama_index/core/indices/tree/inserter.py
@@ -3,6 +3,7 @@
 from typing import Optional, Sequence
 
 from llama_index.core.data_structs.data_structs import IndexGraph
+from llama_index.core.indices.prompt_helper import PromptHelper
 from llama_index.core.indices.tree.utils import get_numbered_text_from_nodes
 from llama_index.core.indices.utils import (
     extract_numbers_given_response,
@@ -19,7 +20,6 @@ from llama_index.core.service_context import ServiceContext
 from llama_index.core.settings import (
     Settings,
     llm_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.storage.docstore import BaseDocumentStore
 from llama_index.core.storage.docstore.registry import get_default_docstore
@@ -46,8 +46,8 @@ class TreeIndexInserter:
         self.insert_prompt = insert_prompt
         self.index_graph = index_graph
         self._llm = llm or llm_from_settings_or_context(Settings, service_context)
-        self._prompt_helper = prompt_helper_from_settings_or_context(
-            Settings, service_context
+        self._prompt_helper = Settings._prompt_helper or PromptHelper.from_llm_metadata(
+            self._llm.metadata,
         )
         self._docstore = docstore or get_default_docstore()
 
diff --git a/llama-index-core/llama_index/core/indices/tree/select_leaf_retriever.py b/llama-index-core/llama_index/core/indices/tree/select_leaf_retriever.py
index 9483de176c..b8dc73d14e 100644
--- a/llama-index-core/llama_index/core/indices/tree/select_leaf_retriever.py
+++ b/llama-index-core/llama_index/core/indices/tree/select_leaf_retriever.py
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, cast
 from llama_index.core.base.base_retriever import BaseRetriever
 from llama_index.core.base.response.schema import Response
 from llama_index.core.callbacks.base import CallbackManager
+from llama_index.core.indices.prompt_helper import PromptHelper
 from llama_index.core.indices.query.schema import QueryBundle
 from llama_index.core.indices.tree.base import TreeIndex
 from llama_index.core.indices.tree.utils import get_numbered_text_from_nodes
@@ -32,7 +33,6 @@ from llama_index.core.schema import (
 from llama_index.core.settings import (
     Settings,
     callback_manager_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.utils import print_text, truncate_text
 
@@ -93,8 +93,8 @@ class TreeSelectLeafRetriever(BaseRetriever):
         self._index_struct = index.index_struct
         self._docstore = index.docstore
         self._service_context = index.service_context
-        self._prompt_helper = prompt_helper_from_settings_or_context(
-            Settings, index.service_context
+        self._prompt_helper = Settings._prompt_helper or PromptHelper.from_llm_metadata(
+            self._llm.metadata,
         )
 
         self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
diff --git a/llama-index-core/llama_index/core/response_synthesizers/base.py b/llama-index-core/llama_index/core/response_synthesizers/base.py
index 6edc364040..acecef8f11 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/base.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/base.py
@@ -7,6 +7,7 @@ Will support different modes, from 1) stuffing chunks into prompt,
 2) create and refine separately over each chunk, 3) tree summarization.
 
 """
+
 import logging
 from abc import abstractmethod
 from typing import Any, Dict, Generator, List, Optional, Sequence, Union
@@ -41,7 +42,6 @@ from llama_index.core.settings import (
     Settings,
     callback_manager_from_settings_or_context,
     llm_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.types import RESPONSE_TEXT_TYPE
 
@@ -69,8 +69,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
             callback_manager
             or callback_manager_from_settings_or_context(Settings, service_context)
         )
-        self._prompt_helper = prompt_helper or prompt_helper_from_settings_or_context(
-            Settings, service_context
+        self._prompt_helper = (
+            prompt_helper
+            or Settings._prompt_helper
+            or PromptHelper.from_llm_metadata(
+                self._llm.metadata,
+            )
         )
 
         self._streaming = streaming
diff --git a/llama-index-core/llama_index/core/response_synthesizers/factory.py b/llama-index-core/llama_index/core/response_synthesizers/factory.py
index f8543667c3..aa191e0b01 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/factory.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/factory.py
@@ -31,7 +31,6 @@ from llama_index.core.settings import (
     Settings,
     callback_manager_from_settings_or_context,
     llm_from_settings_or_context,
-    prompt_helper_from_settings_or_context,
 )
 from llama_index.core.types import BasePydanticProgram
 
@@ -63,9 +62,17 @@ def get_response_synthesizer(
         Settings, service_context
     )
     llm = llm or llm_from_settings_or_context(Settings, service_context)
-    prompt_helper = prompt_helper or prompt_helper_from_settings_or_context(
-        Settings, service_context
-    )
+
+    if service_context is not None:
+        prompt_helper = service_context.prompt_helper
+    else:
+        prompt_helper = (
+            prompt_helper
+            or Settings._prompt_helper
+            or PromptHelper.from_llm_metadata(
+                llm.metadata,
+            )
+        )
 
     if response_mode == ResponseMode.REFINE:
         return Refine(
diff --git a/llama-index-core/llama_index/core/response_synthesizers/generation.py b/llama-index-core/llama_index/core/response_synthesizers/generation.py
index 396cc561b3..e6f176017e 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/generation.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/generation.py
@@ -22,6 +22,9 @@ class Generation(BaseSynthesizer):
         # deprecated
         service_context: Optional[ServiceContext] = None,
     ) -> None:
+        if service_context is not None:
+            prompt_helper = service_context.prompt_helper
+
         super().__init__(
             llm=llm,
             callback_manager=callback_manager,
diff --git a/llama-index-core/llama_index/core/response_synthesizers/refine.py b/llama-index-core/llama_index/core/response_synthesizers/refine.py
index b162923661..8dd0fcc509 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/refine.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/refine.py
@@ -107,6 +107,9 @@ class Refine(BaseSynthesizer):
         # deprecated
         service_context: Optional[ServiceContext] = None,
     ) -> None:
+        if service_context is not None:
+            prompt_helper = service_context.prompt_helper
+
         super().__init__(
             llm=llm,
             callback_manager=callback_manager,
diff --git a/llama-index-core/llama_index/core/response_synthesizers/simple_summarize.py b/llama-index-core/llama_index/core/response_synthesizers/simple_summarize.py
index e0c81be5f3..152459df1c 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/simple_summarize.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/simple_summarize.py
@@ -24,6 +24,9 @@ class SimpleSummarize(BaseSynthesizer):
         # deprecated
         service_context: Optional[ServiceContext] = None,
     ) -> None:
+        if service_context is not None:
+            prompt_helper = service_context.prompt_helper
+
         super().__init__(
             llm=llm,
             callback_manager=callback_manager,
diff --git a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
index 066163a51f..356e5dfee0 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
@@ -41,6 +41,9 @@ class TreeSummarize(BaseSynthesizer):
         # deprecated
         service_context: Optional[ServiceContext] = None,
     ) -> None:
+        if service_context is not None:
+            prompt_helper = service_context.prompt_helper
+
         super().__init__(
             llm=llm,
             callback_manager=callback_manager,
diff --git a/llama-index-core/llama_index/core/settings.py b/llama-index-core/llama_index/core/settings.py
index d7875e566f..c280f8a88e 100644
--- a/llama-index-core/llama_index/core/settings.py
+++ b/llama-index-core/llama_index/core/settings.py
@@ -197,8 +197,6 @@ class _Settings:
         """Set the text splitter."""
         self.node_parser = text_splitter
 
-    # ---- Prompt helper ----
-
     @property
     def prompt_helper(self) -> PromptHelper:
         """Get the prompt helper."""
@@ -296,16 +294,6 @@ def node_parser_from_settings_or_context(
     return settings.node_parser
 
 
-def prompt_helper_from_settings_or_context(
-    settings: _Settings, context: Optional["ServiceContext"]
-) -> PromptHelper:
-    """Get settings from either settings or context."""
-    if context is not None:
-        return context.prompt_helper
-
-    return settings.prompt_helper
-
-
 def transformations_from_settings_or_context(
     settings: _Settings, context: Optional["ServiceContext"]
 ) -> List[TransformComponent]:
-- 
GitLab