From 01df13f4a37e8cc29497ce41b1abed0ab8283f9b Mon Sep 17 00:00:00 2001
From: Ryan Nguyen <96593302+xpbowler@users.noreply.github.com>
Date: Thu, 16 May 2024 22:33:59 -0400
Subject: [PATCH] Add batch mode to QueryPipeline (#13203)

---
 .../querying/pipeline/usage_pattern.md        |  22 ++
 .../llama_index/core/query_pipeline/query.py  | 197 ++++++++++++++++--
 .../tests/query_pipeline/test_query.py        |  56 +++++
 3 files changed, 256 insertions(+), 19 deletions(-)

diff --git a/docs/docs/module_guides/querying/pipeline/usage_pattern.md b/docs/docs/module_guides/querying/pipeline/usage_pattern.md
index d2341821a..bd2aac0f6 100644
--- a/docs/docs/module_guides/querying/pipeline/usage_pattern.md
+++ b/docs/docs/module_guides/querying/pipeline/usage_pattern.md
@@ -113,6 +113,28 @@ p = QueryPipeline(chain=[prompt_tmpl, llm, summarizer_c])
 p.run(topic="YC")
 ```
 
+### Batch Input
+
+If you wish to run the pipeline for several rounds of single/multi-inputs, set `batch=True` in the function call - supported by `run`, `arun`, `run_multi`, and `arun_multi`. Pass in a list of individual single/multi-inputs you would like to run. `batch` mode will return a list of responses in the same order as the inputs.
+
+Example for single-input/single-output: `p.run(field=[in1: Any, in2: Any], batch=True)` --> `[out1: Any, out2: Any]`
+
+```python
+output = p.run(topic=["YC", "RAG", "LlamaIndex"], batch=True)
+# output is [ResponseYC, ResponseRAG, ResponseLlamaIndex]
+print(output)
+```
+
+Example for multi-input/multi-output: `p.run_multi("root_node": {"field": [in1: Any, in2, Any]}, batch=True)` --> `{"output_node": {"field": [out1: Any, out2: Any]}}`
+
+```python
+output_dict = p.run_multi({"llm": {"topic": ["YC", "RAG", "LlamaIndex"]}})
+print(output_dict)
+
+# output dict is {"summarizer": {"output": [ResponseYC, ResponseRAG, ResponseLlamaIndex]}}
+```
+
+
 ### Intermediate outputs
 
 If you wish to obtain the intermediate outputs of modules in QueryPipeline, you can use `run_with_intermediates` or `run_multi_with_intermediates` for single-input and multi-input, respectively.
diff --git a/llama-index-core/llama_index/core/query_pipeline/query.py b/llama-index-core/llama_index/core/query_pipeline/query.py
index 2acf88ea8..2f58549c7 100644
--- a/llama-index-core/llama_index/core/query_pipeline/query.py
+++ b/llama-index-core/llama_index/core/query_pipeline/query.py
@@ -17,7 +17,7 @@ from typing import (
 
 import networkx
 
-from llama_index.core.async_utils import run_jobs
+from llama_index.core.async_utils import asyncio_run, run_jobs
 from llama_index.core.bridge.pydantic import Field
 from llama_index.core.callbacks import CallbackManager
 from llama_index.core.callbacks.schema import CBEventType, EventPayload
@@ -302,6 +302,7 @@ class QueryPipeline(QueryComponent):
         *args: Any,
         return_values_direct: bool = True,
         callback_manager: Optional[CallbackManager] = None,
+        batch: bool = False,
         **kwargs: Any,
     ) -> Any:
         """Run the pipeline."""
@@ -321,8 +322,10 @@ class QueryPipeline(QueryComponent):
                     *args,
                     return_values_direct=return_values_direct,
                     show_intermediates=False,
+                    batch=batch,
                     **kwargs,
                 )
+
                 return outputs
 
     def run_with_intermediates(
@@ -330,9 +333,13 @@ class QueryPipeline(QueryComponent):
         *args: Any,
         return_values_direct: bool = True,
         callback_manager: Optional[CallbackManager] = None,
+        batch: Optional[bool] = None,
         **kwargs: Any,
     ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
         """Run the pipeline."""
+        if batch is not None:
+            raise ValueError("Batch is not supported for run_with_intermediates.")
+
         # first set callback manager
         callback_manager = callback_manager or self.callback_manager
         self.set_callback_manager(callback_manager)
@@ -352,10 +359,27 @@ class QueryPipeline(QueryComponent):
                     **kwargs,
                 )
 
+    def merge_dicts(self, d1, d2):
+        """Merge two dictionaries recursively, combining values of the same key into a list."""
+        merged = {}
+        for key in set(d1).union(d2):
+            if key in d1 and key in d2:
+                if isinstance(d1[key], dict) and isinstance(d2[key], dict):
+                    merged[key] = self.merge_dicts(d1[key], d2[key])
+                else:
+                    merged[key] = (
+                        [d1[key]] if not isinstance(d1[key], list) else d1[key]
+                    )
+                    merged[key].append(d2[key])
+            else:
+                merged[key] = d1.get(key, d2.get(key))
+        return merged
+
     def run_multi(
         self,
         module_input_dict: Dict[str, Any],
         callback_manager: Optional[CallbackManager] = None,
+        batch: bool = False,
     ) -> Dict[str, Any]:
         """Run the pipeline for multiple roots."""
         callback_manager = callback_manager or self.callback_manager
@@ -365,8 +389,41 @@ class QueryPipeline(QueryComponent):
                 CBEventType.QUERY,
                 payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
             ) as query_event:
-                outputs, _ = self._run_multi(module_input_dict)
-                return outputs
+                if batch:
+                    outputs = {}
+
+                    batch_lengths = {
+                        len(values)
+                        for subdict in module_input_dict.values()
+                        for values in subdict.values()
+                    }
+
+                    if len(batch_lengths) != 1:
+                        raise ValueError("Length of batch inputs must be the same.")
+
+                    batch_size = next(iter(batch_lengths))
+
+                    # List individual outputs from batch multi input.
+                    inputs = [
+                        {
+                            key: {
+                                inner_key: inner_val[i]
+                                for inner_key, inner_val in value.items()
+                            }
+                            for key, value in module_input_dict.items()
+                        }
+                        for i in range(batch_size)
+                    ]
+                    jobs = [self._arun_multi(input) for input in inputs]
+                    results = asyncio_run(run_jobs(jobs, workers=len(jobs)))
+
+                    for result in results:
+                        outputs = self.merge_dicts(outputs, result[0])
+
+                    return outputs
+                else:
+                    outputs, _ = self._run_multi(module_input_dict)
+                    return outputs
 
     def run_multi_with_intermediates(
         self,
@@ -388,6 +445,7 @@ class QueryPipeline(QueryComponent):
         *args: Any,
         return_values_direct: bool = True,
         callback_manager: Optional[CallbackManager] = None,
+        batch: bool = False,
         **kwargs: Any,
     ) -> Any:
         """Run the pipeline."""
@@ -406,8 +464,10 @@ class QueryPipeline(QueryComponent):
                     *args,
                     return_values_direct=return_values_direct,
                     show_intermediates=False,
+                    batch=batch,
                     **kwargs,
                 )
+
                 return outputs
 
     async def arun_with_intermediates(
@@ -415,9 +475,13 @@ class QueryPipeline(QueryComponent):
         *args: Any,
         return_values_direct: bool = True,
         callback_manager: Optional[CallbackManager] = None,
+        batch: Optional[bool] = None,
         **kwargs: Any,
     ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
         """Run the pipeline."""
+        if batch is not None:
+            raise ValueError("Batch is not supported for run_with_intermediates.")
+
         # first set callback manager
         callback_manager = callback_manager or self.callback_manager
         self.set_callback_manager(callback_manager)
@@ -440,6 +504,7 @@ class QueryPipeline(QueryComponent):
         self,
         module_input_dict: Dict[str, Any],
         callback_manager: Optional[CallbackManager] = None,
+        batch: bool = False,
     ) -> Dict[str, Any]:
         """Run the pipeline for multiple roots."""
         callback_manager = callback_manager or self.callback_manager
@@ -449,8 +514,42 @@ class QueryPipeline(QueryComponent):
                 CBEventType.QUERY,
                 payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
             ) as query_event:
-                outputs, _ = await self._arun_multi(module_input_dict)
-                return outputs
+                if batch:
+                    outputs = {}
+
+                    batch_lengths = {
+                        len(values)
+                        for subdict in module_input_dict.values()
+                        for values in subdict.values()
+                    }
+
+                    if len(batch_lengths) != 1:
+                        raise ValueError("Length of batch inputs must be the same.")
+
+                    batch_size = next(iter(batch_lengths))
+
+                    # List individual outputs from batch multi input.
+                    inputs = [
+                        {
+                            key: {
+                                inner_key: inner_val[i]
+                                for inner_key, inner_val in value.items()
+                            }
+                            for key, value in module_input_dict.items()
+                        }
+                        for i in range(batch_size)
+                    ]
+
+                    jobs = [self._arun_multi(input) for input in inputs]
+                    results = await run_jobs(jobs, workers=len(jobs))
+
+                    for result in results:
+                        outputs = self.merge_dicts(outputs, result[0])
+
+                    return outputs
+                else:
+                    outputs, _ = await self._arun_multi(module_input_dict)
+                    return outputs
 
     async def arun_multi_with_intermediates(
         self,
@@ -530,6 +629,7 @@ class QueryPipeline(QueryComponent):
         *args: Any,
         return_values_direct: bool = True,
         show_intermediates: bool = False,
+        batch: bool = False,
         **kwargs: Any,
     ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
         """Run the pipeline.
@@ -541,20 +641,50 @@ class QueryPipeline(QueryComponent):
         """
         root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
 
-        result_outputs, intermediates = self._run_multi(
-            {root_key: kwargs}, show_intermediates=show_intermediates
-        )
+        if batch:
+            result_outputs = []
+            intermediates = []
 
-        return (
-            self._get_single_result_output(result_outputs, return_values_direct),
-            intermediates,
-        )
+            if len({len(value) for value in kwargs.values()}) != 1:
+                raise ValueError("Length of batch inputs must be the same.")
+
+            # List of individual inputs from batch input
+            kwargs = [
+                dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())
+            ]
+
+            jobs = [
+                self._arun_multi(
+                    {root_key: kwarg}, show_intermediates=show_intermediates
+                )
+                for kwarg in kwargs
+            ]
+
+            results = asyncio_run(run_jobs(jobs, workers=len(jobs)))
+
+            for result in results:
+                result_outputs.append(
+                    self._get_single_result_output(result[0], return_values_direct)
+                )
+                intermediates.append(result[1])
+
+            return result_outputs, intermediates
+        else:
+            result_outputs, intermediates = self._run_multi(
+                {root_key: kwargs}, show_intermediates=show_intermediates
+            )
+
+            return (
+                self._get_single_result_output(result_outputs, return_values_direct),
+                intermediates,
+            )
 
     async def _arun(
         self,
         *args: Any,
         return_values_direct: bool = True,
         show_intermediates: bool = False,
+        batch: bool = False,
         **kwargs: Any,
     ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
         """Run the pipeline.
@@ -566,14 +696,43 @@ class QueryPipeline(QueryComponent):
         """
         root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
 
-        result_outputs, intermediates = await self._arun_multi(
-            {root_key: kwargs}, show_intermediates=show_intermediates
-        )
+        if batch:
+            result_outputs = []
+            intermediates = []
 
-        return (
-            self._get_single_result_output(result_outputs, return_values_direct),
-            intermediates,
-        )
+            if len({len(value) for value in kwargs.values()}) != 1:
+                raise ValueError("Length of batch inputs must be the same.")
+
+            # List of individual inputs from batch input
+            kwargs = [
+                dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())
+            ]
+
+            jobs = [
+                self._arun_multi(
+                    {root_key: kwarg}, show_intermediates=show_intermediates
+                )
+                for kwarg in kwargs
+            ]
+
+            results = await run_jobs(jobs, workers=len(jobs))
+
+            for result in results:
+                result_outputs.append(
+                    self._get_single_result_output(result[0], return_values_direct)
+                )
+                intermediates.append(result[1])
+
+            return result_outputs, intermediates
+        else:
+            result_outputs, intermediates = await self._arun_multi(
+                {root_key: kwargs}, show_intermediates=show_intermediates
+            )
+
+            return (
+                self._get_single_result_output(result_outputs, return_values_direct),
+                intermediates,
+            )
 
     def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None:
         root_keys = self._get_root_keys()
diff --git a/llama-index-core/tests/query_pipeline/test_query.py b/llama-index-core/tests/query_pipeline/test_query.py
index f0209c5d4..b39ff8c42 100644
--- a/llama-index-core/tests/query_pipeline/test_query.py
+++ b/llama-index-core/tests/query_pipeline/test_query.py
@@ -233,6 +233,27 @@ def test_query_pipeline_multi() -> None:
     assert output == {"qc2": {"output": "3:7"}}
 
 
+def test_query_pipeline_multi_batch() -> None:
+    """Test query pipeline."""
+    # try run run_multi
+    # link both qc1_0 and qc1_1 to qc2
+    qc1_0 = QueryComponent1()
+    qc1_1 = QueryComponent1()
+    qc2 = QueryComponent2()
+    p = QueryPipeline()
+    p.add_modules({"qc1_0": qc1_0, "qc1_1": qc1_1, "qc2": qc2})
+    p.add_link("qc1_0", "qc2", dest_key="input1")
+    p.add_link("qc1_1", "qc2", dest_key="input2")
+    output = p.run_multi(
+        {
+            "qc1_0": {"input1": [1, 5], "input2": [2, 1]},
+            "qc1_1": {"input1": [3, 7], "input2": [4, 2]},
+        },
+        batch=True,
+    )
+    assert output == {"qc2": {"output": ["3:7", "6:9"]}}
+
+
 def test_query_pipeline_multi_intermediate_output() -> None:
     """Test query pipeline showing intermediate outputs."""
     # try run run_multi_with_intermediates
@@ -298,6 +319,9 @@ async def test_query_pipeline_async() -> None:
     output = await p.arun(inp1=1, inp2=2)
     assert output == "3:1"
 
+    output = await p.arun(inp1=[1, 2], inp2=[2, 3], batch=True)
+    assert output == ["3:1", "5:2"]
+
     # try run run_multi
     # link both qc1_0 and qc1_1 to qc2
     qc1_0 = QueryComponent1()
@@ -312,6 +336,15 @@ async def test_query_pipeline_async() -> None:
     )
     assert output == {"qc2": {"output": "3:7"}}
 
+    output = await p.arun_multi(
+        {
+            "qc1_0": {"input1": [1, 5], "input2": [2, 1]},
+            "qc1_1": {"input1": [3, 7], "input2": [4, 2]},
+        },
+        batch=True,
+    )
+    assert output == {"qc2": {"output": ["3:7", "6:9"]}}
+
 
 def test_query_pipeline_init() -> None:
     """Test query pipeline init params."""
@@ -387,6 +420,29 @@ def test_query_pipeline_chain_str() -> None:
     assert output == 11
 
 
+def test_query_pipeline_batch_chain_str() -> None:
+    """Test add_chain with only module strings."""
+    p = QueryPipeline(
+        modules={
+            "input": InputComponent(),
+            "a": QueryComponent3(),
+            "b": QueryComponent3(),
+            "c": QueryComponent3(),
+            "d": QueryComponent1(),
+        }
+    )
+    p.add_links(
+        [
+            Link("input", "a", src_key="inp1", dest_key="input"),
+            Link("input", "d", src_key="inp2", dest_key="input2"),
+            Link("c", "d", dest_key="input1"),
+        ]
+    )
+    p.add_chain(["a", "b", "c"])
+    output = p.run(inp1=[1, 5], inp2=[3, 4], batch=True)
+    assert output == [11, 44]
+
+
 def test_query_pipeline_chain_str_intermediate_output() -> None:
     """Test add_chain with only module strings, showing intermediate outputs."""
     p = QueryPipeline(
-- 
GitLab