Skip to content
Snippets Groups Projects
Unverified Commit 5671177d authored by Jithin James's avatar Jithin James Committed by GitHub
Browse files

feat: async retriever (embeddings) (#6587)

parent fd47b4a3
No related branches found
No related tags found
No related merge requests found
Showing
with 217 additions and 12 deletions
......@@ -5,6 +5,7 @@
### New Features
- Added `kwargs` to `ComposableGraph` for the underlying query engines (#6990)
- Validate openai key on init (#6940)
- Added async embeddings and async RetrieverQueryEngine (#6587)
### Bug Fixes / Nits
- Fix achat memory initialization for data agents (#7000)
......
......@@ -72,6 +72,10 @@ class BaseEmbedding:
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
@abstractmethod
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding asynchronously."""
def get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING)
......@@ -88,6 +92,19 @@ class BaseEmbedding:
)
return query_embedding
async def aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
event_id = self.callback_manager.on_event_start(CBEventType.EMBEDDING)
query_embedding = await self._aget_query_embedding(query)
query_tokens_count = len(self._tokenizer(query))
self._total_tokens_used += query_tokens_count
self.callback_manager.on_event_end(
CBEventType.EMBEDDING,
payload={EventPayload.CHUNKS: [query]},
event_id=event_id,
)
return query_embedding
def get_agg_embedding_from_queries(
self,
queries: List[str],
......@@ -98,6 +115,16 @@ class BaseEmbedding:
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
async def aget_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., List[float]]] = None,
) -> List[float]:
"""Get aggregated embedding from multiple queries."""
query_embeddings = [await self.aget_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)
@abstractmethod
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
......
......@@ -2,7 +2,6 @@
from typing import List, Optional
from llama_index.embeddings.base import BaseEmbedding
# Google Universal Sentence Encode v5
......@@ -28,6 +27,11 @@ class GoogleUnivSentEncoderEmbedding(BaseEmbedding):
"""Get query embedding."""
return self._get_embedding(query)
# TODO: use proper async methods
async def _aget_text_embedding(self, query: str) -> List[float]:
"""Get text embedding."""
return self._get_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_embedding(text)
......
......@@ -4,7 +4,6 @@
from typing import Any, List
from llama_index.bridge.langchain import Embeddings as LCEmbeddings
from llama_index.embeddings.base import BaseEmbedding
......@@ -25,6 +24,13 @@ class LangchainEmbedding(BaseEmbedding):
"""Get query embedding."""
return self._langchain_embedding.embed_query(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
return await self._langchain_embedding.aembed_query(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
embeds = await self._langchain_embedding.aembed_documents([text])
return embeds[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._langchain_embedding.embed_documents([text])[0]
......
......@@ -263,6 +263,15 @@ class OpenAIEmbedding(BaseEmbedding):
**self.openai_kwargs,
)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await aget_embedding(
query,
engine=self.query_engine,
deployment_id=self.deployment_name,
**self.openai_kwargs,
)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return get_embedding(
......
from abc import ABC, abstractmethod
from typing import List, Optional
from llama_index.schema import NodeWithScore
from llama_index.indices.query.schema import QueryBundle, QueryType
from llama_index.indices.service_context import ServiceContext
from llama_index.schema import NodeWithScore
class BaseRetriever(ABC):
......@@ -21,6 +21,12 @@ class BaseRetriever(ABC):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
return self._retrieve(str_or_query_bundle)
async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
nodes = await self._aretrieve(str_or_query_bundle)
return nodes
@abstractmethod
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query.
......@@ -30,6 +36,16 @@ class BaseRetriever(ABC):
"""
pass
# TODO: make this abstract
# @abstractmethod
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asyncronously retrieve nodes given query.
Implemented by the user.
"""
return []
def get_service_context(self) -> Optional[ServiceContext]:
"""Attempts to resolve a service context.
Short-circuits at self.service_context, self._service_context,
......
......@@ -71,13 +71,29 @@ class VectorIndexRetriever(BaseRetriever):
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(query_bundle)
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if self._vector_store.is_embedding_query:
if query_bundle.embedding is None:
embed_model = self._service_context.embed_model
query_bundle.embedding = (
await embed_model.aget_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
return self._get_nodes_with_embeddings(query_bundle)
def _get_nodes_with_embeddings(
self, query_bundle_with_embeddings: QueryBundle
) -> List[NodeWithScore]:
query = VectorStoreQuery(
query_embedding=query_bundle.embedding,
query_embedding=query_bundle_with_embeddings.embedding,
similarity_top_k=self._similarity_top_k,
node_ids=self._node_ids,
doc_ids=self._doc_ids,
query_str=query_bundle.query_str,
query_str=query_bundle_with_embeddings.query_str,
mode=self._vector_store_query_mode,
alpha=self._alpha,
filters=self._filters,
......
......@@ -104,13 +104,22 @@ class RetrieverQueryEngine(BaseQueryEngine):
node_postprocessors=node_postprocessors,
)
def _apply_node_postprocessors(
self, nodes: List[NodeWithScore]
) -> List[NodeWithScore]:
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(nodes)
return nodes
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
nodes = self._apply_node_postprocessors(nodes)
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = await self._retriever.aretrieve(query_bundle)
nodes = self._apply_node_postprocessors(nodes)
return nodes
......@@ -171,7 +180,9 @@ class RetrieverQueryEngine(BaseQueryEngine):
)
retrieve_id = self.callback_manager.on_event_start(CBEventType.RETRIEVE)
nodes = self.retrieve(query_bundle)
nodes = await self.aretrieve(query_bundle)
self.callback_manager.on_event_end(
CBEventType.RETRIEVE,
payload={EventPayload.NODES: nodes},
......
......@@ -20,10 +20,19 @@ class MockEmbedding(BaseEmbedding):
super().__init__(*args, **kwargs)
self.embed_dim = embed_dim
def _get_vector(self) -> List[float]:
return [0.5] * self.embed_dim
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_vector()
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_vector()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return [0.5] * self.embed_dim
return self._get_vector()
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return [0.5] * self.embed_dim
return self._get_vector()
......@@ -17,6 +17,23 @@ from tests.mock_utils.mock_prompts import (
class MockEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> List[float]:
del query
return [0, 0, 1, 0, 0]
async def _aget_text_embedding(self, text: str) -> List[float]:
# assume dimensions are 4
if text == "('foo', 'is', 'bar')":
return [1, 0, 0, 0]
elif text == "('hello', 'is not', 'world')":
return [0, 1, 0, 0]
elif text == "('Jane', 'is mother of', 'Bob')":
return [0, 0, 1, 0]
elif text == "foo":
return [0, 0, 0, 1]
else:
raise ValueError("Invalid text for `mock_get_text_embedding`.")
def _get_text_embedding(self, text: str) -> List[float]:
"""Mock get text embedding."""
# assume dimensions are 4
......
......@@ -17,6 +17,37 @@ from tests.mock_utils.mock_prompts import MOCK_QUERY_KEYWORD_EXTRACT_PROMPT
class MockEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> List[float]:
if query == "Foo?":
return [0, 0, 1, 0, 0]
elif query == "Orange?":
return [0, 1, 0, 0, 0]
elif query == "Cat?":
return [0, 0, 0, 1, 0]
else:
raise ValueError("Invalid query for `_get_query_embedding`.")
async def _aget_text_embedding(self, text: str) -> List[float]:
# assume dimensions are 5
if text == "Hello world.":
return [1, 0, 0, 0, 0]
elif text == "This is a test.":
return [0, 1, 0, 0, 0]
elif text == "This is another test.":
return [0, 0, 1, 0, 0]
elif text == "This is a test v2.":
return [0, 0, 0, 1, 0]
elif text == "foo bar":
return [0, 0, 1, 0, 0]
elif text == "apple orange":
return [0, 1, 0, 0, 0]
elif text == "toronto london":
return [1, 0, 0, 0, 0]
elif text == "cat dog":
return [0, 0, 0, 1, 0]
else:
raise ValueError("Invalid text for `mock_get_text_embedding`.")
def _get_query_embedding(self, query: str) -> List[float]:
"""Mock get query embedding."""
if query == "Foo?":
......
......@@ -26,6 +26,25 @@ def documents() -> List[Document]:
class MockEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> List[float]:
text_embed_map: Dict[str, List[float]] = {
"It is what it is.": [1.0, 0.0, 0.0, 0.0, 0.0],
"The meaning of life": [0.0, 1.0, 0.0, 0.0, 0.0],
}
return text_embed_map[query]
async def _aget_text_embedding(self, text: str) -> List[float]:
text_embed_map: Dict[str, List[float]] = {
"Correct.": [0.5, 0.5, 0.0, 0.0, 0.0],
"Hello world.": [1.0, 0.0, 0.0, 0.0, 0.0],
"This is a test.": [0.0, 1.0, 0.0, 0.0, 0.0],
"This is another test.": [0.0, 0.0, 1.0, 0.0, 0.0],
"This is a test v2.": [0.0, 0.0, 0.0, 1.0, 0.0],
}
return text_embed_map[text]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get node text embedding."""
text_embed_map: Dict[str, List[float]] = {
......
......@@ -4,6 +4,30 @@ from llama_index.embeddings.base import BaseEmbedding
class MockEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> List[float]:
del query
return [0, 0, 1, 0, 0]
async def _aget_text_embedding(self, text: str) -> List[float]:
# assume dimensions are 5
if text == "Hello world.":
return [1, 0, 0, 0, 0]
elif text == "This is a test.":
return [0, 1, 0, 0, 0]
elif text == "This is another test.":
return [0, 0, 1, 0, 0]
elif text == "This is a test v2.":
return [0, 0, 0, 1, 0]
elif text == "This is a test v3.":
return [0, 0, 0, 0, 1]
elif text == "This is bar test.":
return [0, 0, 1, 0, 0]
elif text == "Hello world backup.":
# this is used when "Hello world." is deleted.
return [1, 0, 0, 0, 0]
else:
return [0, 0, 0, 0, 0]
def _get_query_embedding(self, query: str) -> List[float]:
del query # Unused
return [0, 0, 1, 0, 0]
......
......@@ -14,6 +14,21 @@ from llama_index.schema import Document
class MockEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> List[float]:
del query
return [0, 0, 1, 0, 0]
async def _aget_text_embedding(self, text: str) -> List[float]:
# assume dimensions are 5
if text == "They're taking the Hobbits to Isengard!":
return [1, 0, 0, 0, 0]
elif text == "I can't carry it for you.":
return [0, 1, 0, 0, 0]
elif text == "But I can carry you!":
return [0, 0, 1, 0, 0]
else:
raise ValueError("Invalid text for `mock_get_text_embedding`.")
def _get_text_embedding(self, text: str) -> List[float]:
"""Mock get text embedding."""
# assume dimensions are 5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment