From 5671177d480ce178a278856bc27c785b69ceed57 Mon Sep 17 00:00:00 2001
From: Jithin James <jamesjithin97@gmail.com>
Date: Mon, 24 Jul 2023 09:33:48 +0530
Subject: [PATCH] feat: async retriever (embeddings) (#6587)

---
 CHANGELOG.md                                  |  1 +
 llama_index/embeddings/base.py                | 27 ++++++++++++++++
 llama_index/embeddings/google.py              |  6 +++-
 llama_index/embeddings/langchain.py           |  8 ++++-
 llama_index/embeddings/openai.py              |  9 ++++++
 llama_index/indices/base_retriever.py         | 18 ++++++++++-
 .../vector_store/retrievers/retriever.py      | 20 ++++++++++--
 .../query_engine/retriever_query_engine.py    | 21 ++++++++++---
 llama_index/token_counter/mock_embed_model.py | 13 ++++++--
 tests/indices/knowledge_graph/test_base.py    | 17 ++++++++++
 tests/indices/query/test_compose_vector.py    | 31 +++++++++++++++++++
 tests/indices/query/test_query_bundle.py      | 19 ++++++++++++
 tests/indices/vector_store/mock_services.py   | 24 ++++++++++++++
 tests/playground/test_base.py                 | 15 +++++++++
 14 files changed, 217 insertions(+), 12 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 867854cce6..833a30e2a8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,7 @@
 ### New Features
 - Added `kwargs` to `ComposableGraph` for the underlying query engines (#6990)
 - Validate openai key on init (#6940)
+- Added async embeddings and async RetrieverQueryEngine (#6587)
 
 ### Bug Fixes / Nits
 - Fix achat memory initialization for data agents (#7000)
diff --git a/llama_index/embeddings/base.py b/llama_index/embeddings/base.py
index 8492c8fe56..87d7e4a48e 100644
--- a/llama_index/embeddings/base.py
+++ b/llama_index/embeddings/base.py
@@ -72,6 +72,10 @@ class BaseEmbedding:
     def _get_query_embedding(self, query: str) -> List[float]:
         """Get query embedding."""
 
+    @abstractmethod
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        """Get query embedding asynchronously."""
+
     def get_query_embedding(self, query: str) -> List[float]:
         """Get query embedding."""
         event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING)
@@ -88,6 +92,19 @@ class BaseEmbedding:
         )
         return query_embedding
 
+    async def aget_query_embedding(self, query: str) -> List[float]:
+        """Get query embedding."""
+        event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING)
+        query_embedding = await self._aget_query_embedding(query)
+        query_tokens_count = len(self._tokenizer(query))
+        self._total_tokens_used += query_tokens_count
+        self.callback_manager.on_event_end(
+            CBEventType.EMBEDDING,
+            payload={EventPayload.CHUNKS: [query]},
+            event_id=event_id,
+        )
+        return query_embedding
+
     def get_agg_embedding_from_queries(
         self,
         queries: List[str],
@@ -98,6 +115,16 @@ class BaseEmbedding:
         agg_fn = agg_fn or mean_agg
         return agg_fn(query_embeddings)
 
+    async def aget_agg_embedding_from_queries(
+        self,
+        queries: List[str],
+        agg_fn: Optional[Callable[..., List[float]]] = None,
+    ) -> List[float]:
+        """Get aggregated embedding from multiple queries."""
+        query_embeddings = [await self.aget_query_embedding(query) for query in queries]
+        agg_fn = agg_fn or mean_agg
+        return agg_fn(query_embeddings)
+
     @abstractmethod
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
diff --git a/llama_index/embeddings/google.py b/llama_index/embeddings/google.py
index c092deae28..26ec3da81d 100644
--- a/llama_index/embeddings/google.py
+++ b/llama_index/embeddings/google.py
@@ -2,7 +2,6 @@
 
 from typing import List, Optional
 
-
 from llama_index.embeddings.base import BaseEmbedding
 
 # Google Universal Sentence Encode v5
@@ -28,6 +27,11 @@ class GoogleUnivSentEncoderEmbedding(BaseEmbedding):
         """Get query embedding."""
         return self._get_embedding(query)
 
+    # TODO: use proper async methods
+    async def _aget_text_embedding(self, query: str) -> List[float]:
+        """Get text embedding."""
+        return self._get_embedding(query)
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
         return self._get_embedding(text)
diff --git a/llama_index/embeddings/langchain.py b/llama_index/embeddings/langchain.py
index e5b743fc51..1a0f304ed9 100644
--- a/llama_index/embeddings/langchain.py
+++ b/llama_index/embeddings/langchain.py
@@ -4,7 +4,6 @@
 from typing import Any, List
 
 from llama_index.bridge.langchain import Embeddings as LCEmbeddings
-
 from llama_index.embeddings.base import BaseEmbedding
 
 
@@ -25,6 +24,13 @@ class LangchainEmbedding(BaseEmbedding):
         """Get query embedding."""
         return self._langchain_embedding.embed_query(query)
 
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        return await self._langchain_embedding.aembed_query(query)
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        embeds = await self._langchain_embedding.aembed_documents([text])
+        return embeds[0]
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
         return self._langchain_embedding.embed_documents([text])[0]
diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py
index c2415cef37..c033b49689 100644
--- a/llama_index/embeddings/openai.py
+++ b/llama_index/embeddings/openai.py
@@ -263,6 +263,15 @@ class OpenAIEmbedding(BaseEmbedding):
             **self.openai_kwargs,
         )
 
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        """The asynchronous version of _get_query_embedding."""
+        return await aget_embedding(
+            query,
+            engine=self.query_engine,
+            deployment_id=self.deployment_name,
+            **self.openai_kwargs,
+        )
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
         return get_embedding(
diff --git a/llama_index/indices/base_retriever.py b/llama_index/indices/base_retriever.py
index 0eae300b62..5444a6e68f 100644
--- a/llama_index/indices/base_retriever.py
+++ b/llama_index/indices/base_retriever.py
@@ -1,9 +1,9 @@
 from abc import ABC, abstractmethod
 from typing import List, Optional
 
-from llama_index.schema import NodeWithScore
 from llama_index.indices.query.schema import QueryBundle, QueryType
 from llama_index.indices.service_context import ServiceContext
+from llama_index.schema import NodeWithScore
 
 
 class BaseRetriever(ABC):
@@ -21,6 +21,12 @@ class BaseRetriever(ABC):
             str_or_query_bundle = QueryBundle(str_or_query_bundle)
         return self._retrieve(str_or_query_bundle)
 
+    async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
+        if isinstance(str_or_query_bundle, str):
+            str_or_query_bundle = QueryBundle(str_or_query_bundle)
+        nodes = await self._aretrieve(str_or_query_bundle)
+        return nodes
+
     @abstractmethod
     def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
         """Retrieve nodes given query.
@@ -30,6 +36,16 @@ class BaseRetriever(ABC):
         """
         pass
 
+    # TODO: make this abstract
+    # @abstractmethod
+    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        """Asyncronously retrieve nodes given query.
+
+        Implemented by the user.
+
+        """
+        return []
+
     def get_service_context(self) -> Optional[ServiceContext]:
         """Attempts to resolve a service context.
         Short-circuits at self.service_context, self._service_context,
diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py
index 889f3b329a..530a2e05e6 100644
--- a/llama_index/indices/vector_store/retrievers/retriever.py
+++ b/llama_index/indices/vector_store/retrievers/retriever.py
@@ -71,13 +71,29 @@ class VectorIndexRetriever(BaseRetriever):
                         query_bundle.embedding_strs
                     )
                 )
+        return self._get_nodes_with_embeddings(query_bundle)
 
+    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        if self._vector_store.is_embedding_query:
+            if query_bundle.embedding is None:
+                embed_model = self._service_context.embed_model
+                query_bundle.embedding = (
+                    await embed_model.aget_agg_embedding_from_queries(
+                        query_bundle.embedding_strs
+                    )
+                )
+
+        return self._get_nodes_with_embeddings(query_bundle)
+
+    def _get_nodes_with_embeddings(
+        self, query_bundle_with_embeddings: QueryBundle
+    ) -> List[NodeWithScore]:
         query = VectorStoreQuery(
-            query_embedding=query_bundle.embedding,
+            query_embedding=query_bundle_with_embeddings.embedding,
             similarity_top_k=self._similarity_top_k,
             node_ids=self._node_ids,
             doc_ids=self._doc_ids,
-            query_str=query_bundle.query_str,
+            query_str=query_bundle_with_embeddings.query_str,
             mode=self._vector_store_query_mode,
             alpha=self._alpha,
             filters=self._filters,
diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py
index f02de2ce74..430684070b 100644
--- a/llama_index/query_engine/retriever_query_engine.py
+++ b/llama_index/query_engine/retriever_query_engine.py
@@ -104,13 +104,22 @@ class RetrieverQueryEngine(BaseQueryEngine):
             node_postprocessors=node_postprocessors,
         )
 
+    def _apply_node_postprocessors(
+        self, nodes: List[NodeWithScore]
+    ) -> List[NodeWithScore]:
+        for node_postprocessor in self._node_postprocessors:
+            nodes = node_postprocessor.postprocess_nodes(nodes)
+        return nodes
+
     def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
         nodes = self._retriever.retrieve(query_bundle)
+        nodes = self._apply_node_postprocessors(nodes)
 
-        for node_postprocessor in self._node_postprocessors:
-            nodes = node_postprocessor.postprocess_nodes(
-                nodes, query_bundle=query_bundle
-            )
+        return nodes
+
+    async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        nodes = await self._retriever.aretrieve(query_bundle)
+        nodes = self._apply_node_postprocessors(nodes)
 
         return nodes
 
@@ -171,7 +180,9 @@ class RetrieverQueryEngine(BaseQueryEngine):
         )
 
         retrieve_id = self.callback_manager.on_event_start(CBEventType.RETRIEVE)
-        nodes = self.retrieve(query_bundle)
+
+        nodes = await self.aretrieve(query_bundle)
+
         self.callback_manager.on_event_end(
             CBEventType.RETRIEVE,
             payload={EventPayload.NODES: nodes},
diff --git a/llama_index/token_counter/mock_embed_model.py b/llama_index/token_counter/mock_embed_model.py
index 2f13b031f5..cfe3db816e 100644
--- a/llama_index/token_counter/mock_embed_model.py
+++ b/llama_index/token_counter/mock_embed_model.py
@@ -20,10 +20,19 @@ class MockEmbedding(BaseEmbedding):
         super().__init__(*args, **kwargs)
         self.embed_dim = embed_dim
 
+    def _get_vector(self) -> List[float]:
+        return [0.5] * self.embed_dim
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        return self._get_vector()
+
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        return self._get_vector()
+
     def _get_query_embedding(self, query: str) -> List[float]:
         """Get query embedding."""
-        return [0.5] * self.embed_dim
+        return self._get_vector()
 
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
-        return [0.5] * self.embed_dim
+        return self._get_vector()
diff --git a/tests/indices/knowledge_graph/test_base.py b/tests/indices/knowledge_graph/test_base.py
index 37ec9a7a24..f7e373f91a 100644
--- a/tests/indices/knowledge_graph/test_base.py
+++ b/tests/indices/knowledge_graph/test_base.py
@@ -17,6 +17,23 @@ from tests.mock_utils.mock_prompts import (
 
 
 class MockEmbedding(BaseEmbedding):
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        del query
+        return [0, 0, 1, 0, 0]
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        # assume dimensions are 4
+        if text == "('foo', 'is', 'bar')":
+            return [1, 0, 0, 0]
+        elif text == "('hello', 'is not', 'world')":
+            return [0, 1, 0, 0]
+        elif text == "('Jane', 'is mother of', 'Bob')":
+            return [0, 0, 1, 0]
+        elif text == "foo":
+            return [0, 0, 0, 1]
+        else:
+            raise ValueError("Invalid text for `mock_get_text_embedding`.")
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Mock get text embedding."""
         # assume dimensions are 4
diff --git a/tests/indices/query/test_compose_vector.py b/tests/indices/query/test_compose_vector.py
index 35a5d353e6..f652efa260 100644
--- a/tests/indices/query/test_compose_vector.py
+++ b/tests/indices/query/test_compose_vector.py
@@ -17,6 +17,37 @@ from tests.mock_utils.mock_prompts import MOCK_QUERY_KEYWORD_EXTRACT_PROMPT
 
 
 class MockEmbedding(BaseEmbedding):
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        if query == "Foo?":
+            return [0, 0, 1, 0, 0]
+        elif query == "Orange?":
+            return [0, 1, 0, 0, 0]
+        elif query == "Cat?":
+            return [0, 0, 0, 1, 0]
+        else:
+            raise ValueError("Invalid query for `_get_query_embedding`.")
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        # assume dimensions are 5
+        if text == "Hello world.":
+            return [1, 0, 0, 0, 0]
+        elif text == "This is a test.":
+            return [0, 1, 0, 0, 0]
+        elif text == "This is another test.":
+            return [0, 0, 1, 0, 0]
+        elif text == "This is a test v2.":
+            return [0, 0, 0, 1, 0]
+        elif text == "foo bar":
+            return [0, 0, 1, 0, 0]
+        elif text == "apple orange":
+            return [0, 1, 0, 0, 0]
+        elif text == "toronto london":
+            return [1, 0, 0, 0, 0]
+        elif text == "cat dog":
+            return [0, 0, 0, 1, 0]
+        else:
+            raise ValueError("Invalid text for `mock_get_text_embedding`.")
+
     def _get_query_embedding(self, query: str) -> List[float]:
         """Mock get query embedding."""
         if query == "Foo?":
diff --git a/tests/indices/query/test_query_bundle.py b/tests/indices/query/test_query_bundle.py
index 54d36dd425..ea2901a453 100644
--- a/tests/indices/query/test_query_bundle.py
+++ b/tests/indices/query/test_query_bundle.py
@@ -26,6 +26,25 @@ def documents() -> List[Document]:
 
 
 class MockEmbedding(BaseEmbedding):
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        text_embed_map: Dict[str, List[float]] = {
+            "It is what it is.": [1.0, 0.0, 0.0, 0.0, 0.0],
+            "The meaning of life": [0.0, 1.0, 0.0, 0.0, 0.0],
+        }
+
+        return text_embed_map[query]
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        text_embed_map: Dict[str, List[float]] = {
+            "Correct.": [0.5, 0.5, 0.0, 0.0, 0.0],
+            "Hello world.": [1.0, 0.0, 0.0, 0.0, 0.0],
+            "This is a test.": [0.0, 1.0, 0.0, 0.0, 0.0],
+            "This is another test.": [0.0, 0.0, 1.0, 0.0, 0.0],
+            "This is a test v2.": [0.0, 0.0, 0.0, 1.0, 0.0],
+        }
+
+        return text_embed_map[text]
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get node text embedding."""
         text_embed_map: Dict[str, List[float]] = {
diff --git a/tests/indices/vector_store/mock_services.py b/tests/indices/vector_store/mock_services.py
index 4d4b26dc6a..dd570d0d5c 100644
--- a/tests/indices/vector_store/mock_services.py
+++ b/tests/indices/vector_store/mock_services.py
@@ -4,6 +4,30 @@ from llama_index.embeddings.base import BaseEmbedding
 
 
 class MockEmbedding(BaseEmbedding):
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        del query
+        return [0, 0, 1, 0, 0]
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        # assume dimensions are 5
+        if text == "Hello world.":
+            return [1, 0, 0, 0, 0]
+        elif text == "This is a test.":
+            return [0, 1, 0, 0, 0]
+        elif text == "This is another test.":
+            return [0, 0, 1, 0, 0]
+        elif text == "This is a test v2.":
+            return [0, 0, 0, 1, 0]
+        elif text == "This is a test v3.":
+            return [0, 0, 0, 0, 1]
+        elif text == "This is bar test.":
+            return [0, 0, 1, 0, 0]
+        elif text == "Hello world backup.":
+            # this is used when "Hello world." is deleted.
+            return [1, 0, 0, 0, 0]
+        else:
+            return [0, 0, 0, 0, 0]
+
     def _get_query_embedding(self, query: str) -> List[float]:
         del query  # Unused
         return [0, 0, 1, 0, 0]
diff --git a/tests/playground/test_base.py b/tests/playground/test_base.py
index 83ff82440a..b16355f1b2 100644
--- a/tests/playground/test_base.py
+++ b/tests/playground/test_base.py
@@ -14,6 +14,21 @@ from llama_index.schema import Document
 
 
 class MockEmbedding(BaseEmbedding):
+    async def _aget_query_embedding(self, query: str) -> List[float]:
+        del query
+        return [0, 0, 1, 0, 0]
+
+    async def _aget_text_embedding(self, text: str) -> List[float]:
+        # assume dimensions are 5
+        if text == "They're taking the Hobbits to Isengard!":
+            return [1, 0, 0, 0, 0]
+        elif text == "I can't carry it for you.":
+            return [0, 1, 0, 0, 0]
+        elif text == "But I can carry you!":
+            return [0, 0, 1, 0, 0]
+        else:
+            raise ValueError("Invalid text for `mock_get_text_embedding`.")
+
     def _get_text_embedding(self, text: str) -> List[float]:
         """Mock get text embedding."""
         # assume dimensions are 5
-- 
GitLab