From 3bfb467ae4a0b5de38e7a3d82fb58bb7887ec579 Mon Sep 17 00:00:00 2001 From: Azure Wang <azurewtl@hotmail.com> Date: Tue, 21 Nov 2023 00:45:18 +0800 Subject: [PATCH] Refactor `callback_manager` for `VectorIndexRetriever` (#8871) --- CHANGELOG.md | 4 ++ experimental/colbert_index/retriever.py | 5 ++- llama_index/core/base_retriever.py | 40 ++++++++++++++++--- .../indices/document_summary/retrievers.py | 11 ++++- llama_index/indices/empty/retrievers.py | 3 ++ .../indices/keyword_table/retrievers.py | 3 ++ .../indices/knowledge_graph/retrievers.py | 5 +++ llama_index/indices/list/retrievers.py | 13 +++++- .../indices/managed/vectara/retriever.py | 5 ++- llama_index/indices/multi_modal/retriever.py | 3 ++ .../indices/struct_store/sql_retriever.py | 10 ++++- .../indices/tree/all_leaf_retriever.py | 11 ++++- .../indices/tree/select_leaf_retriever.py | 4 ++ .../indices/tree/tree_root_retriever.py | 12 +++++- llama_index/indices/vector_store/base.py | 1 + .../auto_retriever/auto_retriever.py | 3 ++ .../vector_store/retrievers/retriever.py | 5 ++- .../query_engine/retriever_query_engine.py | 21 +--------- .../retrievers/auto_merging_retriever.py | 6 ++- llama_index/retrievers/bm25_retriever.py | 4 +- llama_index/retrievers/fusion_retriever.py | 3 ++ llama_index/retrievers/recursive_retriever.py | 4 +- llama_index/retrievers/transform_retriever.py | 3 ++ llama_index/retrievers/you_retriever.py | 9 ++++- 24 files changed, 145 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9374ca8292..79ac934c7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### New Features + +- Added callback manager to each retriever (#8871) + ### Bug Fixes / Nits - Fixed bug in formatting chat prompt templates when estimating chunk sizes (#9025) diff --git a/experimental/colbert_index/retriever.py b/experimental/colbert_index/retriever.py index 6473f23a2d..199dfa7860 100644 --- a/experimental/colbert_index/retriever.py +++ b/experimental/colbert_index/retriever.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.schema import NodeWithScore, QueryBundle @@ -28,19 +29,19 @@ class ColbertRetriever(BaseRetriever): filters: Optional[MetadataFilters] = None, node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._index = index self._service_context = self._index.service_context self._docstore = self._index.docstore - self._similarity_top_k = similarity_top_k self._node_ids = node_ids self._doc_ids = doc_ids self._filters = filters - self._kwargs: Dict[str, Any] = kwargs.get("colbert_kwargs", {}) + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/core/base_retriever.py b/llama_index/core/base_retriever.py index baf22e3160..2655d1788c 100644 --- a/llama_index/core/base_retriever.py +++ b/llama_index/core/base_retriever.py @@ -2,14 +2,20 @@ from abc import abstractmethod from typing import List, Optional +from llama_index.callbacks.base import CallbackManager +from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.indices.query.schema import QueryBundle, QueryType +from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.schema import NodeWithScore, QueryBundle, QueryType -from llama_index.service_context import ServiceContext +from llama_index.schema import NodeWithScore class BaseRetriever(PromptMixin): """Base retriever.""" + def __init__(self, callback_manager: Optional[CallbackManager]) -> None: + self.callback_manager = callback_manager or CallbackManager([]) + def _get_prompts(self) -> PromptDictType: """Get prompts.""" return {} @@ -30,13 +36,35 @@ class BaseRetriever(PromptMixin): """ if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return self._retrieve(str_or_query_bundle) + query_bundle = QueryBundle(str_or_query_bundle) + else: + query_bundle = str_or_query_bundle + with self.callback_manager.as_trace("query"): + with self.callback_manager.event( + CBEventType.RETRIEVE, + payload={EventPayload.QUERY_STR: query_bundle.query_str}, + ) as retrieve_event: + nodes = self._retrieve(query_bundle) + retrieve_event.on_end( + payload={EventPayload.NODES: nodes}, + ) + return nodes async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return await self._aretrieve(str_or_query_bundle) + query_bundle = QueryBundle(str_or_query_bundle) + else: + query_bundle = str_or_query_bundle + with self.callback_manager.as_trace("query"): + with self.callback_manager.event( + CBEventType.RETRIEVE, + payload={EventPayload.QUERY_STR: query_bundle.query_str}, + ) as retrieve_event: + nodes = await self._aretrieve(query_bundle) + retrieve_event.on_end( + payload={EventPayload.NODES: nodes}, + ) + return nodes @abstractmethod def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index 5dc3e47d11..5c17522165 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -7,6 +7,7 @@ This module contains retrievers for document summary indices. import logging from typing import Any, Callable, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.document_summary.base import DocumentSummaryIndex from llama_index.indices.utils import ( @@ -46,6 +47,7 @@ class DocumentSummaryIndexLLMRetriever(BaseRetriever): format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, service_context: Optional[ServiceContext] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -61,6 +63,7 @@ class DocumentSummaryIndexLLMRetriever(BaseRetriever): parse_choice_select_answer_fn or default_parse_choice_select_answer_fn ) self._service_context = service_context or index.service_context + super().__init__(callback_manager) def _retrieve( self, @@ -115,7 +118,11 @@ class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever): """ def __init__( - self, index: DocumentSummaryIndex, similarity_top_k: int = 1, **kwargs: Any + self, + index: DocumentSummaryIndex, + similarity_top_k: int = 1, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, ) -> None: """Init params.""" self._index = index @@ -123,8 +130,8 @@ class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever): self._service_context = self._index.service_context self._docstore = self._index.docstore self._index_struct = self._index.index_struct - self._similarity_top_k = similarity_top_k + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/empty/retrievers.py b/llama_index/indices/empty/retrievers.py index f5aeed4baa..e79532bc57 100644 --- a/llama_index/indices/empty/retrievers.py +++ b/llama_index/indices/empty/retrievers.py @@ -1,6 +1,7 @@ """Default query for EmptyIndex.""" from typing import Any, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.empty.base import EmptyIndex from llama_index.prompts import BasePromptTemplate @@ -23,11 +24,13 @@ class EmptyIndexRetriever(BaseRetriever): self, index: EmptyIndex, input_prompt: Optional[BasePromptTemplate] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._index = index self._input_prompt = input_prompt or DEFAULT_SIMPLE_INPUT_PROMPT + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve relevant nodes.""" diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index 83f43e732d..af216c3550 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -4,6 +4,7 @@ from abc import abstractmethod from collections import defaultdict from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.base import BaseKeywordTableIndex from llama_index.indices.keyword_table.utils import ( @@ -52,6 +53,7 @@ class BaseKeywordTableRetriever(BaseRetriever): query_keyword_extract_template: Optional[BasePromptTemplate] = None, max_keywords_per_query: int = 10, num_chunks_per_query: int = 10, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -66,6 +68,7 @@ class BaseKeywordTableRetriever(BaseRetriever): keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE ) self.query_keyword_extract_template = query_keyword_extract_template or DQKET + super().__init__(callback_manager) @abstractmethod def _get_keywords(self, query_str: str) -> List[str]: diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index 7a6490c67f..72fe5efe56 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -4,6 +4,7 @@ from collections import defaultdict from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex @@ -89,6 +90,7 @@ class KGTableRetriever(BaseRetriever): graph_store_query_depth: int = 2, use_global_node_triplets: bool = False, max_knowledge_sequence: int = REL_TEXT_LIMIT, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -118,6 +120,7 @@ class KGTableRetriever(BaseRetriever): except Exception as e: logger.warning(f"Failed to get graph schema: {e}") self._graph_schema = "" + super().__init__(callback_manager) def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" @@ -413,6 +416,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): graph_traversal_depth: int = 2, max_knowledge_sequence: int = REL_TEXT_LIMIT, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize the retriever.""" @@ -485,6 +489,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): except Exception as e: logger.warning(f"Failed to get graph schema: {e}") self._graph_schema = "" + super().__init__(callback_manager) def _process_entities( self, diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index 1ac8c614b0..c3f6dd6e96 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -2,6 +2,7 @@ import logging from typing import Any, Callable, List, Optional, Tuple +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.list.base import SummaryIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings @@ -27,8 +28,14 @@ class SummaryIndexRetriever(BaseRetriever): """ - def __init__(self, index: SummaryIndex, **kwargs: Any) -> None: + def __init__( + self, + index: SummaryIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index + super().__init__(callback_manager) def _retrieve( self, @@ -58,10 +65,12 @@ class SummaryIndexEmbeddingRetriever(BaseRetriever): self, index: SummaryIndex, similarity_top_k: Optional[int] = 1, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index self._similarity_top_k = similarity_top_k + super().__init__(callback_manager) def _retrieve( self, @@ -141,6 +150,7 @@ class SummaryIndexLLMRetriever(BaseRetriever): format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, service_context: Optional[ServiceContext] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -155,6 +165,7 @@ class SummaryIndexLLMRetriever(BaseRetriever): parse_choice_select_answer_fn or default_parse_choice_select_answer_fn ) self._service_context = service_context or index.service_context + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes.""" diff --git a/llama_index/indices/managed/vectara/retriever.py b/llama_index/indices/managed/vectara/retriever.py index 3bc97583d3..1a7e57db4d 100644 --- a/llama_index/indices/managed/vectara/retriever.py +++ b/llama_index/indices/managed/vectara/retriever.py @@ -4,8 +4,9 @@ An index that that is built on top of Vectara. import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.managed.types import ManagedIndexQueryMode @@ -50,6 +51,7 @@ class VectaraRetriever(BaseRetriever): n_sentences_before: int = 2, n_sentences_after: int = 2, filter: str = "", + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -67,6 +69,7 @@ class VectaraRetriever(BaseRetriever): self._mmr_diversity_bias = kwargs.get("mmr_diversity_bias", 0.3) else: self._mmr = False + super().__init__(callback_manager) def _get_post_headers(self) -> dict: """Returns headers that should be attached to each post request.""" diff --git a/llama_index/indices/multi_modal/retriever.py b/llama_index/indices/multi_modal/retriever.py index faeecf5aba..99e26744f2 100644 --- a/llama_index/indices/multi_modal/retriever.py +++ b/llama_index/indices/multi_modal/retriever.py @@ -3,6 +3,7 @@ import asyncio from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import ( MultiModalRetriever, @@ -49,6 +50,7 @@ class MultiModalVectorIndexRetriever(MultiModalRetriever): node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, sparse_top_k: Optional[int] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -73,6 +75,7 @@ class MultiModalVectorIndexRetriever(MultiModalRetriever): self._sparse_top_k = sparse_top_k self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {}) + self.callback_manager = callback_manager or CallbackManager([]) @property def similarity_top_k(self) -> int: diff --git a/llama_index/indices/struct_store/sql_retriever.py b/llama_index/indices/struct_store/sql_retriever.py index 2b3b17e17f..6ae50a0f5b 100644 --- a/llama_index/indices/struct_store/sql_retriever.py +++ b/llama_index/indices/struct_store/sql_retriever.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from sqlalchemy import Table +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.embeddings.base import BaseEmbedding from llama_index.objects.base import ObjectRetriever @@ -39,11 +40,13 @@ class SQLRetriever(BaseRetriever): self, sql_database: SQLDatabase, return_raw: bool = True, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._sql_database = sql_database self._return_raw = return_raw + super().__init__(callback_manager) def _format_node_results( self, results: List[List[Any]], col_keys: List[str] @@ -124,7 +127,10 @@ class DefaultSQLParser(BaseSQLParser): class PGVectorSQLParser(BaseSQLParser): """PGVector SQL Parser.""" - def __init__(self, embed_model: BaseEmbedding) -> None: + def __init__( + self, + embed_model: BaseEmbedding, + ) -> None: """Initialize params.""" self._embed_model = embed_model @@ -180,6 +186,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): service_context: Optional[ServiceContext] = None, return_raw: bool = True, handle_sql_errors: bool = True, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -194,6 +201,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): self._sql_parser_mode = sql_parser_mode self._sql_parser = self._load_sql_parser(sql_parser_mode, self._service_context) self._handle_sql_errors = handle_sql_errors + super().__init__(callback_manager) def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" diff --git a/llama_index/indices/tree/all_leaf_retriever.py b/llama_index/indices/tree/all_leaf_retriever.py index 76d693ff1d..db831f073c 100644 --- a/llama_index/indices/tree/all_leaf_retriever.py +++ b/llama_index/indices/tree/all_leaf_retriever.py @@ -1,8 +1,9 @@ """Summarize query.""" import logging -from typing import Any, List, cast +from typing import Any, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexGraph from llama_index.indices.tree.base import TreeIndex @@ -27,10 +28,16 @@ class TreeAllLeafRetriever(BaseRetriever): """ - def __init__(self, index: TreeIndex, **kwargs: Any) -> None: + def __init__( + self, + index: TreeIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index self._index_struct = index.index_struct self._docstore = index.docstore + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 0ef5858b7c..8606f61499 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -3,7 +3,9 @@ import logging from typing import Any, Dict, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.utils import get_numbered_text_from_nodes from llama_index.indices.utils import ( @@ -70,6 +72,7 @@ class TreeSelectLeafRetriever(BaseRetriever): query_template_multiple: Optional[BasePromptTemplate] = None, child_branch_factor: int = 1, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ): self._index = index @@ -85,6 +88,7 @@ class TreeSelectLeafRetriever(BaseRetriever): ) self.child_branch_factor = child_branch_factor self._verbose = verbose + super().__init__(callback_manager) def _query_with_selected_node( self, diff --git a/llama_index/indices/tree/tree_root_retriever.py b/llama_index/indices/tree/tree_root_retriever.py index a79e334200..58b2ef0cb9 100644 --- a/llama_index/indices/tree/tree_root_retriever.py +++ b/llama_index/indices/tree/tree_root_retriever.py @@ -1,8 +1,10 @@ """Retrieve query.""" import logging -from typing import Any, List +from typing import Any, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list from llama_index.schema import NodeWithScore, QueryBundle @@ -20,10 +22,16 @@ class TreeRootRetriever(BaseRetriever): attempt to parse information down the graph in order to synthesize an answer. """ - def __init__(self, index: TreeIndex, **kwargs: Any) -> None: + def __init__( + self, + index: TreeIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index self._index_struct = index.index_struct self._docstore = index.docstore + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/vector_store/base.py b/llama_index/indices/vector_store/base.py index c8b143e895..5a1b158e53 100644 --- a/llama_index/indices/vector_store/base.py +++ b/llama_index/indices/vector_store/base.py @@ -83,6 +83,7 @@ class VectorStoreIndex(BaseIndex[IndexDict]): return VectorIndexRetriever( self, node_ids=list(self.index_struct.nodes_dict.values()), + callback_manager=self._service_context.callback_manager, **kwargs, ) diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py index c2e661b6d7..4e25b5ca86 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py @@ -1,6 +1,7 @@ import logging from typing import Any, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex @@ -58,6 +59,7 @@ class VectorIndexAutoRetriever(BaseRetriever): max_top_k: int = 10, similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -79,6 +81,7 @@ class VectorIndexAutoRetriever(BaseRetriever): self._similarity_top_k = similarity_top_k self._vector_store_query_mode = vector_store_query_mode self._kwargs = kwargs + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: # prepare input diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py index 64c886913f..2be7db0d51 100644 --- a/llama_index/indices/vector_store/retrievers/retriever.py +++ b/llama_index/indices/vector_store/retrievers/retriever.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict @@ -44,6 +45,7 @@ class VectorIndexRetriever(BaseRetriever): node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, sparse_top_k: Optional[int] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -59,8 +61,8 @@ class VectorIndexRetriever(BaseRetriever): self._doc_ids = doc_ids self._filters = filters self._sparse_top_k = sparse_top_k - self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {}) + super().__init__(callback_manager) @property def similarity_top_k(self) -> int: @@ -94,7 +96,6 @@ class VectorIndexRetriever(BaseRetriever): query_bundle.embedding_strs ) ) - return await self._aget_nodes_with_embeddings(query_bundle) def _build_vector_store_query( diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index 937663b38a..1fa0355e7e 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -167,16 +167,7 @@ class RetrieverQueryEngine(BaseQueryEngine): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - with self.callback_manager.event( - CBEventType.RETRIEVE, - payload={EventPayload.QUERY_STR: query_bundle.query_str}, - ) as retrieve_event: - nodes = self.retrieve(query_bundle) - - retrieve_event.on_end( - payload={EventPayload.NODES: nodes}, - ) - + nodes = self.retrieve(query_bundle) response = self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, @@ -191,15 +182,7 @@ class RetrieverQueryEngine(BaseQueryEngine): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - with self.callback_manager.event( - CBEventType.RETRIEVE, - payload={EventPayload.QUERY_STR: query_bundle.query_str}, - ) as retrieve_event: - nodes = await self.aretrieve(query_bundle) - - retrieve_event.on_end( - payload={EventPayload.NODES: nodes}, - ) + nodes = await self.aretrieve(query_bundle) response = await self._response_synthesizer.asynthesize( query=query_bundle, diff --git a/llama_index/retrievers/auto_merging_retriever.py b/llama_index/retrievers/auto_merging_retriever.py index 365df24d76..f27d4284c5 100644 --- a/llama_index/retrievers/auto_merging_retriever.py +++ b/llama_index/retrievers/auto_merging_retriever.py @@ -2,9 +2,11 @@ import logging from collections import defaultdict -from typing import Dict, List, Tuple, cast +from typing import Dict, List, Optional, Tuple, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.utils import truncate_text from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever from llama_index.schema import BaseNode, NodeWithScore, QueryBundle @@ -27,12 +29,14 @@ class AutoMergingRetriever(BaseRetriever): storage_context: StorageContext, simple_ratio_thresh: float = 0.5, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, ) -> None: """Init params.""" self._vector_retriever = vector_retriever self._storage_context = storage_context self._simple_ratio_thresh = simple_ratio_thresh self._verbose = verbose + super().__init__(callback_manager) def _get_parents_and_merge( self, nodes: List[NodeWithScore] diff --git a/llama_index/retrievers/bm25_retriever.py b/llama_index/retrievers/bm25_retriever.py index bc59da465a..fd7a83af2d 100644 --- a/llama_index/retrievers/bm25_retriever.py +++ b/llama_index/retrievers/bm25_retriever.py @@ -1,6 +1,7 @@ import logging from typing import Callable, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex @@ -17,6 +18,7 @@ class BM25Retriever(BaseRetriever): nodes: List[BaseNode], tokenizer: Optional[Callable[[str], List[str]]], similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, ) -> None: try: from rank_bm25 import BM25Okapi @@ -27,8 +29,8 @@ class BM25Retriever(BaseRetriever): self._tokenizer = tokenizer or (lambda x: x.split(" ")) self._similarity_top_k = similarity_top_k self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] - self.bm25 = BM25Okapi(self._corpus) + super().__init__(callback_manager) @classmethod def from_defaults( diff --git a/llama_index/retrievers/fusion_retriever.py b/llama_index/retrievers/fusion_retriever.py index 51603ef63c..01c86fadbc 100644 --- a/llama_index/retrievers/fusion_retriever.py +++ b/llama_index/retrievers/fusion_retriever.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Dict, List, Optional, Tuple from llama_index.async_utils import run_async_tasks +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.llms.utils import LLMType, resolve_llm from llama_index.retrievers import BaseRetriever @@ -35,6 +36,7 @@ class QueryFusionRetriever(BaseRetriever): num_queries: int = 4, use_async: bool = True, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, ) -> None: self.num_queries = num_queries self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT @@ -45,6 +47,7 @@ class QueryFusionRetriever(BaseRetriever): self._retrievers = retrievers self._llm = resolve_llm(llm) + super().__init__(callback_manager) def _get_queries(self, original_query: str) -> List[str]: prompt_str = self.query_gen_prompt.format( diff --git a/llama_index/retrievers/recursive_retriever.py b/llama_index/retrievers/recursive_retriever.py index e4e031d840..bc5817b1f1 100644 --- a/llama_index/retrievers/recursive_retriever.py +++ b/llama_index/retrievers/recursive_retriever.py @@ -49,7 +49,7 @@ class RecursiveRetriever(BaseRetriever): self._retriever_dict = retriever_dict self._query_engine_dict = query_engine_dict or {} self._node_dict = node_dict or {} - self.callback_manager = callback_manager or CallbackManager([]) + super().__init__(callback_manager) # make sure keys don't overlap if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()): @@ -57,7 +57,7 @@ class RecursiveRetriever(BaseRetriever): self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL self._verbose = verbose - super().__init__() + super().__init__(callback_manager) def _query_retrieved_nodes( self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore] diff --git a/llama_index/retrievers/transform_retriever.py b/llama_index/retrievers/transform_retriever.py index 0dd08176cc..f200f75100 100644 --- a/llama_index/retrievers/transform_retriever.py +++ b/llama_index/retrievers/transform_retriever.py @@ -1,5 +1,6 @@ from typing import List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.prompts.mixin import PromptMixinType @@ -19,10 +20,12 @@ class TransformRetriever(BaseRetriever): retriever: BaseRetriever, query_transform: BaseQueryTransform, transform_metadata: Optional[dict] = None, + callback_manager: Optional[CallbackManager] = None, ) -> None: self._retriever = retriever self._query_transform = query_transform self._transform_metadata = transform_metadata + super().__init__(callback_manager) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" diff --git a/llama_index/retrievers/you_retriever.py b/llama_index/retrievers/you_retriever.py index a964863a18..df042b6ce1 100644 --- a/llama_index/retrievers/you_retriever.py +++ b/llama_index/retrievers/you_retriever.py @@ -6,7 +6,9 @@ from typing import List, Optional import requests +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore, QueryBundle, TextNode logger = logging.getLogger(__name__) @@ -15,9 +17,14 @@ logger = logging.getLogger(__name__) class YouRetriever(BaseRetriever): """You retriever.""" - def __init__(self, api_key: Optional[str] = None) -> None: + def __init__( + self, + api_key: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + ) -> None: """Init params.""" self._api_key = api_key or os.environ["YOU_API_KEY"] + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve.""" -- GitLab