Skip to content
Snippets Groups Projects
Unverified Commit 01df13f4 authored by Ryan Nguyen's avatar Ryan Nguyen Committed by GitHub
Browse files

Add batch mode to QueryPipeline (#13203)

parent e76d702a
No related branches found
No related tags found
No related merge requests found
...@@ -113,6 +113,28 @@ p = QueryPipeline(chain=[prompt_tmpl, llm, summarizer_c]) ...@@ -113,6 +113,28 @@ p = QueryPipeline(chain=[prompt_tmpl, llm, summarizer_c])
p.run(topic="YC") p.run(topic="YC")
``` ```
### Batch Input
If you wish to run the pipeline for several rounds of single/multi-inputs, set `batch=True` in the function call - supported by `run`, `arun`, `run_multi`, and `arun_multi`. Pass in a list of individual single/multi-inputs you would like to run. `batch` mode will return a list of responses in the same order as the inputs.
Example for single-input/single-output: `p.run(field=[in1: Any, in2: Any], batch=True)` --> `[out1: Any, out2: Any]`
```python
output = p.run(topic=["YC", "RAG", "LlamaIndex"], batch=True)
# output is [ResponseYC, ResponseRAG, ResponseLlamaIndex]
print(output)
```
Example for multi-input/multi-output: `p.run_multi("root_node": {"field": [in1: Any, in2, Any]}, batch=True)` --> `{"output_node": {"field": [out1: Any, out2: Any]}}`
```python
output_dict = p.run_multi({"llm": {"topic": ["YC", "RAG", "LlamaIndex"]}})
print(output_dict)
# output dict is {"summarizer": {"output": [ResponseYC, ResponseRAG, ResponseLlamaIndex]}}
```
### Intermediate outputs ### Intermediate outputs
If you wish to obtain the intermediate outputs of modules in QueryPipeline, you can use `run_with_intermediates` or `run_multi_with_intermediates` for single-input and multi-input, respectively. If you wish to obtain the intermediate outputs of modules in QueryPipeline, you can use `run_with_intermediates` or `run_multi_with_intermediates` for single-input and multi-input, respectively.
......
...@@ -17,7 +17,7 @@ from typing import ( ...@@ -17,7 +17,7 @@ from typing import (
import networkx import networkx
from llama_index.core.async_utils import run_jobs from llama_index.core.async_utils import asyncio_run, run_jobs
from llama_index.core.bridge.pydantic import Field from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CallbackManager from llama_index.core.callbacks import CallbackManager
from llama_index.core.callbacks.schema import CBEventType, EventPayload from llama_index.core.callbacks.schema import CBEventType, EventPayload
...@@ -302,6 +302,7 @@ class QueryPipeline(QueryComponent): ...@@ -302,6 +302,7 @@ class QueryPipeline(QueryComponent):
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the pipeline.""" """Run the pipeline."""
...@@ -321,8 +322,10 @@ class QueryPipeline(QueryComponent): ...@@ -321,8 +322,10 @@ class QueryPipeline(QueryComponent):
*args, *args,
return_values_direct=return_values_direct, return_values_direct=return_values_direct,
show_intermediates=False, show_intermediates=False,
batch=batch,
**kwargs, **kwargs,
) )
return outputs return outputs
def run_with_intermediates( def run_with_intermediates(
...@@ -330,9 +333,13 @@ class QueryPipeline(QueryComponent): ...@@ -330,9 +333,13 @@ class QueryPipeline(QueryComponent):
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[Any, Dict[str, ComponentIntermediates]]: ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
"""Run the pipeline.""" """Run the pipeline."""
if batch is not None:
raise ValueError("Batch is not supported for run_with_intermediates.")
# first set callback manager # first set callback manager
callback_manager = callback_manager or self.callback_manager callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager) self.set_callback_manager(callback_manager)
...@@ -352,10 +359,27 @@ class QueryPipeline(QueryComponent): ...@@ -352,10 +359,27 @@ class QueryPipeline(QueryComponent):
**kwargs, **kwargs,
) )
def merge_dicts(self, d1, d2):
"""Merge two dictionaries recursively, combining values of the same key into a list."""
merged = {}
for key in set(d1).union(d2):
if key in d1 and key in d2:
if isinstance(d1[key], dict) and isinstance(d2[key], dict):
merged[key] = self.merge_dicts(d1[key], d2[key])
else:
merged[key] = (
[d1[key]] if not isinstance(d1[key], list) else d1[key]
)
merged[key].append(d2[key])
else:
merged[key] = d1.get(key, d2.get(key))
return merged
def run_multi( def run_multi(
self, self,
module_input_dict: Dict[str, Any], module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.""" """Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager callback_manager = callback_manager or self.callback_manager
...@@ -365,8 +389,41 @@ class QueryPipeline(QueryComponent): ...@@ -365,8 +389,41 @@ class QueryPipeline(QueryComponent):
CBEventType.QUERY, CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event: ) as query_event:
outputs, _ = self._run_multi(module_input_dict) if batch:
return outputs outputs = {}
batch_lengths = {
len(values)
for subdict in module_input_dict.values()
for values in subdict.values()
}
if len(batch_lengths) != 1:
raise ValueError("Length of batch inputs must be the same.")
batch_size = next(iter(batch_lengths))
# List individual outputs from batch multi input.
inputs = [
{
key: {
inner_key: inner_val[i]
for inner_key, inner_val in value.items()
}
for key, value in module_input_dict.items()
}
for i in range(batch_size)
]
jobs = [self._arun_multi(input) for input in inputs]
results = asyncio_run(run_jobs(jobs, workers=len(jobs)))
for result in results:
outputs = self.merge_dicts(outputs, result[0])
return outputs
else:
outputs, _ = self._run_multi(module_input_dict)
return outputs
def run_multi_with_intermediates( def run_multi_with_intermediates(
self, self,
...@@ -388,6 +445,7 @@ class QueryPipeline(QueryComponent): ...@@ -388,6 +445,7 @@ class QueryPipeline(QueryComponent):
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run the pipeline.""" """Run the pipeline."""
...@@ -406,8 +464,10 @@ class QueryPipeline(QueryComponent): ...@@ -406,8 +464,10 @@ class QueryPipeline(QueryComponent):
*args, *args,
return_values_direct=return_values_direct, return_values_direct=return_values_direct,
show_intermediates=False, show_intermediates=False,
batch=batch,
**kwargs, **kwargs,
) )
return outputs return outputs
async def arun_with_intermediates( async def arun_with_intermediates(
...@@ -415,9 +475,13 @@ class QueryPipeline(QueryComponent): ...@@ -415,9 +475,13 @@ class QueryPipeline(QueryComponent):
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[Any, Dict[str, ComponentIntermediates]]: ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
"""Run the pipeline.""" """Run the pipeline."""
if batch is not None:
raise ValueError("Batch is not supported for run_with_intermediates.")
# first set callback manager # first set callback manager
callback_manager = callback_manager or self.callback_manager callback_manager = callback_manager or self.callback_manager
self.set_callback_manager(callback_manager) self.set_callback_manager(callback_manager)
...@@ -440,6 +504,7 @@ class QueryPipeline(QueryComponent): ...@@ -440,6 +504,7 @@ class QueryPipeline(QueryComponent):
self, self,
module_input_dict: Dict[str, Any], module_input_dict: Dict[str, Any],
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
batch: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.""" """Run the pipeline for multiple roots."""
callback_manager = callback_manager or self.callback_manager callback_manager = callback_manager or self.callback_manager
...@@ -449,8 +514,42 @@ class QueryPipeline(QueryComponent): ...@@ -449,8 +514,42 @@ class QueryPipeline(QueryComponent):
CBEventType.QUERY, CBEventType.QUERY,
payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)}, payload={EventPayload.QUERY_STR: json.dumps(module_input_dict)},
) as query_event: ) as query_event:
outputs, _ = await self._arun_multi(module_input_dict) if batch:
return outputs outputs = {}
batch_lengths = {
len(values)
for subdict in module_input_dict.values()
for values in subdict.values()
}
if len(batch_lengths) != 1:
raise ValueError("Length of batch inputs must be the same.")
batch_size = next(iter(batch_lengths))
# List individual outputs from batch multi input.
inputs = [
{
key: {
inner_key: inner_val[i]
for inner_key, inner_val in value.items()
}
for key, value in module_input_dict.items()
}
for i in range(batch_size)
]
jobs = [self._arun_multi(input) for input in inputs]
results = await run_jobs(jobs, workers=len(jobs))
for result in results:
outputs = self.merge_dicts(outputs, result[0])
return outputs
else:
outputs, _ = await self._arun_multi(module_input_dict)
return outputs
async def arun_multi_with_intermediates( async def arun_multi_with_intermediates(
self, self,
...@@ -530,6 +629,7 @@ class QueryPipeline(QueryComponent): ...@@ -530,6 +629,7 @@ class QueryPipeline(QueryComponent):
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
show_intermediates: bool = False, show_intermediates: bool = False,
batch: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Tuple[Any, Dict[str, ComponentIntermediates]]: ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
"""Run the pipeline. """Run the pipeline.
...@@ -541,20 +641,50 @@ class QueryPipeline(QueryComponent): ...@@ -541,20 +641,50 @@ class QueryPipeline(QueryComponent):
""" """
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
result_outputs, intermediates = self._run_multi( if batch:
{root_key: kwargs}, show_intermediates=show_intermediates result_outputs = []
) intermediates = []
return ( if len({len(value) for value in kwargs.values()}) != 1:
self._get_single_result_output(result_outputs, return_values_direct), raise ValueError("Length of batch inputs must be the same.")
intermediates,
) # List of individual inputs from batch input
kwargs = [
dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())
]
jobs = [
self._arun_multi(
{root_key: kwarg}, show_intermediates=show_intermediates
)
for kwarg in kwargs
]
results = asyncio_run(run_jobs(jobs, workers=len(jobs)))
for result in results:
result_outputs.append(
self._get_single_result_output(result[0], return_values_direct)
)
intermediates.append(result[1])
return result_outputs, intermediates
else:
result_outputs, intermediates = self._run_multi(
{root_key: kwargs}, show_intermediates=show_intermediates
)
return (
self._get_single_result_output(result_outputs, return_values_direct),
intermediates,
)
async def _arun( async def _arun(
self, self,
*args: Any, *args: Any,
return_values_direct: bool = True, return_values_direct: bool = True,
show_intermediates: bool = False, show_intermediates: bool = False,
batch: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Tuple[Any, Dict[str, ComponentIntermediates]]: ) -> Tuple[Any, Dict[str, ComponentIntermediates]]:
"""Run the pipeline. """Run the pipeline.
...@@ -566,14 +696,43 @@ class QueryPipeline(QueryComponent): ...@@ -566,14 +696,43 @@ class QueryPipeline(QueryComponent):
""" """
root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs) root_key, kwargs = self._get_root_key_and_kwargs(*args, **kwargs)
result_outputs, intermediates = await self._arun_multi( if batch:
{root_key: kwargs}, show_intermediates=show_intermediates result_outputs = []
) intermediates = []
return ( if len({len(value) for value in kwargs.values()}) != 1:
self._get_single_result_output(result_outputs, return_values_direct), raise ValueError("Length of batch inputs must be the same.")
intermediates,
) # List of individual inputs from batch input
kwargs = [
dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())
]
jobs = [
self._arun_multi(
{root_key: kwarg}, show_intermediates=show_intermediates
)
for kwarg in kwargs
]
results = await run_jobs(jobs, workers=len(jobs))
for result in results:
result_outputs.append(
self._get_single_result_output(result[0], return_values_direct)
)
intermediates.append(result[1])
return result_outputs, intermediates
else:
result_outputs, intermediates = await self._arun_multi(
{root_key: kwargs}, show_intermediates=show_intermediates
)
return (
self._get_single_result_output(result_outputs, return_values_direct),
intermediates,
)
def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None: def _validate_inputs(self, module_input_dict: Dict[str, Any]) -> None:
root_keys = self._get_root_keys() root_keys = self._get_root_keys()
......
...@@ -233,6 +233,27 @@ def test_query_pipeline_multi() -> None: ...@@ -233,6 +233,27 @@ def test_query_pipeline_multi() -> None:
assert output == {"qc2": {"output": "3:7"}} assert output == {"qc2": {"output": "3:7"}}
def test_query_pipeline_multi_batch() -> None:
"""Test query pipeline."""
# try run run_multi
# link both qc1_0 and qc1_1 to qc2
qc1_0 = QueryComponent1()
qc1_1 = QueryComponent1()
qc2 = QueryComponent2()
p = QueryPipeline()
p.add_modules({"qc1_0": qc1_0, "qc1_1": qc1_1, "qc2": qc2})
p.add_link("qc1_0", "qc2", dest_key="input1")
p.add_link("qc1_1", "qc2", dest_key="input2")
output = p.run_multi(
{
"qc1_0": {"input1": [1, 5], "input2": [2, 1]},
"qc1_1": {"input1": [3, 7], "input2": [4, 2]},
},
batch=True,
)
assert output == {"qc2": {"output": ["3:7", "6:9"]}}
def test_query_pipeline_multi_intermediate_output() -> None: def test_query_pipeline_multi_intermediate_output() -> None:
"""Test query pipeline showing intermediate outputs.""" """Test query pipeline showing intermediate outputs."""
# try run run_multi_with_intermediates # try run run_multi_with_intermediates
...@@ -298,6 +319,9 @@ async def test_query_pipeline_async() -> None: ...@@ -298,6 +319,9 @@ async def test_query_pipeline_async() -> None:
output = await p.arun(inp1=1, inp2=2) output = await p.arun(inp1=1, inp2=2)
assert output == "3:1" assert output == "3:1"
output = await p.arun(inp1=[1, 2], inp2=[2, 3], batch=True)
assert output == ["3:1", "5:2"]
# try run run_multi # try run run_multi
# link both qc1_0 and qc1_1 to qc2 # link both qc1_0 and qc1_1 to qc2
qc1_0 = QueryComponent1() qc1_0 = QueryComponent1()
...@@ -312,6 +336,15 @@ async def test_query_pipeline_async() -> None: ...@@ -312,6 +336,15 @@ async def test_query_pipeline_async() -> None:
) )
assert output == {"qc2": {"output": "3:7"}} assert output == {"qc2": {"output": "3:7"}}
output = await p.arun_multi(
{
"qc1_0": {"input1": [1, 5], "input2": [2, 1]},
"qc1_1": {"input1": [3, 7], "input2": [4, 2]},
},
batch=True,
)
assert output == {"qc2": {"output": ["3:7", "6:9"]}}
def test_query_pipeline_init() -> None: def test_query_pipeline_init() -> None:
"""Test query pipeline init params.""" """Test query pipeline init params."""
...@@ -387,6 +420,29 @@ def test_query_pipeline_chain_str() -> None: ...@@ -387,6 +420,29 @@ def test_query_pipeline_chain_str() -> None:
assert output == 11 assert output == 11
def test_query_pipeline_batch_chain_str() -> None:
"""Test add_chain with only module strings."""
p = QueryPipeline(
modules={
"input": InputComponent(),
"a": QueryComponent3(),
"b": QueryComponent3(),
"c": QueryComponent3(),
"d": QueryComponent1(),
}
)
p.add_links(
[
Link("input", "a", src_key="inp1", dest_key="input"),
Link("input", "d", src_key="inp2", dest_key="input2"),
Link("c", "d", dest_key="input1"),
]
)
p.add_chain(["a", "b", "c"])
output = p.run(inp1=[1, 5], inp2=[3, 4], batch=True)
assert output == [11, 44]
def test_query_pipeline_chain_str_intermediate_output() -> None: def test_query_pipeline_chain_str_intermediate_output() -> None:
"""Test add_chain with only module strings, showing intermediate outputs.""" """Test add_chain with only module strings, showing intermediate outputs."""
p = QueryPipeline( p = QueryPipeline(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment