From 098de29b231e427bbb1ede10e63edc2bd801e7f3 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Thu, 10 Oct 2024 23:00:35 +0200 Subject: [PATCH] feat: move cohere to optional dep --- poetry.lock | 27 +++++++++++++------------ pyproject.toml | 3 ++- semantic_router/encoders/cohere.py | 32 ++++++++++++++++++++++++------ semantic_router/llms/cohere.py | 20 +++++++++++++++---- 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0c6f20e6..11dbb7ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -327,7 +327,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "boto3" version = "1.35.32" description = "The AWS SDK for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "boto3-1.35.32-py3-none-any.whl", hash = "sha256:786a243f4b4827c6ae149442bf544c2ae449570cf23616a5d386f7a2633e0e08"}, @@ -346,7 +346,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.35.32" description = "Low-level, data-driven core of boto 3." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "botocore-1.35.32-py3-none-any.whl", hash = "sha256:2c0c2b62dd156daf904525f3f523ae22bf34ac109d727acf0bbfbca291440fc3"}, @@ -582,7 +582,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cohere" version = "5.10.0" description = "" -optional = false +optional = true python-versions = "<4.0,>=3.8" files = [ {file = "cohere-5.10.0-py3-none-any.whl", hash = "sha256:46e50e3e8514a99cf77b4c022c8077a6205fba948051c33087ddeb66ec706f0a"}, @@ -982,7 +982,7 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth name = "fastavro" version = "1.9.7" description = "Fast read/write of AVRO files" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, @@ -1056,7 +1056,7 @@ tqdm = ">=4.66,<5.0" name = "filelock" version = "3.16.1" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, @@ -1240,7 +1240,7 @@ files = [ name = "fsspec" version = "2024.9.0" description = "File-system specification" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, @@ -1796,7 +1796,7 @@ socks = ["socksio (==1.*)"] name = "httpx-sse" version = "0.4.0" description = "Consume Server-Sent Event (SSE) messages with HTTPX." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, @@ -1807,7 +1807,7 @@ files = [ name = "huggingface-hub" version = "0.25.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "huggingface_hub-0.25.1-py3-none-any.whl", hash = "sha256:a5158ded931b3188f54ea9028097312cb0acd50bffaaa2612014c3c526b44972"}, @@ -2123,7 +2123,7 @@ files = [ name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, @@ -3211,7 +3211,7 @@ files = [ name = "parameterized" version = "0.9.0" description = "Parameterized testing with any Python test framework" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"}, @@ -4347,7 +4347,7 @@ files = [ name = "s3transfer" version = "0.10.2" description = "An Amazon S3 Transfer Manager" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, @@ -4862,7 +4862,7 @@ files = [ name = "tokenizers" version = "0.20.0" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6cff5c5e37c41bc5faa519d6f3df0679e4b37da54ea1f42121719c5e2b4905c0"}, @@ -5488,6 +5488,7 @@ type = ["pytest-mypy"] [extras] bedrock = ["boto3", "botocore"] +cohere = ["cohere"] docs = ["sphinx", "sphinxawesome-theme"] fastembed = ["fastembed"] google = ["google-cloud-aiplatform"] @@ -5503,4 +5504,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "3b6d8cef3e0d6c516a9d9704350e8ff6dac7277cabed851f8c4ccc84214df6ea" +content-hash = "b0ddd77f2b9a210601eba56f69630eaa6a53cb358cae95bace2ae080d51c7812" diff --git a/pyproject.toml b/pyproject.toml index 99f91f2d..42abf2c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = "MIT" python = ">=3.9,<3.13" pydantic = "^2.5.3" openai = ">=1.10.0,<2.0.0" -cohere = ">=5.9.4,<6.00" +cohere = {version = ">=5.9.4,<6.00", optional = true} mistralai= {version = ">=0.0.12,<0.1.0", optional = true} numpy = "^1.25.2" colorlog = "^6.8.0" @@ -52,6 +52,7 @@ bedrock = ["boto3", "botocore"] postgres = ["psycopg2"] fastembed = ["fastembed"] docs = ["sphinx", "sphinxawesome-theme"] +cohere = ["cohere"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 01426e9f..4dc02146 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -1,15 +1,14 @@ import os -from typing import List, Optional +from typing import Any, List, Optional -import cohere -from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse +from pydantic.v1 import PrivateAttr from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault class CohereEncoder(BaseEncoder): - client: Optional[cohere.Client] = None + client: Any = PrivateAttr() type: str = "cohere" input_type: Optional[str] = "search_query" @@ -28,15 +27,36 @@ class CohereEncoder(BaseEncoder): input_type=input_type, # type: ignore ) self.input_type = input_type + self.client = self._initialize_client(cohere_api_key) + + def _initialize_client(self, cohere_api_key: Optional[str] = None): + """Initializes the Cohere client. + + :param cohere_api_key: The API key for the Cohere client, can also + be set via the COHERE_API_KEY environment variable. + + :return: An instance of the Cohere client. + """ + try: + import cohere + from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse + self.EmbeddingsByTypeEmbedResponse = EmbeddingsByTypeEmbedResponse + except ImportError: + raise ImportError( + "Please install Cohere to use CohereEncoder. " + "You can install it with: " + "`pip install 'semantic-router[cohere]'`" + ) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: raise ValueError("Cohere API key cannot be 'None'.") try: - self.client = cohere.Client(cohere_api_key) + client = cohere.Client(cohere_api_key) except Exception as e: raise ValueError( f"Cohere API client failed to initialize. Error: {e}" ) from e + return client def __call__(self, docs: List[str]) -> List[List[float]]: if self.client is None: @@ -46,7 +66,7 @@ class CohereEncoder(BaseEncoder): texts=docs, input_type=self.input_type, model=self.name ) # Check for unsupported type. - if isinstance(embeds, EmbeddingsByTypeEmbedResponse): + if isinstance(embeds, self.EmbeddingsByTypeEmbedResponse): raise NotImplementedError( "Handling of EmbedByTypeResponseEmbeddings is not implemented." ) diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index 37eb4338..98dc445c 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -1,14 +1,14 @@ import os -from typing import List, Optional +from typing import Any, List, Optional -import cohere +from pydantic.v1 import PrivateAttr from semantic_router.llms import BaseLLM from semantic_router.schema import Message class CohereLLM(BaseLLM): - client: Optional[cohere.Client] = None + client: Any = PrivateAttr() def __init__( self, @@ -18,15 +18,27 @@ class CohereLLM(BaseLLM): if name is None: name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") super().__init__(name=name) + self.client = self._initialize_client(cohere_api_key) + + def _initialize_client(self, cohere_api_key: Optional[str] = None): + try: + import cohere + except ImportError: + raise ImportError( + "Please install Cohere to use CohereLLM. " + "You can install it with: " + "`pip install 'semantic-router[cohere]'`" + ) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: raise ValueError("Cohere API key cannot be 'None'.") try: - self.client = cohere.Client(cohere_api_key) + client = cohere.Client(cohere_api_key) except Exception as e: raise ValueError( f"Cohere API client failed to initialize. Error: {e}" ) from e + return client def __call__(self, messages: List[Message]) -> str: if self.client is None: -- GitLab