From 01df13f4a37e8cc29497ce41b1abed0ab8283f9b Mon Sep 17 00:00:00 2001 From: Ryan Nguyen <96593302+xpbowler@users.noreply.github.com> Date: Thu, 16 May 2024 22:33:59 -0400 Subject: [PATCH] Add batch mode to QueryPipeline (#13203) --- .../querying/pipeline/usage_pattern.md | 22 ++ .../llama_index/core/query_pipeline/query.py | 197 ++++++++++++++++-- .../tests/query_pipeline/test_query.py | 56 +++++ 3 files changed, 256 insertions(+), 19 deletions(-) diff --git a/docs/docs/module_guides/querying/pipeline/usage_pattern.md b/docs/docs/module_guides/querying/pipeline/usage_pattern.md index d2341821a..bd2aac0f6 100644 --- a/docs/docs/module_guides/querying/pipeline/usage_pattern.md +++ b/docs/docs/module_guides/querying/pipeline/usage_pattern.md @@ -113,6 +113,28 @@ p = QueryPipeline(chain=[prompt_tmpl, llm, summarizer_c]) p.run(topic="YC") ``` +### Batch Input + +If you wish to run the pipeline for several rounds of single/multi-inputs, set `batch=True` in the function call - supported by `run`, `arun`, `run_multi`, and `arun_multi`. Pass in a list of individual single/multi-inputs you would like to run. `batch` mode will return a list of responses in the same order as the inputs. + +Example for single-input/single-output: `p.run(field=[in1: Any, in2: Any], batch=True)` --> `[out1: Any, out2: Any]` + +```python +output = p.run(topic=["YC", "RAG", "LlamaIndex"], batch=True) +# output is [ResponseYC, ResponseRAG, ResponseLlamaIndex] +print(output) +``` + +Example for multi-input/multi-output: `p.run_multi("root_node": {"field": [in1: Any, in2, Any]}, batch=True)` --> `{"output_node": {"field": [out1: Any, out2: Any]}}` + +```python +output_dict = p.run_multi({"llm": {"topic": ["YC", "RAG", "LlamaIndex"]}}) +print(output_dict) + +# output dict is {"summarizer": {"output": [ResponseYC, ResponseRAG, ResponseLlamaIndex]}} +``` + + ### Intermediate outputs If you wish to obtain the intermediate outputs of modules in QueryPipeline, you can use `run_with_intermediates` or `run_multi_with_intermediates` for single-input and multi-input, respectively. diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py index 2acf88ea8..2f58549c7 100644 --- a/llama-index-core/llama_index/core/query_pipeline/query.py +++ b/llama-index-core/llama_index/core/query_pipeline/query.py @@ -17,7 +17,7 @@ from typing import ( import networkx -from llama_index.core.async_utils import run_jobs +from llama_index.core.async_utils import asyncio_run, run_jobs from llama_index.core.bridge.pydantic import Field from llama_index.core.callbacks import CallbackManager from llama_index.core.callbacks.schema import CBEventType, EventPayload @@ -302,6 +302,7 @@ class QueryPipeline(QueryComponent): *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, + batch: bool = False, **kwargs: Any, ) -> Any: """Run the pipeline.""" @@ -321,8 +322,10 @@ class QueryPipeline(QueryComponent): *args, return_values_direct=return_values_direct, show_intermediates=False, + batch=batch, **kwargs, ) + return outputs def run_with_intermediates( @@ -330,9 +333,13 @@ class QueryPipeline(QueryComponent): *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, + batch: Optional[bool] = None, **kwargs: Any, ) -> Tuple[Any, Dict[str, ComponentIntermediates]]: """Run the pipeline.""" + if batch is not None: + raise ValueError("Batch is not supported for run_with_intermediates.") + # first set callback manager callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) @@ -352,10 +359,27 @@ class QueryPipeline(QueryComponent): **kwargs, ) + def merge_dicts(self, d1, d2): + """Merge two dictionaries recursively, combining values of the same key into a list.""" + merged = {} + for key in set(d1).union(d2): + if key in d1 and key in d2: + if isinstance(d1[key], dict) and isinstance(d2[key], dict): + merged[key] = self.merge_dicts(d1[key], d2[key]) + else: + merged[key] = ( + [d1[key]] if not isinstance(d1[key], list) else d1[key] + ) + merged[key].append(d2[key]) + else: + merged[key] = d1.get(key, d2.get(key)) + return merged + def run_multi( self, module_input_dict: Dict[str, Any], callback_manager: Optional[CallbackManager] = None, + batch: bool = False, ) -> Dict[str, Any]: """Run the pipeline for multiple roots.""" callback_manager = callback_manager or self.callback_manager @@ -365,8 +389,41 @@ class QueryPipeline(QueryComponent): CBEventType.QUERY, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, ) as query_event: - outputs, _ = self._run_multi(module_input_dict) - return outputs + if batch: + outputs = {} + + batch_lengths = { + len(values) + for subdict in module_input_dict.values() + for values in subdict.values() + } + + if len(batch_lengths) != 1: + raise ValueError("Length of batch inputs must be the same.") + + batch_size = next(iter(batch_lengths)) + + # List individual outputs from batch multi input. + inputs = [ + { + key: { + inner_key: inner_val[i] + for inner_key, inner_val in value.items() + } + for key, value in module_input_dict.items() + } + for i in range(batch_size) + ] + jobs = [self._arun_multi(input) for input in inputs] + results = asyncio_run(run_jobs(jobs, workers=len(jobs))) + + for result in results: + outputs = self.merge_dicts(outputs, result[0]) + + return outputs + else: + outputs, _ = self._run_multi(module_input_dict) + return outputs def run_multi_with_intermediates( self, @@ -388,6 +445,7 @@ class QueryPipeline(QueryComponent): *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, + batch: bool = False, **kwargs: Any, ) -> Any: """Run the pipeline.""" @@ -406,8 +464,10 @@ class QueryPipeline(QueryComponent): *args, return_values_direct=return_values_direct, show_intermediates=False, + batch=batch, **kwargs, ) + return outputs async def arun_with_intermediates( @@ -415,9 +475,13 @@ class QueryPipeline(QueryComponent): *args: Any, return_values_direct: bool = True, callback_manager: Optional[CallbackManager] = None, + batch: Optional[bool] = None, **kwargs: Any, ) -> Tuple[Any, Dict[str, ComponentIntermediates]]: """Run the pipeline.""" + if batch is not None: + raise ValueError("Batch is not supported for run_with_intermediates.") + # first set callback manager callback_manager = callback_manager or self.callback_manager self.set_callback_manager(callback_manager) @@ -440,6 +504,7 @@ class QueryPipeline(QueryComponent): self, module_input_dict: Dict[str, Any], callback_manager: Optional[CallbackManager] = None, + batch: bool = False, ) -> Dict[str, Any]: """Run the pipeline for multiple roots.""" callback_manager = callback_manager or self.callback_manager @@ -449,8 +514,42 @@ class QueryPipeline(QueryComponent): CBEventType.QUERY, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, ) as query_event: - outputs, _ = await self._arun_multi(module_input_dict) - return outputs + if batch: + outputs = {} + + batch_lengths = { + len(values) + for subdict in module_input_dict.values() + for values in subdict.values() + } + + if len(batch_lengths) != 1: + raise ValueError("Length of batch inputs must be the same.") + + batch_size = next(iter(batch_lengths)) + + # List individual outputs from batch multi input. + inputs = [ + { + key: { + inner_key: inner_val[i] + for inner_key, inner_val in value.items() + } + for key, value in module_input_dict.items() + } + for i in range(batch_size) + ] + + jobs = [self._arun_multi(input) for input in inputs] + results = await run_jobs(jobs, workers=len(jobs)) + + for result in results: + outputs = self.merge_dicts(outputs, result[0]) + + return outputs + else: + outputs, _ = await self._arun_multi(module_input_dict) + return outputs async def arun_multi_with_intermediates( self, @@ -530,6 +629,7 @@ class QueryPipeline(QueryComponent): *args: Any, return_values_direct: bool = True, show_intermediates: bool = False, + batch: bool = False, **kwargs: Any, ) -> Tuple[Any, Dict[str, ComponentIntermediates]]: """Run the pipeline. @@ -541,20 +641,50 @@ class QueryPipeline(QueryComponent): """ root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) - result_outputs, intermediates = self._run_multi( - {root_key: kwargs}, show_intermediates=show_intermediates - ) + if batch: + result_outputs = [] + intermediates = [] - return ( - self._get_single_result_output(result_outputs, return_values_direct), - intermediates, - ) + if len({len(value) for value in kwargs.values()}) != 1: + raise ValueError("Length of batch inputs must be the same.") + + # List of individual inputs from batch input + kwargs = [ + dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values()) + ] + + jobs = [ + self._arun_multi( + {root_key: kwarg}, show_intermediates=show_intermediates + ) + for kwarg in kwargs + ] + + results = asyncio_run(run_jobs(jobs, workers=len(jobs))) + + for result in results: + result_outputs.append( + self._get_single_result_output(result[0], return_values_direct) + ) + intermediates.append(result[1]) + + return result_outputs, intermediates + else: + result_outputs, intermediates = self._run_multi( + {root_key: kwargs}, show_intermediates=show_intermediates + ) + + return ( + self._get_single_result_output(result_outputs, return_values_direct), + intermediates, + ) async def _arun( self, *args: Any, return_values_direct: bool = True, show_intermediates: bool = False, + batch: bool = False, **kwargs: Any, ) -> Tuple[Any, Dict[str, ComponentIntermediates]]: """Run the pipeline. @@ -566,14 +696,43 @@ class QueryPipeline(QueryComponent): """ root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) - result_outputs, intermediates = await self._arun_multi( - {root_key: kwargs}, show_intermediates=show_intermediates - ) + if batch: + result_outputs = [] + intermediates = [] - return ( - self._get_single_result_output(result_outputs, return_values_direct), - intermediates, - ) + if len({len(value) for value in kwargs.values()}) != 1: + raise ValueError("Length of batch inputs must be the same.") + + # List of individual inputs from batch input + kwargs = [ + dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values()) + ] + + jobs = [ + self._arun_multi( + {root_key: kwarg}, show_intermediates=show_intermediates + ) + for kwarg in kwargs + ] + + results = await run_jobs(jobs, workers=len(jobs)) + + for result in results: + result_outputs.append( + self._get_single_result_output(result[0], return_values_direct) + ) + intermediates.append(result[1]) + + return result_outputs, intermediates + else: + result_outputs, intermediates = await self._arun_multi( + {root_key: kwargs}, show_intermediates=show_intermediates + ) + + return ( + self._get_single_result_output(result_outputs, return_values_direct), + intermediates, + ) def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None: root_keys = self._get_root_keys() diff --git a/llama-index-core/tests/query_pipeline/test_query.py b/llama-index-core/tests/query_pipeline/test_query.py index f0209c5d4..b39ff8c42 100644 --- a/llama-index-core/tests/query_pipeline/test_query.py +++ b/llama-index-core/tests/query_pipeline/test_query.py @@ -233,6 +233,27 @@ def test_query_pipeline_multi() -> None: assert output == {"qc2": {"output": "3:7"}} +def test_query_pipeline_multi_batch() -> None: + """Test query pipeline.""" + # try run run_multi + # link both qc1_0 and qc1_1 to qc2 + qc1_0 = QueryComponent1() + qc1_1 = QueryComponent1() + qc2 = QueryComponent2() + p = QueryPipeline() + p.add_modules({"qc1_0": qc1_0, "qc1_1": qc1_1, "qc2": qc2}) + p.add_link("qc1_0", "qc2", dest_key="input1") + p.add_link("qc1_1", "qc2", dest_key="input2") + output = p.run_multi( + { + "qc1_0": {"input1": [1, 5], "input2": [2, 1]}, + "qc1_1": {"input1": [3, 7], "input2": [4, 2]}, + }, + batch=True, + ) + assert output == {"qc2": {"output": ["3:7", "6:9"]}} + + def test_query_pipeline_multi_intermediate_output() -> None: """Test query pipeline showing intermediate outputs.""" # try run run_multi_with_intermediates @@ -298,6 +319,9 @@ async def test_query_pipeline_async() -> None: output = await p.arun(inp1=1, inp2=2) assert output == "3:1" + output = await p.arun(inp1=[1, 2], inp2=[2, 3], batch=True) + assert output == ["3:1", "5:2"] + # try run run_multi # link both qc1_0 and qc1_1 to qc2 qc1_0 = QueryComponent1() @@ -312,6 +336,15 @@ async def test_query_pipeline_async() -> None: ) assert output == {"qc2": {"output": "3:7"}} + output = await p.arun_multi( + { + "qc1_0": {"input1": [1, 5], "input2": [2, 1]}, + "qc1_1": {"input1": [3, 7], "input2": [4, 2]}, + }, + batch=True, + ) + assert output == {"qc2": {"output": ["3:7", "6:9"]}} + def test_query_pipeline_init() -> None: """Test query pipeline init params.""" @@ -387,6 +420,29 @@ def test_query_pipeline_chain_str() -> None: assert output == 11 +def test_query_pipeline_batch_chain_str() -> None: + """Test add_chain with only module strings.""" + p = QueryPipeline( + modules={ + "input": InputComponent(), + "a": QueryComponent3(), + "b": QueryComponent3(), + "c": QueryComponent3(), + "d": QueryComponent1(), + } + ) + p.add_links( + [ + Link("input", "a", src_key="inp1", dest_key="input"), + Link("input", "d", src_key="inp2", dest_key="input2"), + Link("c", "d", dest_key="input1"), + ] + ) + p.add_chain(["a", "b", "c"]) + output = p.run(inp1=[1, 5], inp2=[3, 4], batch=True) + assert output == [11, 44] + + def test_query_pipeline_chain_str_intermediate_output() -> None: """Test add_chain with only module strings, showing intermediate outputs.""" p = QueryPipeline( -- GitLab