From 0393b081f3aed854e0a628f49b8e51f8da7906ef Mon Sep 17 00:00:00 2001
From: Jerry Liu <jerryjliu98@gmail.com>
Date: Thu, 8 Feb 2024 17:38:28 -0800
Subject: [PATCH] add conditional links to query pipeline  (#10520)

---
 .../core/query_pipeline/query_component.py    |  40 +++++-
 llama_index/query_pipeline/query.py           | 134 ++++++++++++++++--
 tests/query_pipeline/test_query.py            |  47 +++++-
 3 files changed, 207 insertions(+), 14 deletions(-)

diff --git a/llama_index/core/query_pipeline/query_component.py b/llama_index/core/query_pipeline/query_component.py
index 9b5c0306c..474b1e518 100644
--- a/llama_index/core/query_pipeline/query_component.py
+++ b/llama_index/core/query_pipeline/query_component.py
@@ -1,7 +1,18 @@
 """Pipeline schema."""
 
 from abc import ABC, abstractmethod
-from typing import Any, Dict, Generator, List, Optional, Set, Union, cast, get_args
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    List,
+    Optional,
+    Set,
+    Union,
+    cast,
+    get_args,
+)
 
 from llama_index.bridge.pydantic import BaseModel, Field
 from llama_index.callbacks.base import CallbackManager
@@ -306,6 +317,33 @@ class Link(BaseModel):
         super().__init__(src=src, dest=dest, src_key=src_key, dest_key=dest_key)
 
 
+class ConditionalLinks(BaseModel):
+    """Conditional Links between source and multiple destinations."""
+
+    src: str = Field(..., description="Source component name")
+    fn: Callable = Field(
+        ..., description="Function to determine which destination to go to"
+    )
+    cond_dest_dict: Dict[str, Any] = Field(
+        ..., description="dictionary of value to destination component name"
+    )
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(
+        self,
+        src: str,
+        fn: Callable,
+        cond_dest_dict: Dict[str, Any],
+    ) -> None:
+        """Init params."""
+        # NOTE: This is to enable positional args.
+        super().__init__(src=src, fn=fn, cond_dest_dict=cond_dest_dict)
+
+
 # accept both QueryComponent and ChainableMixin as inputs to query pipeline
 # ChainableMixin modules will be converted to components via `as_query_component`
 QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin]
+
+LINK_TYPE = Union[Link, ConditionalLinks]
diff --git a/llama_index/query_pipeline/query.py b/llama_index/query_pipeline/query.py
index 6f5835f8e..ab40550f1 100644
--- a/llama_index/query_pipeline/query.py
+++ b/llama_index/query_pipeline/query.py
@@ -2,7 +2,18 @@
 
 import json
 import uuid
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, get_args
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+    get_args,
+)
 
 import networkx
 
@@ -11,8 +22,10 @@ from llama_index.bridge.pydantic import Field
 from llama_index.callbacks import CallbackManager
 from llama_index.callbacks.schema import CBEventType, EventPayload
 from llama_index.core.query_pipeline.query_component import (
+    LINK_TYPE,
     QUERY_COMPONENT_TYPE,
     ChainableMixin,
+    ConditionalLinks,
     InputKeys,
     Link,
     OutputKeys,
@@ -21,8 +34,17 @@ from llama_index.core.query_pipeline.query_component import (
 from llama_index.utils import print_text
 
 
+def get_single_output(
+    output_dict: Dict[str, Any],
+) -> Any:
+    """Get single output."""
+    if len(output_dict) != 1:
+        raise ValueError("Output dict must have exactly one key.")
+    return next(iter(output_dict.values()))
+
+
 def add_output_to_module_inputs(
-    src_key: str,
+    src_key: Optional[str],
     dest_key: str,
     output_dict: Dict[str, Any],
     module: QueryComponent,
@@ -111,6 +133,16 @@ class QueryPipeline(QueryComponent):
     module_dict: Dict[str, QueryComponent] = Field(
         default_factory=dict, description="The modules in the pipeline."
     )
+    conditional_dict: Dict[str, Callable] = Field(
+        default_factory=dict,
+        description=(
+            "Mapping from module key to conditional function."
+            "Specifically used in the case of conditional links. "
+            "The conditional function should take in the output of the source module and "
+            "return a value that will then be mapped to the relevant destination module "
+            "in the conditional link."
+        ),
+    )
     dag: networkx.MultiDiGraph = Field(
         default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline."
     )
@@ -185,11 +217,16 @@ class QueryPipeline(QueryComponent):
 
     def add_links(
         self,
-        links: List[Link],
+        links: List[LINK_TYPE],
     ) -> None:
         """Add links to the pipeline."""
         for link in links:
-            self.add_link(**link.dict())
+            if isinstance(link, Link):
+                self.add_link(**link.dict())
+            elif isinstance(link, ConditionalLinks):
+                self.add_conditional_links(**link.dict())
+            else:
+                raise ValueError("Link must be of type `Link` or `ConditionalLinks`.")
 
     def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None:
         """Add modules to the pipeline."""
@@ -222,6 +259,24 @@ class QueryPipeline(QueryComponent):
             raise ValueError(f"Module {src} does not exist in pipeline.")
         self.dag.add_edge(src, dest, src_key=src_key, dest_key=dest_key)
 
+    def add_conditional_links(
+        self,
+        src: str,
+        fn: Callable,
+        cond_dest_dict: Dict[str, Any],
+    ) -> None:
+        """Add conditional links."""
+        if src not in self.module_dict:
+            raise ValueError(f"Module {src} does not exist in pipeline.")
+        self.conditional_dict[src] = fn
+        for conditional, dest_dict in cond_dest_dict.items():
+            self.dag.add_edge(
+                src,
+                dest_dict["dest"],
+                dest_key=dest_dict["dest_key"],
+                conditional=conditional,
+            )
+
     def get_root_keys(self) -> List[str]:
         """Get root keys."""
         return self._get_root_keys()
@@ -419,28 +474,83 @@ class QueryPipeline(QueryComponent):
 
     def _process_component_output(
         self,
+        queue: List[str],
         output_dict: Dict[str, Any],
         module_key: str,
         all_module_inputs: Dict[str, Dict[str, Any]],
         result_outputs: Dict[str, Any],
-    ) -> None:
+    ) -> List[str]:
         """Process component output."""
+        new_queue = queue.copy()
         # if there's no more edges, add result to output
         if module_key in self._get_leaf_keys():
             result_outputs[module_key] = output_dict
         else:
-            for _, dest, attr in self.dag.edges(module_key, data=True):
-                edge_module = self.module_dict[dest]
+            # first, process conditional edges. find the conditional edge
+            # that matches, and then remove all other edges from queue
+            # after, process regular edges
+
+            conditional_val: Optional[Any] = None
+            if module_key in self.conditional_dict:
+                # NOTE: we assume that the output of the module is a single key
+                single_output = get_single_output(output_dict)
+                # the conditional_val determines which edge to take
+                # new_output is the output that will be passed to the next module
+                conditional_val, new_output = self.conditional_dict[module_key](
+                    single_output
+                )
+            edge_list = list(self.dag.edges(module_key, data=True))
+            # get conditional edge list
+            conditional_edge_list = [
+                (src, dest, attr)
+                for src, dest, attr in edge_list
+                if "conditional" in attr
+            ]
+            # in conditional edge list, find matches
+            if len(conditional_edge_list) > 0:
+                match = next(
+                    iter(
+                        [
+                            (src, dest, attr)
+                            for src, dest, attr in conditional_edge_list
+                            if conditional_val == attr["conditional"]
+                        ]
+                    )
+                )
+                non_matches = [x for x in conditional_edge_list if x != match]
+                if len(non_matches) != len(conditional_edge_list) - 1:
+                    raise ValueError("Multiple conditional matches found or None.")
+                # remove all non-matches from queue
+                for non_match in non_matches:
+                    new_queue.remove(non_match[1])
+                # add match to module inputs
+                add_output_to_module_inputs(
+                    None,  # no src_key for conditional link
+                    match[2].get("dest_key"),
+                    {"output": new_output},
+                    self.module_dict[match[1]],
+                    all_module_inputs[match[1]],
+                )
 
+            # everything not in conditional_edge_list is regular
+            regular_edge_list = [
+                (src, dest, attr)
+                for src, dest, attr in edge_list
+                if "conditional" not in attr
+            ]
+            for _, dest, attr in regular_edge_list:
+                # if conditional link, check if it should be added
                 # add input to module_deps_inputs
                 add_output_to_module_inputs(
                     attr.get("src_key"),
                     attr.get("dest_key"),
                     output_dict,
-                    edge_module,
+                    self.module_dict[dest],
                     all_module_inputs[dest],
                 )
 
+        return new_queue
+
     def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
         """Run the pipeline for multiple roots.
 
@@ -474,8 +584,8 @@ class QueryPipeline(QueryComponent):
             output_dict = module.run_component(**module_input)
 
             # get new nodes and is_leaf
-            self._process_component_output(
-                output_dict, module_key, all_module_inputs, result_outputs
+            queue = self._process_component_output(
+                queue, output_dict, module_key, all_module_inputs, result_outputs
             )
 
         return result_outputs
@@ -541,8 +651,8 @@ class QueryPipeline(QueryComponent):
 
             for output_dict, module_key in zip(output_dicts, popped_nodes):
                 # get new nodes and is_leaf
-                self._process_component_output(
-                    output_dict, module_key, all_module_inputs, result_outputs
+                queue = self._process_component_output(
+                    queue, output_dict, module_key, all_module_inputs, result_outputs
                 )
 
         return result_outputs
diff --git a/tests/query_pipeline/test_query.py b/tests/query_pipeline/test_query.py
index 01cd129b6..b006df828 100644
--- a/tests/query_pipeline/test_query.py
+++ b/tests/query_pipeline/test_query.py
@@ -3,9 +3,10 @@
 from typing import Any, Dict
 
 import pytest
-from llama_index.core.query_pipeline.components import InputComponent
+from llama_index.core.query_pipeline.components import FnComponent, InputComponent
 from llama_index.core.query_pipeline.query_component import (
     ChainableMixin,
+    ConditionalLinks,
     InputKeys,
     Link,
     OutputKeys,
@@ -356,3 +357,47 @@ def test_query_pipeline_chain_str() -> None:
     p.add_chain(["a", "b", "c"])
     output = p.run(inp1=1, inp2=3)
     assert output == 11
+
+
+def test_query_pipeline_conditional_edges() -> None:
+    """Test conditional edges."""
+
+    def choose_fn(input: int) -> Dict:
+        """Choose."""
+        if input == 1:
+            toggle = "true"
+        else:
+            toggle = "false"
+        return {"toggle": toggle, "input": input}
+
+    p = QueryPipeline(
+        modules={
+            "input": InputComponent(),
+            "fn": FnComponent(fn=choose_fn),
+            "a": QueryComponent1(),
+            "b": QueryComponent2(),
+        },
+    )
+
+    p.add_links(
+        [
+            Link("input", "fn", src_key="inp1", dest_key="input"),
+            Link("input", "a", src_key="inp2", dest_key="input1"),
+            Link("input", "b", src_key="inp2", dest_key="input1"),
+            ConditionalLinks(
+                "fn",
+                lambda x: (x["toggle"], x["input"]),
+                {
+                    "true": {"dest": "a", "dest_key": "input2"},
+                    "false": {"dest": "b", "dest_key": "input2"},
+                },
+            ),
+        ]
+    )
+    output = p.run(inp1=1, inp2=3)
+    # should go to a
+    assert output == 4
+
+    output = p.run(inp1=2, inp2=3)
+    # should go to b
+    assert output == "3:2"
-- 
GitLab