Skip to content
Snippets Groups Projects
Unverified Commit 0393b081 authored by Jerry Liu's avatar Jerry Liu Committed by GitHub
Browse files

add conditional links to query pipeline (#10520)

parent e5b163da
No related branches found
No related tags found
No related merge requests found
"""Pipeline schema."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Union, cast, get_args
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Set,
Union,
cast,
get_args,
)
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.callbacks.base import CallbackManager
......@@ -306,6 +317,33 @@ class Link(BaseModel):
super().__init__(src=src, dest=dest, src_key=src_key, dest_key=dest_key)
class ConditionalLinks(BaseModel):
"""Conditional Links between source and multiple destinations."""
src: str = Field(..., description="Source component name")
fn: Callable = Field(
..., description="Function to determine which destination to go to"
)
cond_dest_dict: Dict[str, Any] = Field(
..., description="dictionary of value to destination component name"
)
class Config:
arbitrary_types_allowed = True
def __init__(
self,
src: str,
fn: Callable,
cond_dest_dict: Dict[str, Any],
) -> None:
"""Init params."""
# NOTE: This is to enable positional args.
super().__init__(src=src, fn=fn, cond_dest_dict=cond_dest_dict)
# accept both QueryComponent and ChainableMixin as inputs to query pipeline
# ChainableMixin modules will be converted to components via `as_query_component`
QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin]
LINK_TYPE = Union[Link, ConditionalLinks]
......@@ -2,7 +2,18 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, get_args
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
get_args,
)
import networkx
......@@ -11,8 +22,10 @@ from llama_index.bridge.pydantic import Field
from llama_index.callbacks import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.query_pipeline.query_component import (
LINK_TYPE,
QUERY_COMPONENT_TYPE,
ChainableMixin,
ConditionalLinks,
InputKeys,
Link,
OutputKeys,
......@@ -21,8 +34,17 @@ from llama_index.core.query_pipeline.query_component import (
from llama_index.utils import print_text
def get_single_output(
output_dict: Dict[str, Any],
) -> Any:
"""Get single output."""
if len(output_dict) != 1:
raise ValueError("Output dict must have exactly one key.")
return next(iter(output_dict.values()))
def add_output_to_module_inputs(
src_key: str,
src_key: Optional[str],
dest_key: str,
output_dict: Dict[str, Any],
module: QueryComponent,
......@@ -111,6 +133,16 @@ class QueryPipeline(QueryComponent):
module_dict: Dict[str, QueryComponent] = Field(
default_factory=dict, description="The modules in the pipeline."
)
conditional_dict: Dict[str, Callable] = Field(
default_factory=dict,
description=(
"Mapping from module key to conditional function."
"Specifically used in the case of conditional links. "
"The conditional function should take in the output of the source module and "
"return a value that will then be mapped to the relevant destination module "
"in the conditional link."
),
)
dag: networkx.MultiDiGraph = Field(
default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline."
)
......@@ -185,11 +217,16 @@ class QueryPipeline(QueryComponent):
def add_links(
self,
links: List[Link],
links: List[LINK_TYPE],
) -> None:
"""Add links to the pipeline."""
for link in links:
self.add_link(**link.dict())
if isinstance(link, Link):
self.add_link(**link.dict())
elif isinstance(link, ConditionalLinks):
self.add_conditional_links(**link.dict())
else:
raise ValueError("Link must be of type `Link` or `ConditionalLinks`.")
def add_modules(self, module_dict: Dict[str, QUERY_COMPONENT_TYPE]) -> None:
"""Add modules to the pipeline."""
......@@ -222,6 +259,24 @@ class QueryPipeline(QueryComponent):
raise ValueError(f"Module {src} does not exist in pipeline.")
self.dag.add_edge(src, dest, src_key=src_key, dest_key=dest_key)
def add_conditional_links(
self,
src: str,
fn: Callable,
cond_dest_dict: Dict[str, Any],
) -> None:
"""Add conditional links."""
if src not in self.module_dict:
raise ValueError(f"Module {src} does not exist in pipeline.")
self.conditional_dict[src] = fn
for conditional, dest_dict in cond_dest_dict.items():
self.dag.add_edge(
src,
dest_dict["dest"],
dest_key=dest_dict["dest_key"],
conditional=conditional,
)
def get_root_keys(self) -> List[str]:
"""Get root keys."""
return self._get_root_keys()
......@@ -419,28 +474,83 @@ class QueryPipeline(QueryComponent):
def _process_component_output(
self,
queue: List[str],
output_dict: Dict[str, Any],
module_key: str,
all_module_inputs: Dict[str, Dict[str, Any]],
result_outputs: Dict[str, Any],
) -> None:
) -> List[str]:
"""Process component output."""
new_queue = queue.copy()
# if there's no more edges, add result to output
if module_key in self._get_leaf_keys():
result_outputs[module_key] = output_dict
else:
for _, dest, attr in self.dag.edges(module_key, data=True):
edge_module = self.module_dict[dest]
# first, process conditional edges. find the conditional edge
# that matches, and then remove all other edges from queue
# after, process regular edges
conditional_val: Optional[Any] = None
if module_key in self.conditional_dict:
# NOTE: we assume that the output of the module is a single key
single_output = get_single_output(output_dict)
# the conditional_val determines which edge to take
# new_output is the output that will be passed to the next module
conditional_val, new_output = self.conditional_dict[module_key](
single_output
)
edge_list = list(self.dag.edges(module_key, data=True))
# get conditional edge list
conditional_edge_list = [
(src, dest, attr)
for src, dest, attr in edge_list
if "conditional" in attr
]
# in conditional edge list, find matches
if len(conditional_edge_list) > 0:
match = next(
iter(
[
(src, dest, attr)
for src, dest, attr in conditional_edge_list
if conditional_val == attr["conditional"]
]
)
)
non_matches = [x for x in conditional_edge_list if x != match]
if len(non_matches) != len(conditional_edge_list) - 1:
raise ValueError("Multiple conditional matches found or None.")
# remove all non-matches from queue
for non_match in non_matches:
new_queue.remove(non_match[1])
# add match to module inputs
add_output_to_module_inputs(
None, # no src_key for conditional link
match[2].get("dest_key"),
{"output": new_output},
self.module_dict[match[1]],
all_module_inputs[match[1]],
)
# everything not in conditional_edge_list is regular
regular_edge_list = [
(src, dest, attr)
for src, dest, attr in edge_list
if "conditional" not in attr
]
for _, dest, attr in regular_edge_list:
# if conditional link, check if it should be added
# add input to module_deps_inputs
add_output_to_module_inputs(
attr.get("src_key"),
attr.get("dest_key"),
output_dict,
edge_module,
self.module_dict[dest],
all_module_inputs[dest],
)
return new_queue
def _run_multi(self, module_input_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Run the pipeline for multiple roots.
......@@ -474,8 +584,8 @@ class QueryPipeline(QueryComponent):
output_dict = module.run_component(**module_input)
# get new nodes and is_leaf
self._process_component_output(
output_dict, module_key, all_module_inputs, result_outputs
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
......@@ -541,8 +651,8 @@ class QueryPipeline(QueryComponent):
for output_dict, module_key in zip(output_dicts, popped_nodes):
# get new nodes and is_leaf
self._process_component_output(
output_dict, module_key, all_module_inputs, result_outputs
queue = self._process_component_output(
queue, output_dict, module_key, all_module_inputs, result_outputs
)
return result_outputs
......
......@@ -3,9 +3,10 @@
from typing import Any, Dict
import pytest
from llama_index.core.query_pipeline.components import InputComponent
from llama_index.core.query_pipeline.components import FnComponent, InputComponent
from llama_index.core.query_pipeline.query_component import (
ChainableMixin,
ConditionalLinks,
InputKeys,
Link,
OutputKeys,
......@@ -356,3 +357,47 @@ def test_query_pipeline_chain_str() -> None:
p.add_chain(["a", "b", "c"])
output = p.run(inp1=1, inp2=3)
assert output == 11
def test_query_pipeline_conditional_edges() -> None:
"""Test conditional edges."""
def choose_fn(input: int) -> Dict:
"""Choose."""
if input == 1:
toggle = "true"
else:
toggle = "false"
return {"toggle": toggle, "input": input}
p = QueryPipeline(
modules={
"input": InputComponent(),
"fn": FnComponent(fn=choose_fn),
"a": QueryComponent1(),
"b": QueryComponent2(),
},
)
p.add_links(
[
Link("input", "fn", src_key="inp1", dest_key="input"),
Link("input", "a", src_key="inp2", dest_key="input1"),
Link("input", "b", src_key="inp2", dest_key="input1"),
ConditionalLinks(
"fn",
lambda x: (x["toggle"], x["input"]),
{
"true": {"dest": "a", "dest_key": "input2"},
"false": {"dest": "b", "dest_key": "input2"},
},
),
]
)
output = p.run(inp1=1, inp2=3)
# should go to a
assert output == 4
output = p.run(inp1=2, inp2=3)
# should go to b
assert output == "3:2"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment