From e157ebb7e459a9990a86dc5a518cc18a765a109a Mon Sep 17 00:00:00 2001 From: GICodeWarrior <GICodeWarrior@gmail.com> Date: Sun, 16 Feb 2025 17:38:07 -0800 Subject: [PATCH] Retain return type from @dispatcher.span (#17817) --- .../core/chat_engine/condense_plus_context.py | 2 + .../core/evaluation/retrieval/evaluator.py | 4 +- .../core/extractors/metadata_extractors.py | 12 +-- .../core/indices/common_tree/base.py | 3 +- .../core/instrumentation/dispatcher.py | 5 +- llama-index-core/llama_index/core/llms/llm.py | 29 +++--- .../core/output_parsers/pydantic.py | 6 +- .../core/program/function_program.py | 90 ++++++++++--------- .../llama_index/core/program/llm_program.py | 17 ++-- .../llama_index/core/program/utils.py | 12 +-- .../query_engine/retry_source_query_engine.py | 7 +- .../query_engine/sql_join_query_engine.py | 4 +- .../llama_index/core/query_pipeline/query.py | 3 +- .../response_synthesizers/tree_summarize.py | 13 ++- .../llama_index/core/tools/retriever_tool.py | 10 ++- llama-index-core/llama_index/core/types.py | 16 +++- .../pyproject.toml | 4 +- .../tests/test_node_parser_docling.py | 10 ++- 18 files changed, 145 insertions(+), 102 deletions(-) diff --git a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py index 8177ccb7df..86e93bcf8b 100644 --- a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py +++ b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py @@ -347,6 +347,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): ) response = synthesizer.synthesize(message, context_nodes) + assert isinstance(response, StreamingResponse) def wrapped_gen(response: StreamingResponse) -> ChatResponseGen: full_response = "" @@ -405,6 +406,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): ) response = await synthesizer.asynthesize(message, context_nodes) + assert isinstance(response, AsyncStreamingResponse) async def wrapped_gen(response: AsyncStreamingResponse) -> ChatResponseAsyncGen: full_response = "" diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py b/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py index fd686df156..24e7793f84 100644 --- a/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py +++ b/llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py @@ -45,7 +45,7 @@ class RetrieverEvaluator(BaseRetrievalEvaluator): return ( [node.node.node_id for node in retrieved_nodes], - [node.node.text for node in retrieved_nodes], + [node.text for node in retrieved_nodes], ) @@ -84,7 +84,7 @@ class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator): node = scored_node.node if isinstance(node, ImageNode): image_nodes.append(node) - if node.text: + if isinstance(node, TextNode): text_nodes.append(node) if mode == "text": diff --git a/llama-index-core/llama_index/core/extractors/metadata_extractors.py b/llama-index-core/llama_index/core/extractors/metadata_extractors.py index f5ef8f70ad..e416411658 100644 --- a/llama-index-core/llama_index/core/extractors/metadata_extractors.py +++ b/llama-index-core/llama_index/core/extractors/metadata_extractors.py @@ -20,7 +20,7 @@ disambiguate the document or subsection from other similar documents or subsecti (similar with contrastive learning) """ -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, cast from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs from llama_index.core.bridge.pydantic import ( @@ -33,7 +33,7 @@ from llama_index.core.llms.llm import LLM from llama_index.core.prompts import PromptTemplate from llama_index.core.schema import BaseNode, TextNode from llama_index.core.settings import Settings -from llama_index.core.types import BasePydanticProgram +from llama_index.core.types import BasePydanticProgram, Model DEFAULT_TITLE_NODE_TEMPLATE = """\ Context: {context_str}. Give a title that summarizes all of \ @@ -462,7 +462,7 @@ Given the contextual information, extract out a {class_name} object.\ """ -class PydanticProgramExtractor(BaseExtractor): +class PydanticProgramExtractor(BaseExtractor, Generic[Model]): """Pydantic program extractor. Uses an LLM to extract out a Pydantic object. Return attributes of that object @@ -470,7 +470,7 @@ class PydanticProgramExtractor(BaseExtractor): """ - program: SerializeAsAny[BasePydanticProgram] = Field( + program: SerializeAsAny[BasePydanticProgram[Model]] = Field( ..., description="Pydantic program to extract." ) input_key: str = Field( @@ -500,7 +500,9 @@ class PydanticProgramExtractor(BaseExtractor): ) ret_object = await self.program.acall(**{self.input_key: extract_str}) - return ret_object.dict() + assert not isinstance(ret_object, list) + + return ret_object.model_dump() async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]: """Extract pydantic program.""" diff --git a/llama-index-core/llama_index/core/indices/common_tree/base.py b/llama-index-core/llama_index/core/indices/common_tree/base.py index 258f4996a6..0cb075b1ac 100644 --- a/llama-index-core/llama_index/core/indices/common_tree/base.py +++ b/llama-index-core/llama_index/core/indices/common_tree/base.py @@ -219,8 +219,7 @@ class GPTTreeIndexBuilder: self._llm.apredict(self.summary_prompt, context_str=text_chunk) for text_chunk in text_chunks_progress ] - outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks) - summaries = [output[0] for output in outputs] + summaries = await asyncio.gather(*tasks) event.on_end(payload={"summaries": summaries, "level": level}) diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 6ad38eafb1..a00dcbb41d 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -2,7 +2,7 @@ import asyncio from functools import partial from contextlib import contextmanager from contextvars import Context, ContextVar, Token, copy_context -from typing import Any, Callable, Generator, List, Optional, Dict, Protocol +from typing import Any, Callable, Generator, List, Optional, Dict, Protocol, TypeVar import inspect import logging import uuid @@ -26,6 +26,7 @@ _logger = logging.getLogger(__name__) active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar( "instrument_tags", default={} ) +_R = TypeVar("_R") @contextmanager @@ -239,7 +240,7 @@ class Dispatcher(BaseModel): else: c = c.parent - def span(self, func: Callable) -> Any: + def span(self, func: Callable[..., _R]) -> Callable[..., _R]: # The `span` decorator should be idempotent. try: if hasattr(func, DISPATCHER_SPAN_DECORATED_ATTR): diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index 1b9d139c52..3feee7a0cf 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -73,6 +73,7 @@ dispatcher = instrument.get_dispatcher(__name__) if TYPE_CHECKING: from llama_index.core.chat_engine.types import AgentChatResponse + from llama_index.core.program.utils import FlexibleModel from llama_index.core.tools.types import BaseTool from llama_index.core.llms.structured_llm import StructuredLLM @@ -322,11 +323,11 @@ class LLM(BaseLLM): @dispatcher.span def structured_predict( self, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: PromptTemplate, llm_kwargs: Optional[Dict[str, Any]] = None, **prompt_args: Any, - ) -> BaseModel: + ) -> Model: r"""Structured predict. Args: @@ -372,17 +373,19 @@ class LLM(BaseLLM): ) result = program(llm_kwargs=llm_kwargs, **prompt_args) + assert not isinstance(result, list) + dispatcher.event(LLMStructuredPredictEndEvent(output=result)) return result @dispatcher.span async def astructured_predict( self, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: PromptTemplate, llm_kwargs: Optional[Dict[str, Any]] = None, **prompt_args: Any, - ) -> BaseModel: + ) -> Model: r"""Async Structured predict. Args: @@ -429,17 +432,19 @@ class LLM(BaseLLM): ) result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args) + assert not isinstance(result, list) + dispatcher.event(LLMStructuredPredictEndEvent(output=result)) return result @dispatcher.span def stream_structured_predict( self, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: PromptTemplate, llm_kwargs: Optional[Dict[str, Any]] = None, **prompt_args: Any, - ) -> Generator[Union[Model, List[Model]], None, None]: + ) -> Generator[Union[Model, "FlexibleModel"], None, None]: r"""Stream Structured predict. Args: @@ -489,6 +494,7 @@ class LLM(BaseLLM): result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args) for r in result: dispatcher.event(LLMStructuredPredictInProgressEvent(output=r)) + assert not isinstance(r, list) yield r dispatcher.event(LLMStructuredPredictEndEvent(output=r)) @@ -496,11 +502,11 @@ class LLM(BaseLLM): @dispatcher.span async def astream_structured_predict( self, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: PromptTemplate, llm_kwargs: Optional[Dict[str, Any]] = None, **prompt_args: Any, - ) -> AsyncGenerator[Union[Model, List[Model]], None]: + ) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]: r"""Async Stream Structured predict. Args: @@ -534,8 +540,10 @@ class LLM(BaseLLM): ``` """ - async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]: - from llama_index.core.program.utils import get_program_for_llm + async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]: + from llama_index.core.program.utils import ( + get_program_for_llm, + ) dispatcher.event( LLMStructuredPredictStartEvent( @@ -552,6 +560,7 @@ class LLM(BaseLLM): result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args) async for r in result: dispatcher.event(LLMStructuredPredictInProgressEvent(output=r)) + assert not isinstance(r, list) yield r dispatcher.event(LLMStructuredPredictEndEvent(output=r)) diff --git a/llama-index-core/llama_index/core/output_parsers/pydantic.py b/llama-index-core/llama_index/core/output_parsers/pydantic.py index faf98ee351..4abe147f73 100644 --- a/llama-index-core/llama_index/core/output_parsers/pydantic.py +++ b/llama-index-core/llama_index/core/output_parsers/pydantic.py @@ -1,7 +1,7 @@ """Pydantic output parser.""" import json -from typing import Any, List, Optional, Type +from typing import Any, Generic, List, Optional, Type from llama_index.core.output_parsers.base import ChainableOutputParser from llama_index.core.output_parsers.utils import extract_json_str @@ -15,7 +15,7 @@ Output a valid JSON object but do not repeat the schema. """ -class PydanticOutputParser(ChainableOutputParser): +class PydanticOutputParser(ChainableOutputParser, Generic[Model]): """Pydantic Output Parser. Args: @@ -36,7 +36,7 @@ class PydanticOutputParser(ChainableOutputParser): @property def output_cls(self) -> Type[Model]: - return self._output_cls # type: ignore + return self._output_cls @property def format_string(self) -> str: diff --git a/llama-index-core/llama_index/core/program/function_program.py b/llama-index-core/llama_index/core/program/function_program.py index e4d071d67d..4d4e2e5b47 100644 --- a/llama-index-core/llama_index/core/program/function_program.py +++ b/llama-index-core/llama_index/core/program/function_program.py @@ -14,7 +14,6 @@ from typing import ( ) from llama_index.core.bridge.pydantic import ( - BaseModel, ValidationError, ) from llama_index.core.base.llms.types import ChatResponse @@ -26,6 +25,7 @@ from llama_index.core.types import BasePydanticProgram, Model from llama_index.core.tools.function_tool import FunctionTool from llama_index.core.chat_engine.types import AgentChatResponse from llama_index.core.program.utils import ( + FlexibleModel, process_streaming_objects, num_valid_fields, ) @@ -33,24 +33,6 @@ from llama_index.core.program.utils import ( _logger = logging.getLogger(__name__) -def _parse_tool_outputs( - agent_response: AgentChatResponse, - allow_parallel_tool_calls: bool = False, -) -> Union[BaseModel, List[BaseModel]]: - """Parse tool outputs.""" - outputs = [cast(BaseModel, s.raw_output) for s in agent_response.sources] - if allow_parallel_tool_calls: - return outputs - else: - if len(outputs) > 1: - _logger.warning( - "Multiple outputs found, returning first one. " - "If you want to return all outputs, set output_multiple=True." - ) - - return outputs[0] - - def get_function_tool(output_cls: Type[Model]) -> FunctionTool: """Get function tool.""" schema = output_cls.model_json_schema() @@ -58,7 +40,7 @@ def get_function_tool(output_cls: Type[Model]) -> FunctionTool: # NOTE: this does not specify the schema in the function signature, # so instead we'll directly provide it in the fn_schema in the ToolMetadata - def model_fn(**kwargs: Any) -> BaseModel: + def model_fn(**kwargs: Any) -> Model: """Model function.""" return output_cls(**kwargs) @@ -70,7 +52,7 @@ def get_function_tool(output_cls: Type[Model]) -> FunctionTool: ) -class FunctionCallingProgram(BasePydanticProgram[BaseModel]): +class FunctionCallingProgram(BasePydanticProgram[Model]): """Function Calling Program. Uses function calling LLMs to obtain a structured output. @@ -122,7 +104,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): prompt = PromptTemplate(prompt_template_str) return cls( - output_cls=output_cls, # type: ignore + output_cls=output_cls, llm=llm, # type: ignore prompt=cast(PromptTemplate, prompt), tool_choice=tool_choice, @@ -131,7 +113,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): ) @property - def output_cls(self) -> Type[BaseModel]: + def output_cls(self) -> Type[Model]: return self._output_cls @property @@ -147,7 +129,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> BaseModel: + ) -> Union[Model, List[Model]]: llm_kwargs = llm_kwargs or {} tool = get_function_tool(self._output_cls) @@ -161,17 +143,17 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): allow_parallel_tool_calls=self._allow_parallel_tool_calls, **llm_kwargs, ) - return _parse_tool_outputs( + return self._parse_tool_outputs( agent_response, allow_parallel_tool_calls=self._allow_parallel_tool_calls, - ) # type: ignore + ) async def acall( self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> BaseModel: + ) -> Union[Model, List[Model]]: llm_kwargs = llm_kwargs or {} tool = get_function_tool(self._output_cls) @@ -182,16 +164,34 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): allow_parallel_tool_calls=self._allow_parallel_tool_calls, **llm_kwargs, ) - return _parse_tool_outputs( + return self._parse_tool_outputs( agent_response, allow_parallel_tool_calls=self._allow_parallel_tool_calls, - ) # type: ignore + ) + + def _parse_tool_outputs( + self, + agent_response: AgentChatResponse, + allow_parallel_tool_calls: bool = False, + ) -> Union[Model, List[Model]]: + """Parse tool outputs.""" + outputs = [cast(Model, s.raw_output) for s in agent_response.sources] + if allow_parallel_tool_calls: + return outputs + else: + if len(outputs) > 1: + _logger.warning( + "Multiple outputs found, returning first one. " + "If you want to return all outputs, set output_multiple=True." + ) + + return outputs[0] def _process_objects( self, chat_response: ChatResponse, - output_cls: Type[BaseModel], - cur_objects: Optional[List[BaseModel]] = None, + output_cls: Type[Model], + cur_objects: Optional[List[Model]] = None, ) -> Union[Model, List[Model]]: """Process stream.""" tool_calls = self._llm.get_tool_calls_from_response( @@ -202,7 +202,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): # TODO: change if len(tool_calls) == 0: # if no tool calls, return single blank output_class - return output_cls() # type: ignore + return output_cls() tool_fn_args = [call.tool_kwargs for call in tool_calls] objects = [ @@ -222,22 +222,24 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): new_obj = self._output_cls.model_validate(obj.model_dump()) except ValidationError as e: _logger.warning(f"Failed to parse object: {e}") - new_obj = obj # type: ignore + new_obj = obj new_cur_objects.append(new_obj) if self._allow_parallel_tool_calls: - return new_cur_objects # type: ignore + return new_cur_objects else: if len(new_cur_objects) > 1: _logger.warning( "Multiple outputs found, returning first one. " "If you want to return all outputs, set output_multiple=True." ) - return new_cur_objects[0] # type: ignore + return new_cur_objects[0] - def stream_call( # type: ignore + def stream_call( self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Generator[Union[Model, List[Model]], None, None]: + ) -> Generator[ + Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None, None + ]: """Stream object. Returns a generator returning partials of the same object @@ -273,14 +275,16 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): llm=self._llm, ) cur_objects = objects if isinstance(objects, list) else [objects] - yield objects # type: ignore + yield objects except Exception as e: _logger.warning(f"Failed to parse streaming response: {e}") continue - async def astream_call( # type: ignore + async def astream_call( self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> AsyncGenerator[Union[Model, List[Model]], None]: + ) -> AsyncGenerator[ + Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None + ]: """Stream objects. Returns a generator returning partials of the same object @@ -302,7 +306,9 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): **(llm_kwargs or {}), ) - async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]: + async def gen() -> AsyncGenerator[ + Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None + ]: cur_objects = None async for partial_resp in chat_response_gen: try: @@ -315,7 +321,7 @@ class FunctionCallingProgram(BasePydanticProgram[BaseModel]): llm=self._llm, ) cur_objects = objects if isinstance(objects, list) else [objects] - yield objects # type: ignore + yield objects except Exception as e: _logger.warning(f"Failed to parse streaming response: {e}") continue diff --git a/llama-index-core/llama_index/core/program/llm_program.py b/llama-index-core/llama_index/core/program/llm_program.py index dc73b801f7..2e2b7e5284 100644 --- a/llama-index-core/llama_index/core/program/llm_program.py +++ b/llama-index-core/llama_index/core/program/llm_program.py @@ -1,14 +1,13 @@ from typing import Any, Dict, Optional, Type, cast -from llama_index.core.bridge.pydantic import BaseModel from llama_index.core.llms.llm import LLM from llama_index.core.output_parsers.pydantic import PydanticOutputParser from llama_index.core.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.core.settings import Settings -from llama_index.core.types import BaseOutputParser, BasePydanticProgram +from llama_index.core.types import BaseOutputParser, BasePydanticProgram, Model -class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): +class LLMTextCompletionProgram(BasePydanticProgram[Model]): """ LLM Text Completion Program. @@ -19,7 +18,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): def __init__( self, output_parser: BaseOutputParser, - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: BasePromptTemplate, llm: LLM, verbose: bool = False, @@ -36,13 +35,13 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): def from_defaults( cls, output_parser: Optional[BaseOutputParser] = None, - output_cls: Optional[Type[BaseModel]] = None, + output_cls: Optional[Type[Model]] = None, prompt_template_str: Optional[str] = None, prompt: Optional[BasePromptTemplate] = None, llm: Optional[LLM] = None, verbose: bool = False, **kwargs: Any, - ) -> "LLMTextCompletionProgram": + ) -> "LLMTextCompletionProgram[Model]": llm = llm or Settings.llm if prompt is None and prompt_template_str is None: raise ValueError("Must provide either prompt or prompt_template_str.") @@ -69,7 +68,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): ) @property - def output_cls(self) -> Type[BaseModel]: + def output_cls(self) -> Type[Model]: return self._output_cls @property @@ -85,7 +84,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): llm_kwargs: Optional[Dict[str, Any]] = None, *args: Any, **kwargs: Any, - ) -> BaseModel: + ) -> Model: llm_kwargs = llm_kwargs or {} if self._llm.metadata.is_chat_model: messages = self._prompt.format_messages(llm=self._llm, **kwargs) @@ -112,7 +111,7 @@ class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]): llm_kwargs: Optional[Dict[str, Any]] = None, *args: Any, **kwargs: Any, - ) -> BaseModel: + ) -> Model: llm_kwargs = llm_kwargs or {} if self._llm.metadata.is_chat_model: messages = self._prompt.format_messages(llm=self._llm, **kwargs) diff --git a/llama-index-core/llama_index/core/program/utils.py b/llama-index-core/llama_index/core/program/utils.py index 6a3a136331..8a4125027b 100644 --- a/llama-index-core/llama_index/core/program/utils.py +++ b/llama-index-core/llama_index/core/program/utils.py @@ -13,7 +13,7 @@ from llama_index.core.llms.llm import LLM, ToolSelection from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.output_parsers.pydantic import PydanticOutputParser from llama_index.core.prompts.base import BasePromptTemplate -from llama_index.core.types import BasePydanticProgram, PydanticProgramMode +from llama_index.core.types import BasePydanticProgram, Model, PydanticProgramMode from llama_index.core.base.llms.types import ChatResponse _logger = logging.getLogger(__name__) @@ -55,12 +55,12 @@ def create_list_model(base_cls: Type[BaseModel]) -> Type[BaseModel]: def get_program_for_llm( - output_cls: Type[BaseModel], + output_cls: Type[Model], prompt: BasePromptTemplate, llm: LLM, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, **kwargs: Any, -) -> BasePydanticProgram: +) -> BasePydanticProgram[Model]: """Get a program based on the compatible LLM.""" if pydantic_program_mode == PydanticProgramMode.DEFAULT: if llm.metadata.is_function_calling_model: @@ -161,12 +161,12 @@ def _repair_incomplete_json(json_str: str) -> str: def process_streaming_objects( chat_response: ChatResponse, - output_cls: Type[BaseModel], - cur_objects: Optional[Sequence[BaseModel]] = None, + output_cls: Type[Model], + cur_objects: Optional[Sequence[Model]] = None, allow_parallel_tool_calls: bool = False, flexible_mode: bool = True, llm: Optional[FunctionCallingLLM] = None, -) -> Union[BaseModel, List[BaseModel]]: +) -> Union[Model, List[Model], FlexibleModel, List[FlexibleModel]]: """Process streaming response into structured objects. Args: diff --git a/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py b/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py index bf29bcdca3..db59b02fd6 100644 --- a/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py +++ b/llama-index-core/llama_index/core/query_engine/retry_source_query_engine.py @@ -2,7 +2,11 @@ import logging from typing import Optional from llama_index.core.base.base_query_engine import BaseQueryEngine -from llama_index.core.base.response.schema import RESPONSE_TYPE, Response +from llama_index.core.base.response.schema import ( + AsyncStreamingResponse, + RESPONSE_TYPE, + Response, +) from llama_index.core.callbacks.base import CallbackManager from llama_index.core.evaluation import BaseEvaluator from llama_index.core.indices.list.base import SummaryIndex @@ -41,6 +45,7 @@ class RetrySourceQueryEngine(BaseQueryEngine): def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: response = self._query_engine._query(query_bundle) + assert not isinstance(response, AsyncStreamingResponse) if self.max_retries <= 0: return response typed_response = ( diff --git a/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py b/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py index 896eaea7c3..01db685086 100644 --- a/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py +++ b/llama-index-core/llama_index/core/query_engine/sql_join_query_engine.py @@ -283,7 +283,7 @@ class SQLJoinQueryEngine(BaseQueryEngine): logger.info(f"> query engine response: {other_response}") if self._streaming: - response_str = self._llm.stream( + response_gen = self._llm.stream( self._sql_join_synthesis_prompt, query_str=query_bundle.query_str, sql_query_str=sql_query, @@ -298,7 +298,7 @@ class SQLJoinQueryEngine(BaseQueryEngine): } source_nodes = other_response.source_nodes return StreamingResponse( - response_str, + response_gen, metadata=response_metadata, source_nodes=source_nodes, ) diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py index 3ba8b0563c..646ca96386 100644 --- a/llama-index-core/llama_index/core/query_pipeline/query.py +++ b/llama-index-core/llama_index/core/query_pipeline/query.py @@ -658,9 +658,10 @@ class QueryPipeline(QueryComponent): CBEventType.QUERY, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, ) as query_event: - return await self._arun_multi( + outputs, _ = await self._arun_multi( module_input_dict, show_intermediates=True ) + return outputs def _get_root_key_and_kwargs( self, *args: Any, **kwargs: Any diff --git a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py index efda62b2d9..05f40806f6 100644 --- a/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py +++ b/llama-index-core/llama_index/core/response_synthesizers/tree_summarize.py @@ -102,7 +102,7 @@ class TreeSummarize(BaseSynthesizer): else: # summarize each chunk if self._output_cls is None: - tasks = [ + str_tasks = [ self._llm.apredict( summary_template, context_str=text_chunk, @@ -110,8 +110,9 @@ class TreeSummarize(BaseSynthesizer): ) for text_chunk in text_chunks ] + summaries = await asyncio.gather(*str_tasks) else: - tasks = [ + model_tasks = [ self._llm.astructured_predict( self._output_cls, summary_template, @@ -120,12 +121,8 @@ class TreeSummarize(BaseSynthesizer): ) for text_chunk in text_chunks ] - - summary_responses = await asyncio.gather(*tasks) - if self._output_cls is not None: - summaries = [summary.model_dump_json() for summary in summary_responses] - else: - summaries = summary_responses + summary_models = await asyncio.gather(*model_tasks) + summaries = [summary.model_dump_json() for summary in summary_models] # recursively summarize the summaries return await self.aget_response( diff --git a/llama-index-core/llama_index/core/tools/retriever_tool.py b/llama-index-core/llama_index/core/tools/retriever_tool.py index 509b935cdc..03e40812ff 100644 --- a/llama-index-core/llama_index/core/tools/retriever_tool.py +++ b/llama-index-core/llama_index/core/tools/retriever_tool.py @@ -7,7 +7,13 @@ from llama_index.core.base.base_retriever import BaseRetriever if TYPE_CHECKING: from llama_index.core.langchain_helpers.agents.tools import LlamaIndexTool -from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle +from llama_index.core.schema import ( + MetadataMode, + Node, + NodeWithScore, + QueryBundle, + TextNode, +) from llama_index.core.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput from llama_index.core.postprocessor.types import BaseNodePostprocessor @@ -80,6 +86,7 @@ class RetrieverTool(AsyncBaseTool): docs = self._apply_node_postprocessors(docs, QueryBundle(query_str)) content = "" for doc in docs: + assert isinstance(doc.node, (Node, TextNode)) node_copy = doc.node.model_copy() node_copy.text_template = "{metadata_str}\n{content}" node_copy.metadata_template = "{key} = {value}" @@ -105,6 +112,7 @@ class RetrieverTool(AsyncBaseTool): content = "" docs = self._apply_node_postprocessors(docs, QueryBundle(query_str)) for doc in docs: + assert isinstance(doc.node, (Node, TextNode)) node_copy = doc.node.model_copy() node_copy.text_template = "{metadata_str}\n{content}" node_copy.metadata_template = "{key} = {value}" diff --git a/llama-index-core/llama_index/core/types.py b/llama-index-core/llama_index/core/types.py index 7d0fe2ece5..52e4b8dabc 100644 --- a/llama-index-core/llama_index/core/types.py +++ b/llama-index-core/llama_index/core/types.py @@ -12,6 +12,7 @@ from typing import ( Generic, List, Optional, + TYPE_CHECKING, Tuple, Type, TypeVar, @@ -33,6 +34,9 @@ TokenGen = Generator[str, None, None] TokenAsyncGen = AsyncGenerator[str, None] RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen, TokenAsyncGen] +if TYPE_CHECKING: + from llama_index.core.program.utils import FlexibleModel + # TODO: move into a `core` folder # NOTE: this is necessary to make it compatible with pydantic @@ -108,20 +112,24 @@ class BasePydanticProgram(DispatcherSpanMixin, ABC, Generic[Model]): pass @abstractmethod - def __call__(self, *args: Any, **kwargs: Any) -> Model: + def __call__(self, *args: Any, **kwargs: Any) -> Union[Model, List[Model]]: pass - async def acall(self, *args: Any, **kwargs: Any) -> Model: + async def acall(self, *args: Any, **kwargs: Any) -> Union[Model, List[Model]]: return self(*args, **kwargs) def stream_call( self, *args: Any, **kwargs: Any - ) -> Generator[Union[Model, List[Model]], None, None]: + ) -> Generator[ + Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None, None + ]: raise NotImplementedError("stream_call is not supported by default.") async def astream_call( self, *args: Any, **kwargs: Any - ) -> AsyncGenerator[Union[Model, List[Model]], None]: + ) -> AsyncGenerator[ + Union[Model, List[Model], "FlexibleModel", List["FlexibleModel"]], None + ]: raise NotImplementedError("astream_call is not supported by default.") diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml b/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml index 3b2d83c53b..7e38beddb4 100644 --- a/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml +++ b/llama-index-integrations/node_parser/llama-index-node-parser-docling/pyproject.toml @@ -30,12 +30,12 @@ license = "MIT" name = "llama-index-node-parser-docling" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = "^3.10" llama-index-core = "^0.12.0" -docling-core = "^2.2.0" +docling-core = "^2.18.0" [tool.poetry.group.dev] diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py index 9d7468e2ee..ef4ff7a351 100644 --- a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py +++ b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py @@ -34,6 +34,7 @@ out_get_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/0", "parent": {"$ref": "#/body"}, "children": [], @@ -75,6 +76,7 @@ out_get_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/1", "parent": {"$ref": "#/body"}, "children": [], @@ -88,7 +90,7 @@ out_get_nodes = { "filename": "sample.html", }, }, - "hash": "0a8df027ead9e42831f12f8aa680afe5138436ecd58c32a6289212bc4d0a644a", + "hash": "a6c3f2701d8f99dfe60b8fcfa3602c83412e92a234c747ea9b9554bf1894d484", "class_name": "RelatedNodeInfo", }, }, @@ -110,6 +112,7 @@ out_get_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/1", "parent": {"$ref": "#/body"}, "children": [], @@ -151,6 +154,7 @@ out_get_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/0", "parent": {"$ref": "#/body"}, "children": [], @@ -164,7 +168,7 @@ out_get_nodes = { "filename": "sample.html", }, }, - "hash": "fbfaa945f53349cff0ee00b81a8d3926ca76874fdaf3eac7888f41c5f6a74f0c", + "hash": "0cc445f97c4273fde805655dea71ac576eca74a8210b8589d43fd96a8f79100a", "class_name": "RelatedNodeInfo", }, }, @@ -192,6 +196,7 @@ out_parse_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/0", "parent": {"$ref": "#/body"}, "children": [], @@ -244,6 +249,7 @@ out_parse_nodes = { "version": "1.0.0", "doc_items": [ { + "content_layer": "body", "self_ref": "#/texts/1", "parent": {"$ref": "#/body"}, "children": [], -- GitLab