diff --git a/llama-index-core/llama_index/core/tools/retriever_tool.py b/llama-index-core/llama_index/core/tools/retriever_tool.py index 03e40812ffebbd12737c9859f3122e32a611ee91..44e6050151513e444a0321fa80056acfee7d6257 100644 --- a/llama-index-core/llama_index/core/tools/retriever_tool.py +++ b/llama-index-core/llama_index/core/tools/retriever_tool.py @@ -88,8 +88,6 @@ class RetrieverTool(AsyncBaseTool): for doc in docs: assert isinstance(doc.node, (Node, TextNode)) node_copy = doc.node.model_copy() - node_copy.text_template = "{metadata_str}\n{content}" - node_copy.metadata_template = "{key} = {value}" content += node_copy.get_content(MetadataMode.LLM) + "\n\n" return ToolOutput( content=content, @@ -114,8 +112,6 @@ class RetrieverTool(AsyncBaseTool): for doc in docs: assert isinstance(doc.node, (Node, TextNode)) node_copy = doc.node.model_copy() - node_copy.text_template = "{metadata_str}\n{content}" - node_copy.metadata_template = "{key} = {value}" content += node_copy.get_content(MetadataMode.LLM) + "\n\n" return ToolOutput( content=content, diff --git a/llama-index-core/tests/tools/test_base.py b/llama-index-core/tests/tools/test_base.py index f890eb637c68ca9b6385bcac8da2cf85c2ef70d2..16d2b659058730ab96a928cbac132a7e5d5191c6 100644 --- a/llama-index-core/tests/tools/test_base.py +++ b/llama-index-core/tests/tools/test_base.py @@ -177,7 +177,7 @@ def test_retreiver_tool() -> None: ) output = vs_ret_tool.call("arg1", "arg2", key1="v1", key2="v2") formated_doc = ( - "file_path = /data/personal/essay.md\n" + "file_path: /data/personal/essay.md\n\n" "# title1:Hello world.\n" "This is a test." ) diff --git a/llama-index-core/tests/tools/test_retriever_tool.py b/llama-index-core/tests/tools/test_retriever_tool.py index 215cf823e26ec6e313854bd56dadf578f8c7d2ac..f99461188bb3721d15ea271981eb75a9e9858cce 100644 --- a/llama-index-core/tests/tools/test_retriever_tool.py +++ b/llama-index-core/tests/tools/test_retriever_tool.py @@ -1,22 +1,44 @@ """Test retriever tool.""" + from typing import List, Optional from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.schema import NodeWithScore, TextNode, QueryBundle from llama_index.core.tools import RetrieverTool from llama_index.core.postprocessor.types import BaseNodePostprocessor +import pytest class MockRetriever(BaseRetriever): """Custom retriever for testing.""" - def _retrieve(self, query_str: str) -> List[NodeWithScore]: + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Mock retrieval.""" - return [NodeWithScore(node=TextNode(text=f"mock_{query_str}"), score=0.9)] + return [ + NodeWithScore( + node=TextNode( + text=f"mock_{query_bundle}", + text_template="Metadata:\n{metadata_str}\n\nContent:\n{content}", + metadata_template="- {key}: {value}", + metadata={"key": "value"}, + ), + score=0.9, + ) + ] - async def _aretrieve(self, query_str: str) -> List[NodeWithScore]: + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Mock retrieval.""" - return [NodeWithScore(node=TextNode(text=f"mock_{query_str}"), score=0.9)] + return [ + NodeWithScore( + node=TextNode( + text=f"mock_{query_bundle}", + text_template="Metadata:\n{metadata_str}\n\nContent:\n{content}", + metadata_template="- {key}: {value}", + metadata={"key": "value"}, + ), + score=0.9, + ) + ] class MockPostProcessor(BaseNodePostprocessor): @@ -41,7 +63,10 @@ def test_retriever_tool() -> None: retriever = MockRetriever() retriever_tool = RetrieverTool.from_defaults(retriever=retriever) response_nodes = retriever_tool("hello world") - assert str(response_nodes) == "mock_hello world\n\n\n\n" + assert ( + str(response_nodes) + == "Metadata:\n- key: value\n\nContent:\nmock_hello world\n\n" + ) assert response_nodes.raw_output[0].node.text == "mock_hello world\n\n" # Test node_postprocessors @@ -50,4 +75,32 @@ def test_retriever_tool() -> None: retriever=retriever, node_postprocessors=node_postprocessors ) pr_response_nodes = pr_retriever_tool("hello world") - assert str(pr_response_nodes) == "processed_mock_hello world\n\n\n\n" + assert ( + str(pr_response_nodes) + == "Metadata:\n- key: value\n\nContent:\nprocessed_mock_hello world\n\n" + ) + + +@pytest.mark.asyncio() +async def test_retriever_tool_async() -> None: + """Test retriever tool async call.""" + # Test async retrieval + retriever = MockRetriever() + retriever_tool = RetrieverTool.from_defaults(retriever=retriever) + response_nodes = await retriever_tool.acall("hello world") + assert ( + str(response_nodes) + == "Metadata:\n- key: value\n\nContent:\nmock_hello world\n\n" + ) + assert response_nodes.raw_output[0].node.text == "mock_hello world\n\n" + + # Test node_postprocessors async + node_postprocessors = [MockPostProcessor()] + pr_retriever_tool = RetrieverTool.from_defaults( + retriever=retriever, node_postprocessors=node_postprocessors + ) + pr_response_nodes = await pr_retriever_tool.acall("hello world") + assert ( + str(pr_response_nodes) + == "Metadata:\n- key: value\n\nContent:\nprocessed_mock_hello world\n\n" + )