diff --git a/CHANGELOG.md b/CHANGELOG.md index 867854cce6ef72513d28ec05e8dc3395aa9a9a45..833a30e2a8fcb5f32f8edf95b912bfbe8619860b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features - Added `kwargs` to `ComposableGraph` for the underlying query engines (#6990) - Validate openai key on init (#6940) +- Added async embeddings and async RetrieverQueryEngine (#6587) ### Bug Fixes / Nits - Fix achat memory initialization for data agents (#7000) diff --git a/llama_index/embeddings/base.py b/llama_index/embeddings/base.py index 8492c8fe5669a5d84372821fe6364ac3fc39602b..87d7e4a48ed0c9f9806e3100de0ca3947d492484 100644 --- a/llama_index/embeddings/base.py +++ b/llama_index/embeddings/base.py @@ -72,6 +72,10 @@ class BaseEmbedding: def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" + @abstractmethod + async def _aget_query_embedding(self, query: str) -> List[float]: + """Get query embedding asynchronously.""" + def get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING) @@ -88,6 +92,19 @@ class BaseEmbedding: ) return query_embedding + async def aget_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING) + query_embedding = await self._aget_query_embedding(query) + query_tokens_count = len(self._tokenizer(query)) + self._total_tokens_used += query_tokens_count + self.callback_manager.on_event_end( + CBEventType.EMBEDDING, + payload={EventPayload.CHUNKS: [query]}, + event_id=event_id, + ) + return query_embedding + def get_agg_embedding_from_queries( self, queries: List[str], @@ -98,6 +115,16 @@ class BaseEmbedding: agg_fn = agg_fn or mean_agg return agg_fn(query_embeddings) + async def aget_agg_embedding_from_queries( + self, + queries: List[str], + agg_fn: Optional[Callable[..., List[float]]] = None, + ) -> List[float]: + """Get aggregated embedding from multiple queries.""" + query_embeddings = [await self.aget_query_embedding(query) for query in queries] + agg_fn = agg_fn or mean_agg + return agg_fn(query_embeddings) + @abstractmethod def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" diff --git a/llama_index/embeddings/google.py b/llama_index/embeddings/google.py index c092deae28eb983fad418b9397dc172feea09337..26ec3da81dc1ffe857b5464f3352e24f3790011d 100644 --- a/llama_index/embeddings/google.py +++ b/llama_index/embeddings/google.py @@ -2,7 +2,6 @@ from typing import List, Optional - from llama_index.embeddings.base import BaseEmbedding # Google Universal Sentence Encode v5 @@ -28,6 +27,11 @@ class GoogleUnivSentEncoderEmbedding(BaseEmbedding): """Get query embedding.""" return self._get_embedding(query) + # TODO: use proper async methods + async def _aget_text_embedding(self, query: str) -> List[float]: + """Get text embedding.""" + return self._get_embedding(query) + def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" return self._get_embedding(text) diff --git a/llama_index/embeddings/langchain.py b/llama_index/embeddings/langchain.py index e5b743fc51fae7381d007c8187a807b51f1acff8..1a0f304ed9eb27b67e5b782e0aab7ea7203857a7 100644 --- a/llama_index/embeddings/langchain.py +++ b/llama_index/embeddings/langchain.py @@ -4,7 +4,6 @@ from typing import Any, List from llama_index.bridge.langchain import Embeddings as LCEmbeddings - from llama_index.embeddings.base import BaseEmbedding @@ -25,6 +24,13 @@ class LangchainEmbedding(BaseEmbedding): """Get query embedding.""" return self._langchain_embedding.embed_query(query) + async def _aget_query_embedding(self, query: str) -> List[float]: + return await self._langchain_embedding.aembed_query(query) + + async def _aget_text_embedding(self, text: str) -> List[float]: + embeds = await self._langchain_embedding.aembed_documents([text]) + return embeds[0] + def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" return self._langchain_embedding.embed_documents([text])[0] diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py index c2415cef37e5adefd39a37c22dd1c131e097d384..c033b496894b2c01bd46f3c52704a6bc55eae144 100644 --- a/llama_index/embeddings/openai.py +++ b/llama_index/embeddings/openai.py @@ -263,6 +263,15 @@ class OpenAIEmbedding(BaseEmbedding): **self.openai_kwargs, ) + async def _aget_query_embedding(self, query: str) -> List[float]: + """The asynchronous version of _get_query_embedding.""" + return await aget_embedding( + query, + engine=self.query_engine, + deployment_id=self.deployment_name, + **self.openai_kwargs, + ) + def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" return get_embedding( diff --git a/llama_index/indices/base_retriever.py b/llama_index/indices/base_retriever.py index 0eae300b62dd2aa109e70d78b86b75a6ad7d0b9b..5444a6e68ff4ea79582a7da8f51d0493550fde13 100644 --- a/llama_index/indices/base_retriever.py +++ b/llama_index/indices/base_retriever.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import List, Optional -from llama_index.schema import NodeWithScore from llama_index.indices.query.schema import QueryBundle, QueryType from llama_index.indices.service_context import ServiceContext +from llama_index.schema import NodeWithScore class BaseRetriever(ABC): @@ -21,6 +21,12 @@ class BaseRetriever(ABC): str_or_query_bundle = QueryBundle(str_or_query_bundle) return self._retrieve(str_or_query_bundle) + async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: + if isinstance(str_or_query_bundle, str): + str_or_query_bundle = QueryBundle(str_or_query_bundle) + nodes = await self._aretrieve(str_or_query_bundle) + return nodes + @abstractmethod def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes given query. @@ -30,6 +36,16 @@ class BaseRetriever(ABC): """ pass + # TODO: make this abstract + # @abstractmethod + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Asyncronously retrieve nodes given query. + + Implemented by the user. + + """ + return [] + def get_service_context(self) -> Optional[ServiceContext]: """Attempts to resolve a service context. Short-circuits at self.service_context, self._service_context, diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py index 889f3b329aa9c89bfd05fc89bec4305432b10688..530a2e05e642a2597b31f1f393748ead3b7b826d 100644 --- a/llama_index/indices/vector_store/retrievers/retriever.py +++ b/llama_index/indices/vector_store/retrievers/retriever.py @@ -71,13 +71,29 @@ class VectorIndexRetriever(BaseRetriever): query_bundle.embedding_strs ) ) + return self._get_nodes_with_embeddings(query_bundle) + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + if self._vector_store.is_embedding_query: + if query_bundle.embedding is None: + embed_model = self._service_context.embed_model + query_bundle.embedding = ( + await embed_model.aget_agg_embedding_from_queries( + query_bundle.embedding_strs + ) + ) + + return self._get_nodes_with_embeddings(query_bundle) + + def _get_nodes_with_embeddings( + self, query_bundle_with_embeddings: QueryBundle + ) -> List[NodeWithScore]: query = VectorStoreQuery( - query_embedding=query_bundle.embedding, + query_embedding=query_bundle_with_embeddings.embedding, similarity_top_k=self._similarity_top_k, node_ids=self._node_ids, doc_ids=self._doc_ids, - query_str=query_bundle.query_str, + query_str=query_bundle_with_embeddings.query_str, mode=self._vector_store_query_mode, alpha=self._alpha, filters=self._filters, diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index f02de2ce74a6be4e3644f320848e0ed5852bf20a..430684070b0625fd53c164ad1c9b53b5fdc59fe5 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -104,13 +104,22 @@ class RetrieverQueryEngine(BaseQueryEngine): node_postprocessors=node_postprocessors, ) + def _apply_node_postprocessors( + self, nodes: List[NodeWithScore] + ) -> List[NodeWithScore]: + for node_postprocessor in self._node_postprocessors: + nodes = node_postprocessor.postprocess_nodes(nodes) + return nodes + def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = self._retriever.retrieve(query_bundle) + nodes = self._apply_node_postprocessors(nodes) - for node_postprocessor in self._node_postprocessors: - nodes = node_postprocessor.postprocess_nodes( - nodes, query_bundle=query_bundle - ) + return nodes + + async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + nodes = await self._retriever.aretrieve(query_bundle) + nodes = self._apply_node_postprocessors(nodes) return nodes @@ -171,7 +180,9 @@ class RetrieverQueryEngine(BaseQueryEngine): ) retrieve_id = self.callback_manager.on_event_start(CBEventType.RETRIEVE) - nodes = self.retrieve(query_bundle) + + nodes = await self.aretrieve(query_bundle) + self.callback_manager.on_event_end( CBEventType.RETRIEVE, payload={EventPayload.NODES: nodes}, diff --git a/llama_index/token_counter/mock_embed_model.py b/llama_index/token_counter/mock_embed_model.py index 2f13b031f5179ebb3a591406e5e6a6ef5e2769de..cfe3db816ef2e135b79fd7ac763e2d4ad39d13ab 100644 --- a/llama_index/token_counter/mock_embed_model.py +++ b/llama_index/token_counter/mock_embed_model.py @@ -20,10 +20,19 @@ class MockEmbedding(BaseEmbedding): super().__init__(*args, **kwargs) self.embed_dim = embed_dim + def _get_vector(self) -> List[float]: + return [0.5] * self.embed_dim + + async def _aget_text_embedding(self, text: str) -> List[float]: + return self._get_vector() + + async def _aget_query_embedding(self, query: str) -> List[float]: + return self._get_vector() + def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" - return [0.5] * self.embed_dim + return self._get_vector() def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" - return [0.5] * self.embed_dim + return self._get_vector() diff --git a/tests/indices/knowledge_graph/test_base.py b/tests/indices/knowledge_graph/test_base.py index 37ec9a7a240574dcbc7f2bb5f1bab3bfad94c2f4..f7e373f91ae502a9659e6bf6de5d54d89b5476c9 100644 --- a/tests/indices/knowledge_graph/test_base.py +++ b/tests/indices/knowledge_graph/test_base.py @@ -17,6 +17,23 @@ from tests.mock_utils.mock_prompts import ( class MockEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> List[float]: + del query + return [0, 0, 1, 0, 0] + + async def _aget_text_embedding(self, text: str) -> List[float]: + # assume dimensions are 4 + if text == "('foo', 'is', 'bar')": + return [1, 0, 0, 0] + elif text == "('hello', 'is not', 'world')": + return [0, 1, 0, 0] + elif text == "('Jane', 'is mother of', 'Bob')": + return [0, 0, 1, 0] + elif text == "foo": + return [0, 0, 0, 1] + else: + raise ValueError("Invalid text for `mock_get_text_embedding`.") + def _get_text_embedding(self, text: str) -> List[float]: """Mock get text embedding.""" # assume dimensions are 4 diff --git a/tests/indices/query/test_compose_vector.py b/tests/indices/query/test_compose_vector.py index 35a5d353e6dc3fcc3a76b3f3bd56a0b26a97c855..f652efa260eca98cafc73da24da927cca5ed8f4f 100644 --- a/tests/indices/query/test_compose_vector.py +++ b/tests/indices/query/test_compose_vector.py @@ -17,6 +17,37 @@ from tests.mock_utils.mock_prompts import MOCK_QUERY_KEYWORD_EXTRACT_PROMPT class MockEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> List[float]: + if query == "Foo?": + return [0, 0, 1, 0, 0] + elif query == "Orange?": + return [0, 1, 0, 0, 0] + elif query == "Cat?": + return [0, 0, 0, 1, 0] + else: + raise ValueError("Invalid query for `_get_query_embedding`.") + + async def _aget_text_embedding(self, text: str) -> List[float]: + # assume dimensions are 5 + if text == "Hello world.": + return [1, 0, 0, 0, 0] + elif text == "This is a test.": + return [0, 1, 0, 0, 0] + elif text == "This is another test.": + return [0, 0, 1, 0, 0] + elif text == "This is a test v2.": + return [0, 0, 0, 1, 0] + elif text == "foo bar": + return [0, 0, 1, 0, 0] + elif text == "apple orange": + return [0, 1, 0, 0, 0] + elif text == "toronto london": + return [1, 0, 0, 0, 0] + elif text == "cat dog": + return [0, 0, 0, 1, 0] + else: + raise ValueError("Invalid text for `mock_get_text_embedding`.") + def _get_query_embedding(self, query: str) -> List[float]: """Mock get query embedding.""" if query == "Foo?": diff --git a/tests/indices/query/test_query_bundle.py b/tests/indices/query/test_query_bundle.py index 54d36dd425de72627446af97cb5df74fa86dc4e0..ea2901a45300f4d850075a7941fc11809b367c3e 100644 --- a/tests/indices/query/test_query_bundle.py +++ b/tests/indices/query/test_query_bundle.py @@ -26,6 +26,25 @@ def documents() -> List[Document]: class MockEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> List[float]: + text_embed_map: Dict[str, List[float]] = { + "It is what it is.": [1.0, 0.0, 0.0, 0.0, 0.0], + "The meaning of life": [0.0, 1.0, 0.0, 0.0, 0.0], + } + + return text_embed_map[query] + + async def _aget_text_embedding(self, text: str) -> List[float]: + text_embed_map: Dict[str, List[float]] = { + "Correct.": [0.5, 0.5, 0.0, 0.0, 0.0], + "Hello world.": [1.0, 0.0, 0.0, 0.0, 0.0], + "This is a test.": [0.0, 1.0, 0.0, 0.0, 0.0], + "This is another test.": [0.0, 0.0, 1.0, 0.0, 0.0], + "This is a test v2.": [0.0, 0.0, 0.0, 1.0, 0.0], + } + + return text_embed_map[text] + def _get_text_embedding(self, text: str) -> List[float]: """Get node text embedding.""" text_embed_map: Dict[str, List[float]] = { diff --git a/tests/indices/vector_store/mock_services.py b/tests/indices/vector_store/mock_services.py index 4d4b26dc6aabdea444b4dd6f445ce695315111cc..dd570d0d5c3fc97c94452c6324fd40c0ed5c5b9b 100644 --- a/tests/indices/vector_store/mock_services.py +++ b/tests/indices/vector_store/mock_services.py @@ -4,6 +4,30 @@ from llama_index.embeddings.base import BaseEmbedding class MockEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> List[float]: + del query + return [0, 0, 1, 0, 0] + + async def _aget_text_embedding(self, text: str) -> List[float]: + # assume dimensions are 5 + if text == "Hello world.": + return [1, 0, 0, 0, 0] + elif text == "This is a test.": + return [0, 1, 0, 0, 0] + elif text == "This is another test.": + return [0, 0, 1, 0, 0] + elif text == "This is a test v2.": + return [0, 0, 0, 1, 0] + elif text == "This is a test v3.": + return [0, 0, 0, 0, 1] + elif text == "This is bar test.": + return [0, 0, 1, 0, 0] + elif text == "Hello world backup.": + # this is used when "Hello world." is deleted. + return [1, 0, 0, 0, 0] + else: + return [0, 0, 0, 0, 0] + def _get_query_embedding(self, query: str) -> List[float]: del query # Unused return [0, 0, 1, 0, 0] diff --git a/tests/playground/test_base.py b/tests/playground/test_base.py index 83ff82440aa10c9f2b6eb3b70bb9b738b32fe7c1..b16355f1b23cbf5dc6e98cf0bcf6ab67de329a5c 100644 --- a/tests/playground/test_base.py +++ b/tests/playground/test_base.py @@ -14,6 +14,21 @@ from llama_index.schema import Document class MockEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> List[float]: + del query + return [0, 0, 1, 0, 0] + + async def _aget_text_embedding(self, text: str) -> List[float]: + # assume dimensions are 5 + if text == "They're taking the Hobbits to Isengard!": + return [1, 0, 0, 0, 0] + elif text == "I can't carry it for you.": + return [0, 1, 0, 0, 0] + elif text == "But I can carry you!": + return [0, 0, 1, 0, 0] + else: + raise ValueError("Invalid text for `mock_get_text_embedding`.") + def _get_text_embedding(self, text: str) -> List[float]: """Mock get text embedding.""" # assume dimensions are 5