From 0393b081f3aed854e0a628f49b8e51f8da7906ef Mon Sep 17 00:00:00 2001 From: Jerry Liu <jerryjliu98@gmail.com> Date: Thu, 8 Feb 2024 17:38:28 -0800 Subject: [PATCH] add conditional links to query pipeline (#10520) --- .../core/query_pipeline/query_component.py | 40 +++++- llama_index/query_pipeline/query.py | 134 ++++++++++++++++-- tests/query_pipeline/test_query.py | 47 +++++- 3 files changed, 207 insertions(+), 14 deletions(-) diff --git a/llama_index/core/query_pipeline/query_component.py b/llama_index/core/query_pipeline/query_component.py index 9b5c0306c..474b1e518 100644 --- a/llama_index/core/query_pipeline/query_component.py +++ b/llama_index/core/query_pipeline/query_component.py @@ -1,7 +1,18 @@ """Pipeline schema.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Union, cast, get_args +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Union, + cast, + get_args, +) from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks.base import CallbackManager @@ -306,6 +317,33 @@ class Link(BaseModel): super().__init__(src=src, dest=dest, src_key=src_key, dest_key=dest_key) +class ConditionalLinks(BaseModel): + """Conditional Links between source and multiple destinations.""" + + src: str = Field(..., description="Source component name") + fn: Callable = Field( + ..., description="Function to determine which destination to go to" + ) + cond_dest_dict: Dict[str, Any] = Field( + ..., description="dictionary of value to destination component name" + ) + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + src: str, + fn: Callable, + cond_dest_dict: Dict[str, Any], + ) -> None: + """Init params.""" + # NOTE: This is to enable positional args. + super().__init__(src=src, fn=fn, cond_dest_dict=cond_dest_dict) + + # accept both QueryComponent and ChainableMixin as inputs to query pipeline # ChainableMixin modules will be converted to components via `as_query_component` QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin] + +LINK_TYPE = Union[Link, ConditionalLinks] diff --git a/llama_index/query_pipeline/query.py b/llama_index/query_pipeline/query.py index 6f5835f8e..ab40550f1 100644 --- a/llama_index/query_pipeline/query.py +++ b/llama_index/query_pipeline/query.py @@ -2,7 +2,18 @@ import json import uuid -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, get_args +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, + get_args, +) import networkx @@ -11,8 +22,10 @@ from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.query_pipeline.query_component import ( + LINK_TYPE, QUERY_COMPONENT_TYPE, ChainableMixin, + ConditionalLinks, InputKeys, Link, OutputKeys, @@ -21,8 +34,17 @@ from llama_index.core.query_pipeline.query_component import ( from llama_index.utils import print_text +def get_single_output( + output_dict: Dict[str, Any], +) -> Any: + """Get single output.""" + if len(output_dict) != 1: + raise ValueError("Output dict must have exactly one key.") + return next(iter(output_dict.values())) + + def add_output_to_module_inputs( - src_key: str, + src_key: Optional[str], dest_key: str, output_dict: Dict[str, Any], module: QueryComponent, @@ -111,6 +133,16 @@ class QueryPipeline(QueryComponent): module_dict: Dict[str, QueryComponent] = Field( default_factory=dict, description="The modules in the pipeline." ) + conditional_dict: Dict[str, Callable] = Field( + default_factory=dict, + description=( + "Mapping from module key to conditional function." + "Specifically used in the case of conditional links. " + "The conditional function should take in the output of the source module and " + "return a value that will then be mapped to the relevant destination module " + "in the conditional link." + ), + ) dag: networkx.MultiDiGraph = Field( default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline." ) @@ -185,11 +217,16 @@ class QueryPipeline(QueryComponent): def add_links( self, - links: List[Link], + links: List[LINK_TYPE], ) -> None: """Add links to the pipeline.""" for link in links: - self.add_link(**link.dict()) + if isinstance(link, Link): + self.add_link(**link.dict()) + elif isinstance(link, ConditionalLinks): + self.add_conditional_links(**link.dict()) + else: + raise ValueError("Link must be of type `Link` or `ConditionalLinks`.") def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None: """Add modules to the pipeline.""" @@ -222,6 +259,24 @@ class QueryPipeline(QueryComponent): raise ValueError(f"Module {src} does not exist in pipeline.") self.dag.add_edge(src, dest, src_key=src_key, dest_key=dest_key) + def add_conditional_links( + self, + src: str, + fn: Callable, + cond_dest_dict: Dict[str, Any], + ) -> None: + """Add conditional links.""" + if src not in self.module_dict: + raise ValueError(f"Module {src} does not exist in pipeline.") + self.conditional_dict[src] = fn + for conditional, dest_dict in cond_dest_dict.items(): + self.dag.add_edge( + src, + dest_dict["dest"], + dest_key=dest_dict["dest_key"], + conditional=conditional, + ) + def get_root_keys(self) -> List[str]: """Get root keys.""" return self._get_root_keys() @@ -419,28 +474,83 @@ class QueryPipeline(QueryComponent): def _process_component_output( self, + queue: List[str], output_dict: Dict[str, Any], module_key: str, all_module_inputs: Dict[str, Dict[str, Any]], result_outputs: Dict[str, Any], - ) -> None: + ) -> List[str]: """Process component output.""" + new_queue = queue.copy() # if there's no more edges, add result to output if module_key in self._get_leaf_keys(): result_outputs[module_key] = output_dict else: - for _, dest, attr in self.dag.edges(module_key, data=True): - edge_module = self.module_dict[dest] + # first, process conditional edges. find the conditional edge + # that matches, and then remove all other edges from queue + # after, process regular edges + + conditional_val: Optional[Any] = None + if module_key in self.conditional_dict: + # NOTE: we assume that the output of the module is a single key + single_output = get_single_output(output_dict) + # the conditional_val determines which edge to take + # new_output is the output that will be passed to the next module + conditional_val, new_output = self.conditional_dict[module_key]( + single_output + ) + edge_list = list(self.dag.edges(module_key, data=True)) + # get conditional edge list + conditional_edge_list = [ + (src, dest, attr) + for src, dest, attr in edge_list + if "conditional" in attr + ] + # in conditional edge list, find matches + if len(conditional_edge_list) > 0: + match = next( + iter( + [ + (src, dest, attr) + for src, dest, attr in conditional_edge_list + if conditional_val == attr["conditional"] + ] + ) + ) + non_matches = [x for x in conditional_edge_list if x != match] + if len(non_matches) != len(conditional_edge_list) - 1: + raise ValueError("Multiple conditional matches found or None.") + # remove all non-matches from queue + for non_match in non_matches: + new_queue.remove(non_match[1]) + # add match to module inputs + add_output_to_module_inputs( + None, # no src_key for conditional link + match[2].get("dest_key"), + {"output": new_output}, + self.module_dict[match[1]], + all_module_inputs[match[1]], + ) + # everything not in conditional_edge_list is regular + regular_edge_list = [ + (src, dest, attr) + for src, dest, attr in edge_list + if "conditional" not in attr + ] + for _, dest, attr in regular_edge_list: + # if conditional link, check if it should be added # add input to module_deps_inputs add_output_to_module_inputs( attr.get("src_key"), attr.get("dest_key"), output_dict, - edge_module, + self.module_dict[dest], all_module_inputs[dest], ) + return new_queue + def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]: """Run the pipeline for multiple roots. @@ -474,8 +584,8 @@ class QueryPipeline(QueryComponent): output_dict = module.run_component(**module_input) # get new nodes and is_leaf - self._process_component_output( - output_dict, module_key, all_module_inputs, result_outputs + queue = self._process_component_output( + queue, output_dict, module_key, all_module_inputs, result_outputs ) return result_outputs @@ -541,8 +651,8 @@ class QueryPipeline(QueryComponent): for output_dict, module_key in zip(output_dicts, popped_nodes): # get new nodes and is_leaf - self._process_component_output( - output_dict, module_key, all_module_inputs, result_outputs + queue = self._process_component_output( + queue, output_dict, module_key, all_module_inputs, result_outputs ) return result_outputs diff --git a/tests/query_pipeline/test_query.py b/tests/query_pipeline/test_query.py index 01cd129b6..b006df828 100644 --- a/tests/query_pipeline/test_query.py +++ b/tests/query_pipeline/test_query.py @@ -3,9 +3,10 @@ from typing import Any, Dict import pytest -from llama_index.core.query_pipeline.components import InputComponent +from llama_index.core.query_pipeline.components import FnComponent, InputComponent from llama_index.core.query_pipeline.query_component import ( ChainableMixin, + ConditionalLinks, InputKeys, Link, OutputKeys, @@ -356,3 +357,47 @@ def test_query_pipeline_chain_str() -> None: p.add_chain(["a", "b", "c"]) output = p.run(inp1=1, inp2=3) assert output == 11 + + +def test_query_pipeline_conditional_edges() -> None: + """Test conditional edges.""" + + def choose_fn(input: int) -> Dict: + """Choose.""" + if input == 1: + toggle = "true" + else: + toggle = "false" + return {"toggle": toggle, "input": input} + + p = QueryPipeline( + modules={ + "input": InputComponent(), + "fn": FnComponent(fn=choose_fn), + "a": QueryComponent1(), + "b": QueryComponent2(), + }, + ) + + p.add_links( + [ + Link("input", "fn", src_key="inp1", dest_key="input"), + Link("input", "a", src_key="inp2", dest_key="input1"), + Link("input", "b", src_key="inp2", dest_key="input1"), + ConditionalLinks( + "fn", + lambda x: (x["toggle"], x["input"]), + { + "true": {"dest": "a", "dest_key": "input2"}, + "false": {"dest": "b", "dest_key": "input2"}, + }, + ), + ] + ) + output = p.run(inp1=1, inp2=3) + # should go to a + assert output == 4 + + output = p.run(inp1=2, inp2=3) + # should go to b + assert output == "3:2" -- GitLab