diff --git a/CHANGELOG.md b/CHANGELOG.md index 9374ca82920dd5775d5bd687f589881196622b9e..79ac934c7b610ae957327fcad25429afa46052ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### New Features + +- Added callback manager to each retriever (#8871) + ### Bug Fixes / Nits - Fixed bug in formatting chat prompt templates when estimating chunk sizes (#9025) diff --git a/experimental/colbert_index/retriever.py b/experimental/colbert_index/retriever.py index 6473f23a2da646a867a19d1db654f1724a645502..199dfa7860959fb98456093fefb037880484e0eb 100644 --- a/experimental/colbert_index/retriever.py +++ b/experimental/colbert_index/retriever.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.schema import NodeWithScore, QueryBundle @@ -28,19 +29,19 @@ class ColbertRetriever(BaseRetriever): filters: Optional[MetadataFilters] = None, node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._index = index self._service_context = self._index.service_context self._docstore = self._index.docstore - self._similarity_top_k = similarity_top_k self._node_ids = node_ids self._doc_ids = doc_ids self._filters = filters - self._kwargs: Dict[str, Any] = kwargs.get("colbert_kwargs", {}) + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/core/base_retriever.py b/llama_index/core/base_retriever.py index baf22e316004884510621362230ab8a688300d76..2655d1788cdb8cdb5069c6e8d18a11761ba8d3ea 100644 --- a/llama_index/core/base_retriever.py +++ b/llama_index/core/base_retriever.py @@ -2,14 +2,20 @@ from abc import abstractmethod from typing import List, Optional +from llama_index.callbacks.base import CallbackManager +from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.indices.query.schema import QueryBundle, QueryType +from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.schema import NodeWithScore, QueryBundle, QueryType -from llama_index.service_context import ServiceContext +from llama_index.schema import NodeWithScore class BaseRetriever(PromptMixin): """Base retriever.""" + def __init__(self, callback_manager: Optional[CallbackManager]) -> None: + self.callback_manager = callback_manager or CallbackManager([]) + def _get_prompts(self) -> PromptDictType: """Get prompts.""" return {} @@ -30,13 +36,35 @@ class BaseRetriever(PromptMixin): """ if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return self._retrieve(str_or_query_bundle) + query_bundle = QueryBundle(str_or_query_bundle) + else: + query_bundle = str_or_query_bundle + with self.callback_manager.as_trace("query"): + with self.callback_manager.event( + CBEventType.RETRIEVE, + payload={EventPayload.QUERY_STR: query_bundle.query_str}, + ) as retrieve_event: + nodes = self._retrieve(query_bundle) + retrieve_event.on_end( + payload={EventPayload.NODES: nodes}, + ) + return nodes async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return await self._aretrieve(str_or_query_bundle) + query_bundle = QueryBundle(str_or_query_bundle) + else: + query_bundle = str_or_query_bundle + with self.callback_manager.as_trace("query"): + with self.callback_manager.event( + CBEventType.RETRIEVE, + payload={EventPayload.QUERY_STR: query_bundle.query_str}, + ) as retrieve_event: + nodes = await self._aretrieve(query_bundle) + retrieve_event.on_end( + payload={EventPayload.NODES: nodes}, + ) + return nodes @abstractmethod def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index 5dc3e47d11c7e71800a0f1f7b786f00c06446644..5c1752216599a24dbe8834c910f468c3c67c9e23 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -7,6 +7,7 @@ This module contains retrievers for document summary indices. import logging from typing import Any, Callable, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.document_summary.base import DocumentSummaryIndex from llama_index.indices.utils import ( @@ -46,6 +47,7 @@ class DocumentSummaryIndexLLMRetriever(BaseRetriever): format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, service_context: Optional[ServiceContext] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -61,6 +63,7 @@ class DocumentSummaryIndexLLMRetriever(BaseRetriever): parse_choice_select_answer_fn or default_parse_choice_select_answer_fn ) self._service_context = service_context or index.service_context + super().__init__(callback_manager) def _retrieve( self, @@ -115,7 +118,11 @@ class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever): """ def __init__( - self, index: DocumentSummaryIndex, similarity_top_k: int = 1, **kwargs: Any + self, + index: DocumentSummaryIndex, + similarity_top_k: int = 1, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, ) -> None: """Init params.""" self._index = index @@ -123,8 +130,8 @@ class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever): self._service_context = self._index.service_context self._docstore = self._index.docstore self._index_struct = self._index.index_struct - self._similarity_top_k = similarity_top_k + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/empty/retrievers.py b/llama_index/indices/empty/retrievers.py index f5aeed4baa6c7e64fe62caf2046b27591c86a263..e79532bc574f07c20194b6aff7883ab5d8a470a8 100644 --- a/llama_index/indices/empty/retrievers.py +++ b/llama_index/indices/empty/retrievers.py @@ -1,6 +1,7 @@ """Default query for EmptyIndex.""" from typing import Any, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.empty.base import EmptyIndex from llama_index.prompts import BasePromptTemplate @@ -23,11 +24,13 @@ class EmptyIndexRetriever(BaseRetriever): self, index: EmptyIndex, input_prompt: Optional[BasePromptTemplate] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._index = index self._input_prompt = input_prompt or DEFAULT_SIMPLE_INPUT_PROMPT + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve relevant nodes.""" diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index 83f43e732daccee4c577ed3e3aa8e91e6438e4f9..af216c3550312b7ea7f38a965290f81d1b2f8eb0 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -4,6 +4,7 @@ from abc import abstractmethod from collections import defaultdict from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.base import BaseKeywordTableIndex from llama_index.indices.keyword_table.utils import ( @@ -52,6 +53,7 @@ class BaseKeywordTableRetriever(BaseRetriever): query_keyword_extract_template: Optional[BasePromptTemplate] = None, max_keywords_per_query: int = 10, num_chunks_per_query: int = 10, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -66,6 +68,7 @@ class BaseKeywordTableRetriever(BaseRetriever): keyword_extract_template or DEFAULT_KEYWORD_EXTRACT_TEMPLATE ) self.query_keyword_extract_template = query_keyword_extract_template or DQKET + super().__init__(callback_manager) @abstractmethod def _get_keywords(self, query_str: str) -> List[str]: diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index 7a6490c67f9b772b721d8127fd9e0893a7b414f7..72fe5efe5683620883a3af8436233162c460646f 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -4,6 +4,7 @@ from collections import defaultdict from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex @@ -89,6 +90,7 @@ class KGTableRetriever(BaseRetriever): graph_store_query_depth: int = 2, use_global_node_triplets: bool = False, max_knowledge_sequence: int = REL_TEXT_LIMIT, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -118,6 +120,7 @@ class KGTableRetriever(BaseRetriever): except Exception as e: logger.warning(f"Failed to get graph schema: {e}") self._graph_schema = "" + super().__init__(callback_manager) def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" @@ -413,6 +416,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): graph_traversal_depth: int = 2, max_knowledge_sequence: int = REL_TEXT_LIMIT, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize the retriever.""" @@ -485,6 +489,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): except Exception as e: logger.warning(f"Failed to get graph schema: {e}") self._graph_schema = "" + super().__init__(callback_manager) def _process_entities( self, diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index 1ac8c614b038466132a515b90030402c2526eb43..c3f6dd6e963129936f426d7556c7759586f5de73 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -2,6 +2,7 @@ import logging from typing import Any, Callable, List, Optional, Tuple +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.list.base import SummaryIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings @@ -27,8 +28,14 @@ class SummaryIndexRetriever(BaseRetriever): """ - def __init__(self, index: SummaryIndex, **kwargs: Any) -> None: + def __init__( + self, + index: SummaryIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index + super().__init__(callback_manager) def _retrieve( self, @@ -58,10 +65,12 @@ class SummaryIndexEmbeddingRetriever(BaseRetriever): self, index: SummaryIndex, similarity_top_k: Optional[int] = 1, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index self._similarity_top_k = similarity_top_k + super().__init__(callback_manager) def _retrieve( self, @@ -141,6 +150,7 @@ class SummaryIndexLLMRetriever(BaseRetriever): format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, service_context: Optional[ServiceContext] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -155,6 +165,7 @@ class SummaryIndexLLMRetriever(BaseRetriever): parse_choice_select_answer_fn or default_parse_choice_select_answer_fn ) self._service_context = service_context or index.service_context + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes.""" diff --git a/llama_index/indices/managed/vectara/retriever.py b/llama_index/indices/managed/vectara/retriever.py index 3bc97583d3273e59b32d44f804ab0127f79f5b7a..1a7e57db4d263e1cce19b9907341d9905b0ba854 100644 --- a/llama_index/indices/managed/vectara/retriever.py +++ b/llama_index/indices/managed/vectara/retriever.py @@ -4,8 +4,9 @@ An index that that is built on top of Vectara. import json import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.managed.types import ManagedIndexQueryMode @@ -50,6 +51,7 @@ class VectaraRetriever(BaseRetriever): n_sentences_before: int = 2, n_sentences_after: int = 2, filter: str = "", + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -67,6 +69,7 @@ class VectaraRetriever(BaseRetriever): self._mmr_diversity_bias = kwargs.get("mmr_diversity_bias", 0.3) else: self._mmr = False + super().__init__(callback_manager) def _get_post_headers(self) -> dict: """Returns headers that should be attached to each post request.""" diff --git a/llama_index/indices/multi_modal/retriever.py b/llama_index/indices/multi_modal/retriever.py index faeecf5aba77b9abe2bd428d605a6e43ba324061..99e26744f2918678641eee0f7abdca8eef298f55 100644 --- a/llama_index/indices/multi_modal/retriever.py +++ b/llama_index/indices/multi_modal/retriever.py @@ -3,6 +3,7 @@ import asyncio from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import ( MultiModalRetriever, @@ -49,6 +50,7 @@ class MultiModalVectorIndexRetriever(MultiModalRetriever): node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, sparse_top_k: Optional[int] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -73,6 +75,7 @@ class MultiModalVectorIndexRetriever(MultiModalRetriever): self._sparse_top_k = sparse_top_k self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {}) + self.callback_manager = callback_manager or CallbackManager([]) @property def similarity_top_k(self) -> int: diff --git a/llama_index/indices/struct_store/sql_retriever.py b/llama_index/indices/struct_store/sql_retriever.py index 2b3b17e17f54974ee32e9ad1c9851570efeede4e..6ae50a0f5b736ad88eaf16357aac9b5d2e8625b7 100644 --- a/llama_index/indices/struct_store/sql_retriever.py +++ b/llama_index/indices/struct_store/sql_retriever.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from sqlalchemy import Table +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.embeddings.base import BaseEmbedding from llama_index.objects.base import ObjectRetriever @@ -39,11 +40,13 @@ class SQLRetriever(BaseRetriever): self, sql_database: SQLDatabase, return_raw: bool = True, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" self._sql_database = sql_database self._return_raw = return_raw + super().__init__(callback_manager) def _format_node_results( self, results: List[List[Any]], col_keys: List[str] @@ -124,7 +127,10 @@ class DefaultSQLParser(BaseSQLParser): class PGVectorSQLParser(BaseSQLParser): """PGVector SQL Parser.""" - def __init__(self, embed_model: BaseEmbedding) -> None: + def __init__( + self, + embed_model: BaseEmbedding, + ) -> None: """Initialize params.""" self._embed_model = embed_model @@ -180,6 +186,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): service_context: Optional[ServiceContext] = None, return_raw: bool = True, handle_sql_errors: bool = True, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -194,6 +201,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): self._sql_parser_mode = sql_parser_mode self._sql_parser = self._load_sql_parser(sql_parser_mode, self._service_context) self._handle_sql_errors = handle_sql_errors + super().__init__(callback_manager) def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" diff --git a/llama_index/indices/tree/all_leaf_retriever.py b/llama_index/indices/tree/all_leaf_retriever.py index 76d693ff1de95fd8b41c2406186ae3debcf746cc..db831f073c1776b8e3bb8e0445d0ea0e366d08d3 100644 --- a/llama_index/indices/tree/all_leaf_retriever.py +++ b/llama_index/indices/tree/all_leaf_retriever.py @@ -1,8 +1,9 @@ """Summarize query.""" import logging -from typing import Any, List, cast +from typing import Any, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexGraph from llama_index.indices.tree.base import TreeIndex @@ -27,10 +28,16 @@ class TreeAllLeafRetriever(BaseRetriever): """ - def __init__(self, index: TreeIndex, **kwargs: Any) -> None: + def __init__( + self, + index: TreeIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index self._index_struct = index.index_struct self._docstore = index.docstore + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 0ef5858b7cbd3daff552f39522a9f36e8a157ce2..8606f614991be2ac2839da03b6d8b93602e7a6aa 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -3,7 +3,9 @@ import logging from typing import Any, Dict, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.utils import get_numbered_text_from_nodes from llama_index.indices.utils import ( @@ -70,6 +72,7 @@ class TreeSelectLeafRetriever(BaseRetriever): query_template_multiple: Optional[BasePromptTemplate] = None, child_branch_factor: int = 1, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ): self._index = index @@ -85,6 +88,7 @@ class TreeSelectLeafRetriever(BaseRetriever): ) self.child_branch_factor = child_branch_factor self._verbose = verbose + super().__init__(callback_manager) def _query_with_selected_node( self, diff --git a/llama_index/indices/tree/tree_root_retriever.py b/llama_index/indices/tree/tree_root_retriever.py index a79e33420017c318ae8d4b73b0cbc8333d45e827..58b2ef0cb97a4be845a49659a7dd4f0f1eb8b55f 100644 --- a/llama_index/indices/tree/tree_root_retriever.py +++ b/llama_index/indices/tree/tree_root_retriever.py @@ -1,8 +1,10 @@ """Retrieve query.""" import logging -from typing import Any, List +from typing import Any, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list from llama_index.schema import NodeWithScore, QueryBundle @@ -20,10 +22,16 @@ class TreeRootRetriever(BaseRetriever): attempt to parse information down the graph in order to synthesize an answer. """ - def __init__(self, index: TreeIndex, **kwargs: Any) -> None: + def __init__( + self, + index: TreeIndex, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ) -> None: self._index = index self._index_struct = index.index_struct self._docstore = index.docstore + super().__init__(callback_manager) def _retrieve( self, diff --git a/llama_index/indices/vector_store/base.py b/llama_index/indices/vector_store/base.py index c8b143e89536d461e481016cb6797407966a1fe8..5a1b158e532134e9daa173cc07029b42aaf22e22 100644 --- a/llama_index/indices/vector_store/base.py +++ b/llama_index/indices/vector_store/base.py @@ -83,6 +83,7 @@ class VectorStoreIndex(BaseIndex[IndexDict]): return VectorIndexRetriever( self, node_ids=list(self.index_struct.nodes_dict.values()), + callback_manager=self._service_context.callback_manager, **kwargs, ) diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py index c2e661b6d725a04ead62454405f474ca01fb9261..4e25b5ca866271b4f596e5dd63362b7f0193c47c 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py @@ -1,6 +1,7 @@ import logging from typing import Any, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex @@ -58,6 +59,7 @@ class VectorIndexAutoRetriever(BaseRetriever): max_top_k: int = 10, similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._index = index @@ -79,6 +81,7 @@ class VectorIndexAutoRetriever(BaseRetriever): self._similarity_top_k = similarity_top_k self._vector_store_query_mode = vector_store_query_mode self._kwargs = kwargs + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: # prepare input diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py index 64c886913f014b73f06a5a455558d8622d92860b..2be7db0d51cb5ebb3c6f452c2265d5b6b25608c7 100644 --- a/llama_index/indices/vector_store/retrievers/retriever.py +++ b/llama_index/indices/vector_store/retrievers/retriever.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict @@ -44,6 +45,7 @@ class VectorIndexRetriever(BaseRetriever): node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, sparse_top_k: Optional[int] = None, + callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: """Initialize params.""" @@ -59,8 +61,8 @@ class VectorIndexRetriever(BaseRetriever): self._doc_ids = doc_ids self._filters = filters self._sparse_top_k = sparse_top_k - self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {}) + super().__init__(callback_manager) @property def similarity_top_k(self) -> int: @@ -94,7 +96,6 @@ class VectorIndexRetriever(BaseRetriever): query_bundle.embedding_strs ) ) - return await self._aget_nodes_with_embeddings(query_bundle) def _build_vector_store_query( diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index 937663b38add16d8757f669ab0c67e69305a3626..1fa0355e7e209c3bbaafb981d238908e1ae5f363 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -167,16 +167,7 @@ class RetrieverQueryEngine(BaseQueryEngine): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - with self.callback_manager.event( - CBEventType.RETRIEVE, - payload={EventPayload.QUERY_STR: query_bundle.query_str}, - ) as retrieve_event: - nodes = self.retrieve(query_bundle) - - retrieve_event.on_end( - payload={EventPayload.NODES: nodes}, - ) - + nodes = self.retrieve(query_bundle) response = self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, @@ -191,15 +182,7 @@ class RetrieverQueryEngine(BaseQueryEngine): with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: - with self.callback_manager.event( - CBEventType.RETRIEVE, - payload={EventPayload.QUERY_STR: query_bundle.query_str}, - ) as retrieve_event: - nodes = await self.aretrieve(query_bundle) - - retrieve_event.on_end( - payload={EventPayload.NODES: nodes}, - ) + nodes = await self.aretrieve(query_bundle) response = await self._response_synthesizer.asynthesize( query=query_bundle, diff --git a/llama_index/retrievers/auto_merging_retriever.py b/llama_index/retrievers/auto_merging_retriever.py index 365df24d769ddc6f2bde8f111902d122a973ee27..f27d4284c58e429297067aca2f31d0fba09f4e5e 100644 --- a/llama_index/retrievers/auto_merging_retriever.py +++ b/llama_index/retrievers/auto_merging_retriever.py @@ -2,9 +2,11 @@ import logging from collections import defaultdict -from typing import Dict, List, Tuple, cast +from typing import Dict, List, Optional, Tuple, cast +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.indices.utils import truncate_text from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever from llama_index.schema import BaseNode, NodeWithScore, QueryBundle @@ -27,12 +29,14 @@ class AutoMergingRetriever(BaseRetriever): storage_context: StorageContext, simple_ratio_thresh: float = 0.5, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, ) -> None: """Init params.""" self._vector_retriever = vector_retriever self._storage_context = storage_context self._simple_ratio_thresh = simple_ratio_thresh self._verbose = verbose + super().__init__(callback_manager) def _get_parents_and_merge( self, nodes: List[NodeWithScore] diff --git a/llama_index/retrievers/bm25_retriever.py b/llama_index/retrievers/bm25_retriever.py index bc59da465a2ffca412617e4506feddac5b9d82f0..fd7a83af2d1e317c3970e88c0f0ef46a26d08ebb 100644 --- a/llama_index/retrievers/bm25_retriever.py +++ b/llama_index/retrievers/bm25_retriever.py @@ -1,6 +1,7 @@ import logging from typing import Callable, List, Optional, cast +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex @@ -17,6 +18,7 @@ class BM25Retriever(BaseRetriever): nodes: List[BaseNode], tokenizer: Optional[Callable[[str], List[str]]], similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, ) -> None: try: from rank_bm25 import BM25Okapi @@ -27,8 +29,8 @@ class BM25Retriever(BaseRetriever): self._tokenizer = tokenizer or (lambda x: x.split(" ")) self._similarity_top_k = similarity_top_k self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] - self.bm25 = BM25Okapi(self._corpus) + super().__init__(callback_manager) @classmethod def from_defaults( diff --git a/llama_index/retrievers/fusion_retriever.py b/llama_index/retrievers/fusion_retriever.py index 51603ef63c67b50466de428c03cd2c26af3ea549..01c86fadbc2c25ce7e844500b5bef2d43f9b7605 100644 --- a/llama_index/retrievers/fusion_retriever.py +++ b/llama_index/retrievers/fusion_retriever.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Dict, List, Optional, Tuple from llama_index.async_utils import run_async_tasks +from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.llms.utils import LLMType, resolve_llm from llama_index.retrievers import BaseRetriever @@ -35,6 +36,7 @@ class QueryFusionRetriever(BaseRetriever): num_queries: int = 4, use_async: bool = True, verbose: bool = False, + callback_manager: Optional[CallbackManager] = None, ) -> None: self.num_queries = num_queries self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT @@ -45,6 +47,7 @@ class QueryFusionRetriever(BaseRetriever): self._retrievers = retrievers self._llm = resolve_llm(llm) + super().__init__(callback_manager) def _get_queries(self, original_query: str) -> List[str]: prompt_str = self.query_gen_prompt.format( diff --git a/llama_index/retrievers/recursive_retriever.py b/llama_index/retrievers/recursive_retriever.py index e4e031d8404eb52f79218a1bd60526f76b6e173a..bc5817b1f155496bff73e3526c491727c0b072e1 100644 --- a/llama_index/retrievers/recursive_retriever.py +++ b/llama_index/retrievers/recursive_retriever.py @@ -49,7 +49,7 @@ class RecursiveRetriever(BaseRetriever): self._retriever_dict = retriever_dict self._query_engine_dict = query_engine_dict or {} self._node_dict = node_dict or {} - self.callback_manager = callback_manager or CallbackManager([]) + super().__init__(callback_manager) # make sure keys don't overlap if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()): @@ -57,7 +57,7 @@ class RecursiveRetriever(BaseRetriever): self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL self._verbose = verbose - super().__init__() + super().__init__(callback_manager) def _query_retrieved_nodes( self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore] diff --git a/llama_index/retrievers/transform_retriever.py b/llama_index/retrievers/transform_retriever.py index 0dd08176ccf3ac57fcbb38bef94a3678dfba87c8..f200f751005bd3f7c0eddf002ecf2df4d2e6ad4c 100644 --- a/llama_index/retrievers/transform_retriever.py +++ b/llama_index/retrievers/transform_retriever.py @@ -1,5 +1,6 @@ from typing import List, Optional +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.prompts.mixin import PromptMixinType @@ -19,10 +20,12 @@ class TransformRetriever(BaseRetriever): retriever: BaseRetriever, query_transform: BaseQueryTransform, transform_metadata: Optional[dict] = None, + callback_manager: Optional[CallbackManager] = None, ) -> None: self._retriever = retriever self._query_transform = query_transform self._transform_metadata = transform_metadata + super().__init__(callback_manager) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" diff --git a/llama_index/retrievers/you_retriever.py b/llama_index/retrievers/you_retriever.py index a964863a185ebc79221f442a47cb3559642f5497..df042b6ce1a0798f210d5d960357f5f4dcd7065f 100644 --- a/llama_index/retrievers/you_retriever.py +++ b/llama_index/retrievers/you_retriever.py @@ -6,7 +6,9 @@ from typing import List, Optional import requests +from llama_index.callbacks.base import CallbackManager from llama_index.core import BaseRetriever +from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore, QueryBundle, TextNode logger = logging.getLogger(__name__) @@ -15,9 +17,14 @@ logger = logging.getLogger(__name__) class YouRetriever(BaseRetriever): """You retriever.""" - def __init__(self, api_key: Optional[str] = None) -> None: + def __init__( + self, + api_key: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + ) -> None: """Init params.""" self._api_key = api_key or os.environ["YOU_API_KEY"] + super().__init__(callback_manager) def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve."""