diff --git a/docs/examples/pipeline/query_pipeline_routing.ipynb b/docs/examples/pipeline/query_pipeline_routing.ipynb index ee5a58351652a2805067a66617a7762ef5d298e6..ba93c0b264dd4552fdf6f1be821695bcd08cb6d2 100644 --- a/docs/examples/pipeline/query_pipeline_routing.ipynb +++ b/docs/examples/pipeline/query_pipeline_routing.ipynb @@ -11,7 +11,11 @@ "\n", "Routing lets us dynamically choose underlying query pipelines to use given the query and a set of choices.\n", "\n", - "We offer this as an out-of-the-box abstraction in our [Router Query Engine](https://docs.llamaindex.ai/en/stable/examples/query_engine/RouterQueryEngine.html) guide. Here we show you how to compose a similar pipeline using our Query Pipeline syntax - this allows you to not only define query engines but easily stitch it into a chain/DAG with other modules across the compute graph." + "We offer this as an out-of-the-box abstraction in our [Router Query Engine](https://docs.llamaindex.ai/en/stable/examples/query_engine/RouterQueryEngine.html) guide. Here we show you how to compose a similar pipeline using our Query Pipeline syntax - this allows you to not only define query engines but easily stitch it into a chain/DAG with other modules across the compute graph.\n", + "\n", + "We show this in two ways:\n", + "1. **Using a Router Component**: This is a Component that is composed on top of other query pipelines, and selects them based on a condition.\n", + "2. **Using Conditional Edges**: You can make the edges in a graph \"conditional\", meaning that they are only picked if certain conditions are met." ] }, { @@ -34,16 +38,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2024-01-10 12:31:00-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", + "--2024-02-10 00:31:34-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8003::154, 2606:50c0:8001::154, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 75042 (73K) [text/plain]\n", "Saving to: ‘pg_essay.txt’\n", "\n", "pg_essay.txt 100%[===================>] 73.28K --.-KB/s in 0.01s \n", "\n", - "2024-01-10 12:31:00 (6.32 MB/s) - ‘pg_essay.txt’ saved [75042/75042]\n", + "2024-02-10 00:31:34 (6.78 MB/s) - ‘pg_essay.txt’ saved [75042/75042]\n", "\n" ] } @@ -65,20 +69,12 @@ "documents = reader.load_data()" ] }, - { - "cell_type": "markdown", - "id": "6c1d5ff8-ae04-4ea3-bbe0-2c097af71efd", - "metadata": {}, - "source": [ - "## Setup Query Pipeline with Routing" - ] - }, { "cell_type": "markdown", "id": "63caf998-0a88-4c50-b6a4-2a0c412bde5b", "metadata": {}, "source": [ - "### Define Modules\n", + "## Define Modules\n", "\n", "We define llm, vector index, summary index, and prompt templates." ] @@ -127,6 +123,7 @@ "\n", "# define vector retriever\n", "vector_index = VectorStoreIndex.from_documents(documents)\n", + "vector_retriever = vector_index.as_retriever(similarity_top_k=2)\n", "vector_query_engine = vector_index.as_query_engine(similarity_top_k=2)\n", "\n", "# define summary query prompts + retrievers\n", @@ -140,12 +137,23 @@ "only some pieces of context (or none) maybe be relevant.\n", "\"\"\"\n", "summary_qrewrite_prompt = PromptTemplate(summary_qrewrite_str)\n", + "summary_retriever = summary_index.as_retriever()\n", "summary_query_engine = summary_index.as_query_engine()\n", "\n", "# define selector\n", "selector = LLMSingleSelector.from_defaults()" ] }, + { + "cell_type": "markdown", + "id": "6c1d5ff8-ae04-4ea3-bbe0-2c097af71efd", + "metadata": {}, + "source": [ + "## Setup Query Pipeline with `RouterComponent`\n", + "\n", + "In the first approach, we show you how to setup a query pipeline with a `RouterComponent`. The `RouterComponent` specifically takes in one of our `Selector` modules, and given a set of choices (with string descriptions), chooses the relevant choice and calls the relevant sub-component." + ] + }, { "cell_type": "markdown", "id": "7a87a439-88e6-4130-b28f-45268330d3e4", @@ -191,7 +199,7 @@ "id": "bda05274-09c5-4b56-b2ba-57f445346e73", "metadata": {}, "source": [ - "## Try out Queries" + "### Try out Queries" ] }, { @@ -253,6 +261,298 @@ "response = qp.run(\"What is a summary of this document?\")\n", "print(str(response))" ] + }, + { + "cell_type": "markdown", + "id": "2c1a7118-875a-4d4f-996c-fb0bab0c9b55", + "metadata": {}, + "source": [ + "## Setup Query Pipeline with Conditional Links\n", + "\n", + "In the example below we should you how to build our query pipeline with conditional links to route between our summary query engine and router query engine depending on the user query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2c6e758-1b6f-484f-952e-c89cb62549d1", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.query_pipeline import (\n", + " QueryPipeline,\n", + " InputComponent,\n", + " Link,\n", + " FnComponent,\n", + ")\n", + "from llama_index.selectors import LLMSingleSelector\n", + "from typing import Dict" + ] + }, + { + "cell_type": "markdown", + "id": "924f49f7-624c-4f93-87f0-1474e1b2f999", + "metadata": {}, + "source": [ + "We first initialize our `LLMSingleSelector` component. Given a set of choices and a user query it will output a selection indicating the choice it picks.\n", + "\n", + "Note that the `LLMSingleSelector` can be directly used in a query pipeline. However here we wrap it in a `FnComponent` so that we can return the output as a dictionary of both the selected index and the original user query (this will help when we define our conditional link)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d179b914-12b6-4cdd-89b2-9128714135c1", + "metadata": {}, + "outputs": [], + "source": [ + "choices = [\n", + " \"This tool answers specific questions about the document (not summary questions across the document)\",\n", + " \"This tool answers summary questions about the document (not specific questions)\",\n", + "]\n", + "\n", + "\n", + "def select_choice(query: str) -> Dict:\n", + " selector = LLMSingleSelector.from_defaults()\n", + " output = selector.select(choices, query)\n", + " return {\"query\": query, \"index\": str(output.ind)}" + ] + }, + { + "cell_type": "markdown", + "id": "932f0386-7dc3-4072-8777-3216613b4a51", + "metadata": {}, + "source": [ + "We now initialize our Query Pipeline with the modules: input, selector, vector/summary query engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bb637ac-8aed-4b7c-ab33-57af8dd8c369", + "metadata": {}, + "outputs": [], + "source": [ + "qp = QueryPipeline(\n", + " modules={\n", + " \"input\": InputComponent(),\n", + " \"selector\": FnComponent(fn=select_choice),\n", + " \"vector_retriever\": vector_retriever,\n", + " \"summary_retriever\": summary_retriever,\n", + " \"summarizer\": summarizer,\n", + " },\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c1dc8ab4-9c25-4e03-a9b0-fd901e9d6679", + "metadata": {}, + "source": [ + "We now define our links. Input --> selector is standard. What's more interesting here is our conditional link. \n", + "\n", + "We input our selector component as the source. \n", + "\n", + "We then input a function that produces two outputs, the first being the condition variable and the second being the child component.\n", + "\n", + "Lastly, we define a dictionary mapping each condition variable value to the component (which is represented as a dictionary with \"dest\" and \"dest_key\"). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57aa6bcf-f42b-4902-ab96-fe50afcae371", + "metadata": {}, + "outputs": [], + "source": [ + "qp.add_link(\"input\", \"selector\")\n", + "qp.add_link(\n", + " \"selector\",\n", + " \"vector_retriever\",\n", + " condition_fn=lambda x: x[\"index\"] == \"0\",\n", + " input_fn=lambda x: x[\"query\"],\n", + ")\n", + "qp.add_link(\n", + " \"selector\",\n", + " \"summary_retriever\",\n", + " condition_fn=lambda x: x[\"index\"] == \"1\",\n", + " input_fn=lambda x: x[\"query\"],\n", + ")\n", + "qp.add_link(\"vector_retriever\", \"summarizer\", dest_key=\"nodes\")\n", + "qp.add_link(\"summary_retriever\", \"summarizer\", dest_key=\"nodes\")\n", + "qp.add_link(\"input\", \"summarizer\", dest_key=\"query_str\")" + ] + }, + { + "cell_type": "markdown", + "id": "fa2305e7-3a4b-4e3d-84a5-f5469f096f55", + "metadata": {}, + "source": [ + "### Visualize\n", + "\n", + "The benefit of conditional links is that " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63381dfa-4338-47d5-b0d3-12964391f8e1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rag_dag.html\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " <iframe\n", + " width=\"100%\"\n", + " height=\"600px\"\n", + " src=\"rag_dag.html\"\n", + " frameborder=\"0\"\n", + " allowfullscreen\n", + " \n", + " ></iframe>\n", + " " + ], + "text/plain": [ + "<IPython.lib.display.IFrame at 0x2993a3ca0>" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "## create graph\n", + "from pyvis.network import Network\n", + "\n", + "net = Network(notebook=True, cdn_resources=\"in_line\", directed=True)\n", + "net.from_nx(qp.clean_dag)\n", + "net.show(\"rag_dag.html\")" + ] + }, + { + "cell_type": "markdown", + "id": "884d8696-0579-4b1d-be97-e95a26f26911", + "metadata": {}, + "source": [ + "### Try out Queries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7c47e77-137e-4913-914f-b80f3ccfd30a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", + "input: What did the author do during his time in YC?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module selector with input: \n", + "query: What did the author do during his time in YC?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module vector_retriever with input: \n", + "input: What did the author do during his time in YC?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module summarizer with input: \n", + "query_str: What did the author do during his time in YC?\n", + "nodes: [NodeWithScore(node=TextNode(id_='e7937e10-61ca-435a-a6a7-4232cdd7bed8', embedding=None, metadata={'file_path': 'pg_essay.txt', 'file_name': 'pg_essay.txt', 'file_type': 'text/plain', 'file_size': 750...\n", + "\n", + "\u001b[0mDuring his time in YC, the author worked on various tasks related to running the program. He selected and helped founders, dealt with disputes between cofounders, figured out when people were lying, and fought with people who maltreated the startups. Additionally, he wrote all of YC's internal software and worked on other projects such as writing essays and managing Hacker News. The author worked hard and wanted YC to be successful, so he dedicated a significant amount of time and effort to his responsibilities.\n" + ] + } + ], + "source": [ + "# compare with sync method\n", + "response = qp.run(input=\"What did the author do during his time in YC?\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a05f4c1-aaad-44ff-8b67-3e7a29afea01", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(response.source_nodes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e3f9cbc-9bbe-40af-a06e-b49e7a7b80b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1;3;38;2;155;135;227m> Running module input with input: \n", + "input: What is a summary of this document?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module selector with input: \n", + "query: What is a summary of this document?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module summary_retriever with input: \n", + "input: What is a summary of this document?\n", + "\n", + "\u001b[0m\u001b[1;3;38;2;155;135;227m> Running module summarizer with input: \n", + "query_str: What is a summary of this document?\n", + "nodes: [NodeWithScore(node=TextNode(id_='f6d0a4c7-3806-4169-8e87-ba7870335312', embedding=None, metadata={'file_path': 'pg_essay.txt', 'file_name': 'pg_essay.txt', 'file_type': 'text/plain', 'file_size': 750...\n", + "\n", + "\u001b[0mThe document provides a personal narrative of the author's journey from studying art in Florence to working at a software company in the US. It discusses their experiences at the Accademia di Belli Arti in Florence, their time working at Interleaf, their decision to drop out of art school and pursue painting and writing books on Lisp, and their failed attempt to start an online art gallery business. The document also covers the founding and early years of Viaweb, a company that aimed to build online stores, and the author's experiences after selling the company, including their struggles with painting and finding a sense of purpose. It concludes with the author's decision to start Y Combinator, an investment firm, and their exploration of Lisp programming language and the development of a new Lisp language called Bel.\n" + ] + } + ], + "source": [ + "response = qp.run(input=\"What is a summary of this document?\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1942e935-e390-41fe-abaf-f21c1282b83d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(response.source_nodes)" + ] } ], "metadata": { diff --git a/llama_index/core/query_pipeline/query_component.py b/llama_index/core/query_pipeline/query_component.py index 474b1e5180c404579386886355a232de40dd2078..6644431714ef5c1ef7bbbe5af7594bb76ec5b4d9 100644 --- a/llama_index/core/query_pipeline/query_component.py +++ b/llama_index/core/query_pipeline/query_component.py @@ -305,45 +305,34 @@ class Link(BaseModel): default=None, description="Destination component input key" ) - def __init__( - self, - src: str, - dest: str, - src_key: Optional[str] = None, - dest_key: Optional[str] = None, - ) -> None: - """Init params.""" - # NOTE: This is to enable positional args. - super().__init__(src=src, dest=dest, src_key=src_key, dest_key=dest_key) - - -class ConditionalLinks(BaseModel): - """Conditional Links between source and multiple destinations.""" - - src: str = Field(..., description="Source component name") - fn: Callable = Field( - ..., description="Function to determine which destination to go to" + condition_fn: Optional[Callable] = Field( + default=None, description="Condition to determine if link should be followed" ) - cond_dest_dict: Dict[str, Any] = Field( - ..., description="dictionary of value to destination component name" + input_fn: Optional[Callable] = Field( + default=None, description="Input to destination component" ) - class Config: - arbitrary_types_allowed = True - def __init__( self, src: str, - fn: Callable, - cond_dest_dict: Dict[str, Any], + dest: str, + src_key: Optional[str] = None, + dest_key: Optional[str] = None, + condition_fn: Optional[Callable] = None, + input_fn: Optional[Callable] = None, ) -> None: """Init params.""" # NOTE: This is to enable positional args. - super().__init__(src=src, fn=fn, cond_dest_dict=cond_dest_dict) + super().__init__( + src=src, + dest=dest, + src_key=src_key, + dest_key=dest_key, + condition_fn=condition_fn, + input_fn=input_fn, + ) # accept both QueryComponent and ChainableMixin as inputs to query pipeline # ChainableMixin modules will be converted to components via `as_query_component` QUERY_COMPONENT_TYPE = Union[QueryComponent, ChainableMixin] - -LINK_TYPE = Union[Link, ConditionalLinks] diff --git a/llama_index/query_pipeline/query.py b/llama_index/query_pipeline/query.py index ab40550f16a319730d9798df413a6283b6e59a1e..bd5b8a4d6dbdea745933f3e6e2d7a2882d666639 100644 --- a/llama_index/query_pipeline/query.py +++ b/llama_index/query_pipeline/query.py @@ -22,10 +22,8 @@ from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.query_pipeline.query_component import ( - LINK_TYPE, QUERY_COMPONENT_TYPE, ChainableMixin, - ConditionalLinks, InputKeys, Link, OutputKeys, @@ -34,22 +32,10 @@ from llama_index.core.query_pipeline.query_component import ( from llama_index.utils import print_text -def get_single_output( - output_dict: Dict[str, Any], -) -> Any: - """Get single output.""" - if len(output_dict) != 1: - raise ValueError("Output dict must have exactly one key.") - return next(iter(output_dict.values())) - - -def add_output_to_module_inputs( +def get_output( src_key: Optional[str], - dest_key: str, output_dict: Dict[str, Any], - module: QueryComponent, - module_inputs: Dict[str, Any], -) -> None: +) -> Any: """Add input to module deps inputs.""" # get relevant output from link if src_key is None: @@ -59,7 +45,16 @@ def add_output_to_module_inputs( output = next(iter(output_dict.values())) else: output = output_dict[src_key] + return output + +def add_output_to_module_inputs( + dest_key: str, + output: Any, + module: QueryComponent, + module_inputs: Dict[str, Any], +) -> None: + """Add input to module deps inputs.""" # now attach output to relevant input key for module if dest_key is None: free_keys = module.free_req_input_keys @@ -116,6 +111,27 @@ def print_debug_input_multi( print_text(output + "\n", color="llama_lavender") +# Function to clean non-serializable attributes and return a copy of the graph +# https://stackoverflow.com/questions/23268421/networkx-how-to-access-attributes-of-objects-as-nodes +def clean_graph_attributes_copy(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: + # Create a deep copy of the graph to preserve the original + graph_copy = graph.copy() + + # Iterate over nodes and clean attributes + for node, attributes in graph_copy.nodes(data=True): + for key, value in list(attributes.items()): + if callable(value): # Checks if the value is a function + del attributes[key] # Remove the attribute if it's non-serializable + + # Similarly, you can extend this to clean edge attributes if necessary + for u, v, attributes in graph_copy.edges(data=True): + for key, value in list(attributes.items()): + if callable(value): # Checks if the value is a function + del attributes[key] # Remove the attribute if it's non-serializable + + return graph_copy + + CHAIN_COMPONENT_TYPE = Union[QUERY_COMPONENT_TYPE, str] @@ -133,16 +149,6 @@ class QueryPipeline(QueryComponent): module_dict: Dict[str, QueryComponent] = Field( default_factory=dict, description="The modules in the pipeline." ) - conditional_dict: Dict[str, Callable] = Field( - default_factory=dict, - description=( - "Mapping from module key to conditional function." - "Specifically used in the case of conditional links. " - "The conditional function should take in the output of the source module and " - "return a value that will then be mapped to the relevant destination module " - "in the conditional link." - ), - ) dag: networkx.MultiDiGraph = Field( default_factory=networkx.MultiDiGraph, description="The DAG of the pipeline." ) @@ -217,14 +223,12 @@ class QueryPipeline(QueryComponent): def add_links( self, - links: List[LINK_TYPE], + links: List[Link], ) -> None: """Add links to the pipeline.""" for link in links: if isinstance(link, Link): self.add_link(**link.dict()) - elif isinstance(link, ConditionalLinks): - self.add_conditional_links(**link.dict()) else: raise ValueError("Link must be of type `Link` or `ConditionalLinks`.") @@ -253,29 +257,20 @@ class QueryPipeline(QueryComponent): dest: str, src_key: Optional[str] = None, dest_key: Optional[str] = None, + condition_fn: Optional[Callable] = None, + input_fn: Optional[Callable] = None, ) -> None: """Add a link between two modules.""" if src not in self.module_dict: raise ValueError(f"Module {src} does not exist in pipeline.") - self.dag.add_edge(src, dest, src_key=src_key, dest_key=dest_key) - - def add_conditional_links( - self, - src: str, - fn: Callable, - cond_dest_dict: Dict[str, Any], - ) -> None: - """Add conditional links.""" - if src not in self.module_dict: - raise ValueError(f"Module {src} does not exist in pipeline.") - self.conditional_dict[src] = fn - for conditional, dest_dict in cond_dest_dict.items(): - self.dag.add_edge( - src, - dest_dict["dest"], - dest_key=dest_dict["dest_key"], - conditional=conditional, - ) + self.dag.add_edge( + src, + dest, + src_key=src_key, + dest_key=dest_key, + condition_fn=condition_fn, + input_fn=input_fn, + ) def get_root_keys(self) -> List[str]: """Get root keys.""" @@ -486,68 +481,33 @@ class QueryPipeline(QueryComponent): if module_key in self._get_leaf_keys(): result_outputs[module_key] = output_dict else: - # first, process conditional edges. find the conditional edge - # that matches, and then remove all other edges from queue - # after, process regular edges - - conditional_val: Optional[Any] = None - if module_key in self.conditional_dict: - # NOTE: we assume that the output of the module is a single key - single_output = get_single_output(output_dict) - # the conditional_val determines which edge to take - # new_output is the output that will be passed to the next module - conditional_val, new_output = self.conditional_dict[module_key]( - single_output - ) edge_list = list(self.dag.edges(module_key, data=True)) - # get conditional edge list - conditional_edge_list = [ - (src, dest, attr) - for src, dest, attr in edge_list - if "conditional" in attr - ] - # in conditional edge list, find matches - if len(conditional_edge_list) > 0: - match = next( - iter( - [ - (src, dest, attr) - for src, dest, attr in conditional_edge_list - if conditional_val == attr["conditional"] - ] - ) - ) - non_matches = [x for x in conditional_edge_list if x != match] - if len(non_matches) != len(conditional_edge_list) - 1: - raise ValueError("Multiple conditional matches found or None.") - # remove all non-matches from queue - for non_match in non_matches: - new_queue.remove(non_match[1]) - # add match to module inputs - add_output_to_module_inputs( - None, # no src_key for conditional link - match[2].get("dest_key"), - {"output": new_output}, - self.module_dict[match[1]], - all_module_inputs[match[1]], - ) - # everything not in conditional_edge_list is regular - regular_edge_list = [ - (src, dest, attr) - for src, dest, attr in edge_list - if "conditional" not in attr - ] - for _, dest, attr in regular_edge_list: - # if conditional link, check if it should be added - # add input to module_deps_inputs - add_output_to_module_inputs( - attr.get("src_key"), - attr.get("dest_key"), - output_dict, - self.module_dict[dest], - all_module_inputs[dest], - ) + for _, dest, attr in edge_list: + output = get_output(attr.get("src_key"), output_dict) + + # if input_fn is not None, use it to modify the input + if attr["input_fn"] is not None: + dest_output = attr["input_fn"](output) + else: + dest_output = output + + add_edge = True + if attr["condition_fn"] is not None: + conditional_val = attr["condition_fn"](output) + if not conditional_val: + add_edge = False + + if add_edge: + add_output_to_module_inputs( + attr.get("dest_key"), + dest_output, + self.module_dict[dest], + all_module_inputs[dest], + ) + else: + # remove dest from queue + new_queue.remove(dest) return new_queue @@ -705,3 +665,8 @@ class QueryPipeline(QueryComponent): def sub_query_components(self) -> List[QueryComponent]: """Sub query components.""" return list(self.module_dict.values()) + + @property + def clean_dag(self) -> networkx.DiGraph: + """Clean dag.""" + return clean_graph_attributes_copy(self.dag) diff --git a/tests/query_pipeline/test_query.py b/tests/query_pipeline/test_query.py index b006df828829f2efcc2e758271edbbaf102d8dc6..9cd599ee67a85f8704fc626ee49940eb1881ff0d 100644 --- a/tests/query_pipeline/test_query.py +++ b/tests/query_pipeline/test_query.py @@ -6,7 +6,6 @@ import pytest from llama_index.core.query_pipeline.components import FnComponent, InputComponent from llama_index.core.query_pipeline.query_component import ( ChainableMixin, - ConditionalLinks, InputKeys, Link, OutputKeys, @@ -384,13 +383,19 @@ def test_query_pipeline_conditional_edges() -> None: Link("input", "fn", src_key="inp1", dest_key="input"), Link("input", "a", src_key="inp2", dest_key="input1"), Link("input", "b", src_key="inp2", dest_key="input1"), - ConditionalLinks( + Link( "fn", - lambda x: (x["toggle"], x["input"]), - { - "true": {"dest": "a", "dest_key": "input2"}, - "false": {"dest": "b", "dest_key": "input2"}, - }, + "a", + dest_key="input2", + condition_fn=lambda x: x["toggle"] == "true", + input_fn=lambda x: x["input"], + ), + Link( + "fn", + "b", + dest_key="input2", + condition_fn=lambda x: x["toggle"] == "false", + input_fn=lambda x: x["input"], ), ] )