diff --git a/poetry.lock b/poetry.lock index de3afbe884907ab36c776296a4a3e7b667eafb28..5324dfd14d40c6f7cce8e7f1fff759dd9f519d91 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "aiohappyeyeballs" version = "2.4.3" @@ -226,6 +237,26 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "aurelio-sdk" +version = "0.0.16" +description = "Aurelio Platform SDK" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "aurelio_sdk-0.0.16-py3-none-any.whl", hash = "sha256:015fb384552fea5541350f1ca1be72c0186b38ecab0a97a2a29ec38e611cfbce"}, + {file = "aurelio_sdk-0.0.16.tar.gz", hash = "sha256:afdbada21d91160044dc8b953eddfb2cdae18f1c2a946098ca3f8057d6bc9ce7"}, +] + +[package.dependencies] +aiofiles = ">=24.1.0,<25.0.0" +aiohttp = ">=3.10.5,<4.0.0" +colorlog = ">=6.8.2,<7.0.0" +pydantic = ">=2.9.2,<3.0.0" +python-dotenv = ">=1.0.1,<2.0.0" +requests = ">=2.32.3,<3.0.0" +requests-toolbelt = ">=1.0.0,<2.0.0" + [[package]] name = "babel" version = "2.16.0" @@ -4163,7 +4194,7 @@ six = ">=1.5" name = "python-dotenv" version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, @@ -4551,6 +4582,20 @@ requests = ">=2.22,<3" [package.extras] fixture = ["fixtures"] +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +description = "A utility belt for advanced users of python-requests" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, + {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, +] + +[package.dependencies] +requests = ">=2.0.1,<3.0.0" + [[package]] name = "rsa" version = "4.9" @@ -5802,4 +5847,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "377177b0000f74fa3bcc0d7e4abde276b8b6e4954471fd9a412ccf54064f54ca" +content-hash = "4ed42980596d2cca65bdd76d68742e6cd6283f2e2320277bd5e988b376b15997" diff --git a/pyproject.toml b/pyproject.toml index fec91ddfb4124b568576d091d60cd3e68498cd1f..f2817a655bd989bee4f0efb9120fa74301fa86c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ mistralai= {version = ">=0.0.12,<0.1.0", optional = true} numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" +aurelio-sdk = {version = "^0.0.16"} pinecone-text = {version = ">=0.7.1,<0.10.0", optional = true} torch = {version = ">=2.1.0,<2.6.0", optional = true} transformers = {version = ">=4.36.2", optional = true} diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index a1026240d37fbfecb1ec8b1445d42fc05f04265f..4256e2e5e61c18a87d65b6befa2e99c90983af5c 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,5 +1,6 @@ from typing import List, Optional +from semantic_router.encoders.aurelio import AurelioSparseEncoder from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.bedrock import BedrockEncoder from semantic_router.encoders.bm25 import BM25Encoder @@ -17,6 +18,7 @@ from semantic_router.encoders.zure import AzureOpenAIEncoder from semantic_router.schema import EncoderType __all__ = [ + "AurelioSparseEncoder", "BaseEncoder", "AzureOpenAIEncoder", "CohereEncoder", diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc9fc86b300cf90cddbef4759672da205d1d4e6 --- /dev/null +++ b/semantic_router/encoders/aurelio.py @@ -0,0 +1,46 @@ +import os +from typing import Any, Dict, List, Optional +from pydantic.v1 import Field + +from aurelio_sdk import AurelioClient, AsyncAurelioClient, EmbeddingResponse + +from semantic_router.encoders import BaseEncoder + + +class AurelioSparseEncoder(BaseEncoder): + model: Optional[Any] = None + idx_mapping: Optional[Dict[int, int]] = None + client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) + async_client: AsyncAurelioClient = Field(default_factory=AsyncAurelioClient, exclude=True) + type: str = "sparse" + + def __init__( + self, + name: str = "bm25", + score_threshold: float = 1.0, + api_key: Optional[str] = None, + ): + super().__init__(name=name, score_threshold=score_threshold) + if api_key is None: + api_key = os.getenv("AURELIO_API_KEY") + if api_key is None: + raise ValueError("AURELIO_API_KEY environment variable is not set.") + self.client = AurelioClient(api_key=api_key) + self.async_client = AsyncAurelioClient(api_key=api_key) + + def __call__(self, docs: list[str]) -> list[dict[int, float]]: + res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name) + embeds = [r.embedding.model_dump() for r in res.data] + # convert sparse vector to {index: value} format + sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds] + return sparse_dicts + + async def acall(self, docs: list[str]) -> list[dict[int, float]]: + res: EmbeddingResponse = await self.async_client.embedding(input=docs, model=self.name) + embeds = [r.embedding.model_dump() for r in res.data] + # convert sparse vector to {index: value} format + sparse_dicts = [{i: v for i, v in zip(e["indices"], e["values"])} for e in embeds] + return sparse_dicts + + def fit(self, docs: List[str]): + raise NotImplementedError("AurelioSparseEncoder does not support fit.")