From e157ebb7e459a9990a86dc5a518cc18a765a109a Mon Sep 17 00:00:00 2001
From: GICodeWarrior <GICodeWarrior@gmail.com>
Date: Sun, 16 Feb 2025 17:38:07 -0800
Subject: [PATCH] Retain return type from @dispatcher.span (#17817)

---
 .../core/chat_engine/condense_plus_context.py |  2 +
 .../core/evaluation/retrieval/evaluator.py    |  4 +-
 .../core/extractors/metadata_extractors.py    | 12 +--
 .../core/indices/common_tree/base.py          |  3 +-
 .../core/instrumentation/dispatcher.py        |  5 +-
 llama-index-core/llama_index/core/llms/llm.py | 29 +++---
 .../core/output_parsers/pydantic.py           |  6 +-
 .../core/program/function_program.py          | 90 ++++++++++---------
 .../llama_index/core/program/llm_program.py   | 17 ++--
 .../llama_index/core/program/utils.py         | 12 +--
 .../query_engine/retry_source_query_engine.py |  7 +-
 .../query_engine/sql_join_query_engine.py     |  4 +-
 .../llama_index/core/query_pipeline/query.py  |  3 +-
 .../response_synthesizers/tree_summarize.py   | 13 ++-
 .../llama_index/core/tools/retriever_tool.py  | 10 ++-
 llama-index-core/llama_index/core/types.py    | 16 +++-
 .../pyproject.toml                            |  4 +-
 .../tests/test_node_parser_docling.py         | 10 ++-
 18 files changed, 145 insertions(+), 102 deletions(-)

diff --git a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
index 8177ccb7df..86e93bcf8b 100644
--- a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
+++ b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
@@ -347,6 +347,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
         )
 
         response = synthesizer.synthesize(message, context_nodes)
+        assert isinstance(response, StreamingResponse)
 
         def wrapped_gen(response: StreamingResponse) -> ChatResponseGen:
             full_response = ""
@@ -405,6 +406,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
         )
 
         response = await synthesizer.asynthesize(message, context_nodes)
+        assert isinstance(response, AsyncStreamingResponse)
 
         async def wrapped_gen(response: AsyncStreamingResponse) -> ChatResponseAsyncGen:
             full_response = ""
diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py b/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py
index fd686df156..24e7793f84 100644
--- a/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py
+++ b/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py
@@ -45,7 +45,7 @@ class RetrieverEvaluator(BaseRetrievalEvaluator):
 
         return (
             [node.node.node_id for node in retrieved_nodes],
-            [node.node.text for node in retrieved_nodes],
+            [node.text for node in retrieved_nodes],
         )
 
 
@@ -84,7 +84,7 @@ class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator):
             node = scored_node.node
             if isinstance(node, ImageNode):
                 image_nodes.append(node)
-            if node.text:
+            if isinstance(node, TextNode):
                 text_nodes.append(node)
 
         if mode == "text":
diff --git a/llama-index-core/llama_index/core/extractors/metadata_extractors.py b/llama-index-core/llama_index/core/extractors/metadata_extractors.py
index f5ef8f70ad..e416411658 100644
--- a/llama-index-core/llama_index/core/extractors/metadata_extractors.py
+++ b/llama-index-core/llama_index/core/extractors/metadata_extractors.py
@@ -20,7 +20,7 @@ disambiguate the document or subsection from other similar documents or subsecti
 (similar with contrastive learning)
 """
 
-from typing import Any, Callable, Dict, List, Optional, Sequence, cast
+from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, cast
 
 from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
 from llama_index.core.bridge.pydantic import (
@@ -33,7 +33,7 @@ from llama_index.core.llms.llm import LLM
 from llama_index.core.prompts import PromptTemplate
 from llama_index.core.schema import BaseNode, TextNode
 from llama_index.core.settings import Settings
-from llama_index.core.types import BasePydanticProgram
+from llama_index.core.types import BasePydanticProgram, Model
 
 DEFAULT_TITLE_NODE_TEMPLATE = """\
 Context: {context_str}. Give a title that summarizes all of \
@@ -462,7 +462,7 @@ Given the contextual information, extract out a {class_name} object.\
 """
 
 
-class PydanticProgramExtractor(BaseExtractor):
+class PydanticProgramExtractor(BaseExtractor, Generic[Model]):
     """Pydantic program extractor.
 
     Uses an LLM to extract out a Pydantic object. Return attributes of that object
@@ -470,7 +470,7 @@ class PydanticProgramExtractor(BaseExtractor):
 
     """
 
-    program: SerializeAsAny[BasePydanticProgram] = Field(
+    program: SerializeAsAny[BasePydanticProgram[Model]] = Field(
         ..., description="Pydantic program to extract."
     )
     input_key: str = Field(
@@ -500,7 +500,9 @@ class PydanticProgramExtractor(BaseExtractor):
         )
 
         ret_object = await self.program.acall(**{self.input_key: extract_str})
-        return ret_object.dict()
+        assert not isinstance(ret_object, list)
+
+        return ret_object.model_dump()
 
     async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
         """Extract pydantic program."""
diff --git a/llama-index-core/llama_index/core/indices/common_tree/base.py b/llama-index-core/llama_index/core/indices/common_tree/base.py
index 258f4996a6..0cb075b1ac 100644
--- a/llama-index-core/llama_index/core/indices/common_tree/base.py
+++ b/llama-index-core/llama_index/core/indices/common_tree/base.py
@@ -219,8 +219,7 @@ class GPTTreeIndexBuilder:
                 self._llm.apredict(self.summary_prompt, context_str=text_chunk)
                 for text_chunk in text_chunks_progress
             ]
-            outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
-            summaries = [output[0] for output in outputs]
+            summaries = await asyncio.gather(*tasks)
 
             event.on_end(payload={"summaries": summaries, "level": level})
 
diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
index 6ad38eafb1..a00dcbb41d 100644
--- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py
+++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
@@ -2,7 +2,7 @@ import asyncio
 from functools import partial
 from contextlib import contextmanager
 from contextvars import Context, ContextVar, Token, copy_context
-from typing import Any, Callable, Generator, List, Optional, Dict, Protocol
+from typing import Any, Callable, Generator, List, Optional, Dict, Protocol, TypeVar
 import inspect
 import logging
 import uuid
@@ -26,6 +26,7 @@ _logger = logging.getLogger(__name__)
 active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar(
     "instrument_tags", default={}
 )
+_R = TypeVar("_R")
 
 
 @contextmanager
@@ -239,7 +240,7 @@ class Dispatcher(BaseModel):
             else:
                 c = c.parent
 
-    def span(self, func: Callable) -> Any:
+    def span(self, func: Callable[..., _R]) -> Callable[..., _R]:
         # The `span` decorator should be idempotent.
         try:
             if hasattr(func, DISPATCHER_SPAN_DECORATED_ATTR):
diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py
index 1b9d139c52..3feee7a0cf 100644
--- a/llama-index-core/llama_index/core/llms/llm.py
+++ b/llama-index-core/llama_index/core/llms/llm.py
@@ -73,6 +73,7 @@ dispatcher = instrument.get_dispatcher(__name__)
 
 if TYPE_CHECKING:
     from llama_index.core.chat_engine.types import AgentChatResponse
+    from llama_index.core.program.utils import FlexibleModel
     from llama_index.core.tools.types import BaseTool
     from llama_index.core.llms.structured_llm import StructuredLLM
 
@@ -322,11 +323,11 @@ class LLM(BaseLLM):
     @dispatcher.span
     def structured_predict(
         self,
-        output_cls: Type[BaseModel],
+        output_cls: Type[Model],
         prompt: PromptTemplate,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **prompt_args: Any,
-    ) -> BaseModel:
+    ) -> Model:
         r"""Structured predict.
 
         Args:
@@ -372,17 +373,19 @@ class LLM(BaseLLM):
         )
 
         result = program(llm_kwargs=llm_kwargs, **prompt_args)
+        assert not isinstance(result, list)
+
         dispatcher.event(LLMStructuredPredictEndEvent(output=result))
         return result
 
     @dispatcher.span
     async def astructured_predict(
         self,
-        output_cls: Type[BaseModel],
+        output_cls: Type[Model],
         prompt: PromptTemplate,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **prompt_args: Any,
-    ) -> BaseModel:
+    ) -> Model:
         r"""Async Structured predict.
 
         Args:
@@ -429,17 +432,19 @@ class LLM(BaseLLM):
         )
 
         result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args)
+        assert not isinstance(result, list)
+
         dispatcher.event(LLMStructuredPredictEndEvent(output=result))
         return result
 
     @dispatcher.span
     def stream_structured_predict(
         self,
-        output_cls: Type[BaseModel],
+        output_cls: Type[Model],
         prompt: PromptTemplate,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **prompt_args: Any,
-    ) -> Generator[Union[Model, List[Model]], None, None]:
+    ) -> Generator[Union[Model, "FlexibleModel"], None, None]:
         r"""Stream Structured predict.
 
         Args:
@@ -489,6 +494,7 @@ class LLM(BaseLLM):
         result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
         for r in result:
             dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
+            assert not isinstance(r, list)
             yield r
 
         dispatcher.event(LLMStructuredPredictEndEvent(output=r))
@@ -496,11 +502,11 @@ class LLM(BaseLLM):
     @dispatcher.span
     async def astream_structured_predict(
         self,
-        output_cls: Type[BaseModel],
+        output_cls: Type[Model],
         prompt: PromptTemplate,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **prompt_args: Any,
-    ) -> AsyncGenerator[Union[Model, List[Model]], None]:
+    ) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
         r"""Async Stream Structured predict.
 
         Args:
@@ -534,8 +540,10 @@ class LLM(BaseLLM):
             ```
         """
 
-        async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
-            from llama_index.core.program.utils import get_program_for_llm
+        async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
+            from llama_index.core.program.utils import (
+                get_program_for_llm,
+            )
 
             dispatcher.event(
                 LLMStructuredPredictStartEvent(
@@ -552,6 +560,7 @@ class LLM(BaseLLM):
             result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
             async for r in result:
                 dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
+                assert not isinstance(r, list)
                 yield r
 
             dispatcher.event(LLMStructuredPredictEndEvent(output=r))
diff --git a/llama-index-core/llama_index/core/output_parsers/pydantic.py b/llama-index-core/llama_index/core/output_parsers/pydantic.py
index faf98ee351..4abe147f73 100644
--- a/llama-index-core/llama_index/core/output_parsers/pydantic.py
+++ b/llama-index-core/llama_index/core/output_parsers/pydantic.py
@@ -1,7 +1,7 @@
 """Pydantic output parser."""
 
 import json
-from typing import Any, List, Optional, Type
+from typing import Any, Generic, List, Optional, Type
 
 from llama_index.core.output_parsers.base import ChainableOutputParser
 from llama_index.core.output_parsers.utils import extract_json_str
@@ -15,7 +15,7 @@ Output a valid JSON object but do not repeat the schema.
 """
 
 
-class PydanticOutputParser(ChainableOutputParser):
+class PydanticOutputParser(ChainableOutputParser, Generic[Model]):
     """Pydantic Output Parser.
 
     Args:
@@ -36,7 +36,7 @@ class PydanticOutputParser(ChainableOutputParser):
 
     @property
     def output_cls(self) -> Type[Model]:
-        return self._output_cls  # type: ignore
+        return self._output_cls
 
     @property
     def format_string(self) -> str:
diff --git a/llama-index-core/llama_index/core/program/function_program.py b/llama-index-core/llama_index/core/program/function_program.py
index e4d071d67d..4d4e2e5b47 100644
--- a/llama-index-core/llama_index/core/program/function_program.py
+++ b/llama-index-core/llama_index/core/program/function_program.py
@@ -14,7 +14,6 @@ from typing import (
 )
 
 from llama_index.core.bridge.pydantic import (
-    BaseModel,
     ValidationError,
 )
 from llama_index.core.base.llms.types import ChatResponse
@@ -26,6 +25,7 @@ from llama_index.core.types import BasePydanticProgram, Model
 from llama_index.core.tools.function_tool import FunctionTool
 from llama_index.core.chat_engine.types import AgentChatResponse
 from llama_index.core.program.utils import (
+    FlexibleModel,
     process_streaming_objects,
     num_valid_fields,
 )
@@ -33,24 +33,6 @@ from llama_index.core.program.utils import (
 _logger = logging.getLogger(__name__)
 
 
-def _parse_tool_outputs(
-    agent_response: AgentChatResponse,
-    allow_parallel_tool_calls: bool = False,
-) -> Union[BaseModel, List[BaseModel]]:
-    """Parse tool outputs."""
-    outputs = [cast(BaseModel, s.raw_output) for s in agent_response.sources]
-    if allow_parallel_tool_calls:
-        return outputs
-    else:
-        if len(outputs) > 1:
-            _logger.warning(
-                "Multiple outputs found, returning first one. "
-                "If you want to return all outputs, set output_multiple=True."
-            )
-
-        return outputs[0]
-
-
 def get_function_tool(output_cls: Type[Model]) -> FunctionTool:
     """Get function tool."""
     schema = output_cls.model_json_schema()
@@ -58,7 +40,7 @@ def get_function_tool(output_cls: Type[Model]) -> FunctionTool:
 
     # NOTE: this does not specify the schema in the function signature,
     # so instead we'll directly provide it in the fn_schema in the ToolMetadata
-    def model_fn(**kwargs: Any) -> BaseModel:
+    def model_fn(**kwargs: Any) -> Model:
         """Model function."""
         return output_cls(**kwargs)
 
@@ -70,7 +52,7 @@ def get_function_tool(output_cls: Type[Model]) -> FunctionTool:
     )
 
 
-class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
+class FunctionCallingProgram(BasePydanticProgram[Model]):
     """Function Calling Program.
 
     Uses function calling LLMs to obtain a structured output.
@@ -122,7 +104,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
             prompt = PromptTemplate(prompt_template_str)
 
         return cls(
-            output_cls=output_cls,  # type: ignore
+            output_cls=output_cls,
             llm=llm,  # type: ignore
             prompt=cast(PromptTemplate, prompt),
             tool_choice=tool_choice,
@@ -131,7 +113,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
         )
 
     @property
-    def output_cls(self) -> Type[BaseModel]:
+    def output_cls(self) -> Type[Model]:
         return self._output_cls
 
     @property
@@ -147,7 +129,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
         *args: Any,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **kwargs: Any,
-    ) -> BaseModel:
+    ) -> Union[Model, List[Model]]:
         llm_kwargs = llm_kwargs or {}
         tool = get_function_tool(self._output_cls)
 
@@ -161,17 +143,17 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
             allow_parallel_tool_calls=self._allow_parallel_tool_calls,
             **llm_kwargs,
         )
-        return _parse_tool_outputs(
+        return self._parse_tool_outputs(
             agent_response,
             allow_parallel_tool_calls=self._allow_parallel_tool_calls,
-        )  # type: ignore
+        )
 
     async def acall(
         self,
         *args: Any,
         llm_kwargs: Optional[Dict[str, Any]] = None,
         **kwargs: Any,
-    ) -> BaseModel:
+    ) -> Union[Model, List[Model]]:
         llm_kwargs = llm_kwargs or {}
         tool = get_function_tool(self._output_cls)
 
@@ -182,16 +164,34 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
             allow_parallel_tool_calls=self._allow_parallel_tool_calls,
             **llm_kwargs,
         )
-        return _parse_tool_outputs(
+        return self._parse_tool_outputs(
             agent_response,
             allow_parallel_tool_calls=self._allow_parallel_tool_calls,
-        )  # type: ignore
+        )
+
+    def _parse_tool_outputs(
+        self,
+        agent_response: AgentChatResponse,
+        allow_parallel_tool_calls: bool = False,
+    ) -> Union[Model, List[Model]]:
+        """Parse tool outputs."""
+        outputs = [cast(Model, s.raw_output) for s in agent_response.sources]
+        if allow_parallel_tool_calls:
+            return outputs
+        else:
+            if len(outputs) > 1:
+                _logger.warning(
+                    "Multiple outputs found, returning first one. "
+                    "If you want to return all outputs, set output_multiple=True."
+                )
+
+            return outputs[0]
 
     def _process_objects(
         self,
         chat_response: ChatResponse,
-        output_cls: Type[BaseModel],
-        cur_objects: Optional[List[BaseModel]] = None,
+        output_cls: Type[Model],
+        cur_objects: Optional[List[Model]] = None,
     ) -> Union[Model, List[Model]]:
         """Process stream."""
         tool_calls = self._llm.get_tool_calls_from_response(
@@ -202,7 +202,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
         # TODO: change
         if len(tool_calls) == 0:
             # if no tool calls, return single blank output_class
-            return output_cls()  # type: ignore
+            return output_cls()
 
         tool_fn_args = [call.tool_kwargs for call in tool_calls]
         objects = [
@@ -222,22 +222,24 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
                 new_obj = self._output_cls.model_validate(obj.model_dump())
             except ValidationError as e:
                 _logger.warning(f"Failed to parse object: {e}")
-                new_obj = obj  # type: ignore
+                new_obj = obj
             new_cur_objects.append(new_obj)
 
         if self._allow_parallel_tool_calls:
-            return new_cur_objects  # type: ignore
+            return new_cur_objects
         else:
             if len(new_cur_objects) > 1:
                 _logger.warning(
                     "Multiple outputs found, returning first one. "
                     "If you want to return all outputs, set output_multiple=True."
                 )
-            return new_cur_objects[0]  # type: ignore
+            return new_cur_objects[0]
 
-    def stream_call(  # type: ignore
+    def stream_call(
         self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
-    ) -> Generator[Union[Model, List[Model]], None, None]:
+    ) -> Generator[
+        Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None, None
+    ]:
         """Stream object.
 
         Returns a generator returning partials of the same object
@@ -273,14 +275,16 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
                     llm=self._llm,
                 )
                 cur_objects = objects if isinstance(objects, list) else [objects]
-                yield objects  # type: ignore
+                yield objects
             except Exception as e:
                 _logger.warning(f"Failed to parse streaming response: {e}")
                 continue
 
-    async def astream_call(  # type: ignore
+    async def astream_call(
         self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
-    ) -> AsyncGenerator[Union[Model, List[Model]], None]:
+    ) -> AsyncGenerator[
+        Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None
+    ]:
         """Stream objects.
 
         Returns a generator returning partials of the same object
@@ -302,7 +306,9 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
             **(llm_kwargs or {}),
         )
 
-        async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
+        async def gen() -> AsyncGenerator[
+            Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None
+        ]:
             cur_objects = None
             async for partial_resp in chat_response_gen:
                 try:
@@ -315,7 +321,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
                         llm=self._llm,
                     )
                     cur_objects = objects if isinstance(objects, list) else [objects]
-                    yield objects  # type: ignore
+                    yield objects
                 except Exception as e:
                     _logger.warning(f"Failed to parse streaming response: {e}")
                     continue
diff --git a/llama-index-core/llama_index/core/program/llm_program.py b/llama-index-core/llama_index/core/program/llm_program.py
index dc73b801f7..2e2b7e5284 100644
--- a/llama-index-core/llama_index/core/program/llm_program.py
+++ b/llama-index-core/llama_index/core/program/llm_program.py
@@ -1,14 +1,13 @@
 from typing import Any, Dict, Optional, Type, cast
 
-from llama_index.core.bridge.pydantic import BaseModel
 from llama_index.core.llms.llm import LLM
 from llama_index.core.output_parsers.pydantic import PydanticOutputParser
 from llama_index.core.prompts.base import BasePromptTemplate, PromptTemplate
 from llama_index.core.settings import Settings
-from llama_index.core.types import BaseOutputParser, BasePydanticProgram
+from llama_index.core.types import BaseOutputParser, BasePydanticProgram, Model
 
 
-class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
+class LLMTextCompletionProgram(BasePydanticProgram[Model]):
     """
     LLM Text Completion Program.
 
@@ -19,7 +18,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
     def __init__(
         self,
         output_parser: BaseOutputParser,
-        output_cls: Type[BaseModel],
+        output_cls: Type[Model],
         prompt: BasePromptTemplate,
         llm: LLM,
         verbose: bool = False,
@@ -36,13 +35,13 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
     def from_defaults(
         cls,
         output_parser: Optional[BaseOutputParser] = None,
-        output_cls: Optional[Type[BaseModel]] = None,
+        output_cls: Optional[Type[Model]] = None,
         prompt_template_str: Optional[str] = None,
         prompt: Optional[BasePromptTemplate] = None,
         llm: Optional[LLM] = None,
         verbose: bool = False,
         **kwargs: Any,
-    ) -> "LLMTextCompletionProgram":
+    ) -> "LLMTextCompletionProgram[Model]":
         llm = llm or Settings.llm
         if prompt is None and prompt_template_str is None:
             raise ValueError("Must provide either prompt or prompt_template_str.")
@@ -69,7 +68,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
         )
 
     @property
-    def output_cls(self) -> Type[BaseModel]:
+    def output_cls(self) -> Type[Model]:
         return self._output_cls
 
     @property
@@ -85,7 +84,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
         llm_kwargs: Optional[Dict[str, Any]] = None,
         *args: Any,
         **kwargs: Any,
-    ) -> BaseModel:
+    ) -> Model:
         llm_kwargs = llm_kwargs or {}
         if self._llm.metadata.is_chat_model:
             messages = self._prompt.format_messages(llm=self._llm, **kwargs)
@@ -112,7 +111,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
         llm_kwargs: Optional[Dict[str, Any]] = None,
         *args: Any,
         **kwargs: Any,
-    ) -> BaseModel:
+    ) -> Model:
         llm_kwargs = llm_kwargs or {}
         if self._llm.metadata.is_chat_model:
             messages = self._prompt.format_messages(llm=self._llm, **kwargs)
diff --git a/llama-index-core/llama_index/core/program/utils.py b/llama-index-core/llama_index/core/program/utils.py
index 6a3a136331..8a4125027b 100644
--- a/llama-index-core/llama_index/core/program/utils.py
+++ b/llama-index-core/llama_index/core/program/utils.py
@@ -13,7 +13,7 @@ from llama_index.core.llms.llm import LLM, ToolSelection
 from llama_index.core.llms.function_calling import FunctionCallingLLM
 from llama_index.core.output_parsers.pydantic import PydanticOutputParser
 from llama_index.core.prompts.base import BasePromptTemplate
-from llama_index.core.types import BasePydanticProgram, PydanticProgramMode
+from llama_index.core.types import BasePydanticProgram, Model, PydanticProgramMode
 from llama_index.core.base.llms.types import ChatResponse
 
 _logger = logging.getLogger(__name__)
@@ -55,12 +55,12 @@ def create_list_model(base_cls: Type[BaseModel]) -> Type[BaseModel]:
 
 
 def get_program_for_llm(
-    output_cls: Type[BaseModel],
+    output_cls: Type[Model],
     prompt: BasePromptTemplate,
     llm: LLM,
     pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
     **kwargs: Any,
-) -> BasePydanticProgram:
+) -> BasePydanticProgram[Model]:
     """Get a program based on the compatible LLM."""
     if pydantic_program_mode == PydanticProgramMode.DEFAULT:
         if llm.metadata.is_function_calling_model:
@@ -161,12 +161,12 @@ def _repair_incomplete_json(json_str: str) -> str:
 
 def process_streaming_objects(
     chat_response: ChatResponse,
-    output_cls: Type[BaseModel],
-    cur_objects: Optional[Sequence[BaseModel]] = None,
+    output_cls: Type[Model],
+    cur_objects: Optional[Sequence[Model]] = None,
     allow_parallel_tool_calls: bool = False,
     flexible_mode: bool = True,
     llm: Optional[FunctionCallingLLM] = None,
-) -> Union[BaseModel, List[BaseModel]]:
+) -> Union[Model, List[Model], FlexibleModel, List[FlexibleModel]]:
     """Process streaming response into structured objects.
 
     Args:
diff --git a/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py b/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py
index bf29bcdca3..db59b02fd6 100644
--- a/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py
+++ b/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py
@@ -2,7 +2,11 @@ import logging
 from typing import Optional
 
 from llama_index.core.base.base_query_engine import BaseQueryEngine
-from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
+from llama_index.core.base.response.schema import (
+    AsyncStreamingResponse,
+    RESPONSE_TYPE,
+    Response,
+)
 from llama_index.core.callbacks.base import CallbackManager
 from llama_index.core.evaluation import BaseEvaluator
 from llama_index.core.indices.list.base import SummaryIndex
@@ -41,6 +45,7 @@ class RetrySourceQueryEngine(BaseQueryEngine):
 
     def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
         response = self._query_engine._query(query_bundle)
+        assert not isinstance(response, AsyncStreamingResponse)
         if self.max_retries <= 0:
             return response
         typed_response = (
diff --git a/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py b/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py
index 896eaea7c3..01db685086 100644
--- a/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py
+++ b/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py
@@ -283,7 +283,7 @@ class SQLJoinQueryEngine(BaseQueryEngine):
         logger.info(f"> query engine response: {other_response}")
 
         if self._streaming:
-            response_str = self._llm.stream(
+            response_gen = self._llm.stream(
                 self._sql_join_synthesis_prompt,
                 query_str=query_bundle.query_str,
                 sql_query_str=sql_query,
@@ -298,7 +298,7 @@ class SQLJoinQueryEngine(BaseQueryEngine):
             }
             source_nodes = other_response.source_nodes
             return StreamingResponse(
-                response_str,
+                response_gen,
                 metadata=response_metadata,
                 source_nodes=source_nodes,
             )
diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py
index 3ba8b0563c..646ca96386 100644
--- a/llama-index-core/llama_index/core/query_pipeline/query.py
+++ b/llama-index-core/llama_index/core/query_pipeline/query.py
@@ -658,9 +658,10 @@ class QueryPipeline(QueryComponent):
                 CBEventType.QUERY,
                 payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
             ) as query_event:
-                return await self._arun_multi(
+                outputs, _ = await self._arun_multi(
                     module_input_dict, show_intermediates=True
                 )
+                return outputs
 
     def _get_root_key_and_kwargs(
         self, *args: Any, **kwargs: Any
diff --git a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
index efda62b2d9..05f40806f6 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py
@@ -102,7 +102,7 @@ class TreeSummarize(BaseSynthesizer):
         else:
             # summarize each chunk
             if self._output_cls is None:
-                tasks = [
+                str_tasks = [
                     self._llm.apredict(
                         summary_template,
                         context_str=text_chunk,
@@ -110,8 +110,9 @@ class TreeSummarize(BaseSynthesizer):
                     )
                     for text_chunk in text_chunks
                 ]
+                summaries = await asyncio.gather(*str_tasks)
             else:
-                tasks = [
+                model_tasks = [
                     self._llm.astructured_predict(
                         self._output_cls,
                         summary_template,
@@ -120,12 +121,8 @@ class TreeSummarize(BaseSynthesizer):
                     )
                     for text_chunk in text_chunks
                 ]
-
-            summary_responses = await asyncio.gather(*tasks)
-            if self._output_cls is not None:
-                summaries = [summary.model_dump_json() for summary in summary_responses]
-            else:
-                summaries = summary_responses
+                summary_models = await asyncio.gather(*model_tasks)
+                summaries = [summary.model_dump_json() for summary in summary_models]
 
             # recursively summarize the summaries
             return await self.aget_response(
diff --git a/llama-index-core/llama_index/core/tools/retriever_tool.py b/llama-index-core/llama_index/core/tools/retriever_tool.py
index 509b935cdc..03e40812ff 100644
--- a/llama-index-core/llama_index/core/tools/retriever_tool.py
+++ b/llama-index-core/llama_index/core/tools/retriever_tool.py
@@ -7,7 +7,13 @@ from llama_index.core.base.base_retriever import BaseRetriever
 
 if TYPE_CHECKING:
     from llama_index.core.langchain_helpers.agents.tools import LlamaIndexTool
-from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
+from llama_index.core.schema import (
+    MetadataMode,
+    Node,
+    NodeWithScore,
+    QueryBundle,
+    TextNode,
+)
 from llama_index.core.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
 from llama_index.core.postprocessor.types import BaseNodePostprocessor
 
@@ -80,6 +86,7 @@ class RetrieverTool(AsyncBaseTool):
         docs = self._apply_node_postprocessors(docs, QueryBundle(query_str))
         content = ""
         for doc in docs:
+            assert isinstance(doc.node, (Node, TextNode))
             node_copy = doc.node.model_copy()
             node_copy.text_template = "{metadata_str}\n{content}"
             node_copy.metadata_template = "{key} = {value}"
@@ -105,6 +112,7 @@ class RetrieverTool(AsyncBaseTool):
         content = ""
         docs = self._apply_node_postprocessors(docs, QueryBundle(query_str))
         for doc in docs:
+            assert isinstance(doc.node, (Node, TextNode))
             node_copy = doc.node.model_copy()
             node_copy.text_template = "{metadata_str}\n{content}"
             node_copy.metadata_template = "{key} = {value}"
diff --git a/llama-index-core/llama_index/core/types.py b/llama-index-core/llama_index/core/types.py
index 7d0fe2ece5..52e4b8dabc 100644
--- a/llama-index-core/llama_index/core/types.py
+++ b/llama-index-core/llama_index/core/types.py
@@ -12,6 +12,7 @@ from typing import (
     Generic,
     List,
     Optional,
+    TYPE_CHECKING,
     Tuple,
     Type,
     TypeVar,
@@ -33,6 +34,9 @@ TokenGen = Generator[str, None, None]
 TokenAsyncGen = AsyncGenerator[str, None]
 RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen, TokenAsyncGen]
 
+if TYPE_CHECKING:
+    from llama_index.core.program.utils import FlexibleModel
+
 
 # TODO: move into a `core` folder
 # NOTE: this is necessary to make it compatible with pydantic
@@ -108,20 +112,24 @@ class BasePydanticProgram(DispatcherSpanMixin, ABC, Generic[Model]):
         pass
 
     @abstractmethod
-    def __call__(self, *args: Any, **kwargs: Any) -> Model:
+    def __call__(self, *args: Any, **kwargs: Any) -> Union[Model, List[Model]]:
         pass
 
-    async def acall(self, *args: Any, **kwargs: Any) -> Model:
+    async def acall(self, *args: Any, **kwargs: Any) -> Union[Model, List[Model]]:
         return self(*args, **kwargs)
 
     def stream_call(
         self, *args: Any, **kwargs: Any
-    ) -> Generator[Union[Model, List[Model]], None, None]:
+    ) -> Generator[
+        Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None, None
+    ]:
         raise NotImplementedError("stream_call is not supported by default.")
 
     async def astream_call(
         self, *args: Any, **kwargs: Any
-    ) -> AsyncGenerator[Union[Model, List[Model]], None]:
+    ) -> AsyncGenerator[
+        Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None
+    ]:
         raise NotImplementedError("astream_call is not supported by default.")
 
 
diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml b/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml
index 3b2d83c53b..7e38beddb4 100644
--- a/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml
+++ b/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml
@@ -30,12 +30,12 @@ license = "MIT"
 name = "llama-index-node-parser-docling"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.3.0"
+version = "0.3.1"
 
 [tool.poetry.dependencies]
 python = "^3.10"
 llama-index-core = "^0.12.0"
-docling-core = "^2.2.0"
+docling-core = "^2.18.0"
 
 [tool.poetry.group.dev]
 
diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py
index 9d7468e2ee..ef4ff7a351 100644
--- a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py
+++ b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py
@@ -34,6 +34,7 @@ out_get_nodes = {
                 "version": "1.0.0",
                 "doc_items": [
                     {
+                        "content_layer": "body",
                         "self_ref": "#/texts/0",
                         "parent": {"$ref": "#/body"},
                         "children": [],
@@ -75,6 +76,7 @@ out_get_nodes = {
                         "version": "1.0.0",
                         "doc_items": [
                             {
+                                "content_layer": "body",
                                 "self_ref": "#/texts/1",
                                 "parent": {"$ref": "#/body"},
                                 "children": [],
@@ -88,7 +90,7 @@ out_get_nodes = {
                             "filename": "sample.html",
                         },
                     },
-                    "hash": "0a8df027ead9e42831f12f8aa680afe5138436ecd58c32a6289212bc4d0a644a",
+                    "hash": "a6c3f2701d8f99dfe60b8fcfa3602c83412e92a234c747ea9b9554bf1894d484",
                     "class_name": "RelatedNodeInfo",
                 },
             },
@@ -110,6 +112,7 @@ out_get_nodes = {
                 "version": "1.0.0",
                 "doc_items": [
                     {
+                        "content_layer": "body",
                         "self_ref": "#/texts/1",
                         "parent": {"$ref": "#/body"},
                         "children": [],
@@ -151,6 +154,7 @@ out_get_nodes = {
                         "version": "1.0.0",
                         "doc_items": [
                             {
+                                "content_layer": "body",
                                 "self_ref": "#/texts/0",
                                 "parent": {"$ref": "#/body"},
                                 "children": [],
@@ -164,7 +168,7 @@ out_get_nodes = {
                             "filename": "sample.html",
                         },
                     },
-                    "hash": "fbfaa945f53349cff0ee00b81a8d3926ca76874fdaf3eac7888f41c5f6a74f0c",
+                    "hash": "0cc445f97c4273fde805655dea71ac576eca74a8210b8589d43fd96a8f79100a",
                     "class_name": "RelatedNodeInfo",
                 },
             },
@@ -192,6 +196,7 @@ out_parse_nodes = {
                 "version": "1.0.0",
                 "doc_items": [
                     {
+                        "content_layer": "body",
                         "self_ref": "#/texts/0",
                         "parent": {"$ref": "#/body"},
                         "children": [],
@@ -244,6 +249,7 @@ out_parse_nodes = {
                 "version": "1.0.0",
                 "doc_items": [
                     {
+                        "content_layer": "body",
                         "self_ref": "#/texts/1",
                         "parent": {"$ref": "#/body"},
                         "children": [],
-- 
GitLab