From 098de29b231e427bbb1ede10e63edc2bd801e7f3 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Thu, 10 Oct 2024 23:00:35 +0200
Subject: [PATCH] feat: move cohere to optional dep

---
 poetry.lock                        | 27 +++++++++++++------------
 pyproject.toml                     |  3 ++-
 semantic_router/encoders/cohere.py | 32 ++++++++++++++++++++++++------
 semantic_router/llms/cohere.py     | 20 +++++++++++++++----
 4 files changed, 58 insertions(+), 24 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index 0c6f20e6..11dbb7ed 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -327,7 +327,7 @@ uvloop = ["uvloop (>=0.15.2)"]
 name = "boto3"
 version = "1.35.32"
 description = "The AWS SDK for Python"
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "boto3-1.35.32-py3-none-any.whl", hash = "sha256:786a243f4b4827c6ae149442bf544c2ae449570cf23616a5d386f7a2633e0e08"},
@@ -346,7 +346,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
 name = "botocore"
 version = "1.35.32"
 description = "Low-level, data-driven core of boto 3."
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "botocore-1.35.32-py3-none-any.whl", hash = "sha256:2c0c2b62dd156daf904525f3f523ae22bf34ac109d727acf0bbfbca291440fc3"},
@@ -582,7 +582,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
 name = "cohere"
 version = "5.10.0"
 description = ""
-optional = false
+optional = true
 python-versions = "<4.0,>=3.8"
 files = [
     {file = "cohere-5.10.0-py3-none-any.whl", hash = "sha256:46e50e3e8514a99cf77b4c022c8077a6205fba948051c33087ddeb66ec706f0a"},
@@ -982,7 +982,7 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth
 name = "fastavro"
 version = "1.9.7"
 description = "Fast read/write of AVRO files"
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"},
@@ -1056,7 +1056,7 @@ tqdm = ">=4.66,<5.0"
 name = "filelock"
 version = "3.16.1"
 description = "A platform independent file lock."
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"},
@@ -1240,7 +1240,7 @@ files = [
 name = "fsspec"
 version = "2024.9.0"
 description = "File-system specification"
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"},
@@ -1796,7 +1796,7 @@ socks = ["socksio (==1.*)"]
 name = "httpx-sse"
 version = "0.4.0"
 description = "Consume Server-Sent Event (SSE) messages with HTTPX."
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
@@ -1807,7 +1807,7 @@ files = [
 name = "huggingface-hub"
 version = "0.25.1"
 description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
-optional = false
+optional = true
 python-versions = ">=3.8.0"
 files = [
     {file = "huggingface_hub-0.25.1-py3-none-any.whl", hash = "sha256:a5158ded931b3188f54ea9028097312cb0acd50bffaaa2612014c3c526b44972"},
@@ -2123,7 +2123,7 @@ files = [
 name = "jmespath"
 version = "1.0.1"
 description = "JSON Matching Expressions"
-optional = false
+optional = true
 python-versions = ">=3.7"
 files = [
     {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
@@ -3211,7 +3211,7 @@ files = [
 name = "parameterized"
 version = "0.9.0"
 description = "Parameterized testing with any Python test framework"
-optional = false
+optional = true
 python-versions = ">=3.7"
 files = [
     {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"},
@@ -4347,7 +4347,7 @@ files = [
 name = "s3transfer"
 version = "0.10.2"
 description = "An Amazon S3 Transfer Manager"
-optional = false
+optional = true
 python-versions = ">=3.8"
 files = [
     {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"},
@@ -4862,7 +4862,7 @@ files = [
 name = "tokenizers"
 version = "0.20.0"
 description = ""
-optional = false
+optional = true
 python-versions = ">=3.7"
 files = [
     {file = "tokenizers-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6cff5c5e37c41bc5faa519d6f3df0679e4b37da54ea1f42121719c5e2b4905c0"},
@@ -5488,6 +5488,7 @@ type = ["pytest-mypy"]
 
 [extras]
 bedrock = ["boto3", "botocore"]
+cohere = ["cohere"]
 docs = ["sphinx", "sphinxawesome-theme"]
 fastembed = ["fastembed"]
 google = ["google-cloud-aiplatform"]
@@ -5503,4 +5504,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.9,<3.13"
-content-hash = "3b6d8cef3e0d6c516a9d9704350e8ff6dac7277cabed851f8c4ccc84214df6ea"
+content-hash = "b0ddd77f2b9a210601eba56f69630eaa6a53cb358cae95bace2ae080d51c7812"
diff --git a/pyproject.toml b/pyproject.toml
index 99f91f2d..42abf2c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ license = "MIT"
 python = ">=3.9,<3.13"
 pydantic = "^2.5.3"
 openai = ">=1.10.0,<2.0.0"
-cohere = ">=5.9.4,<6.00"
+cohere = {version = ">=5.9.4,<6.00", optional = true}
 mistralai= {version = ">=0.0.12,<0.1.0", optional = true}
 numpy = "^1.25.2"
 colorlog = "^6.8.0"
@@ -52,6 +52,7 @@ bedrock = ["boto3", "botocore"]
 postgres = ["psycopg2"]
 fastembed = ["fastembed"]
 docs = ["sphinx", "sphinxawesome-theme"]
+cohere = ["cohere"]
 
 [tool.poetry.group.dev.dependencies]
 ipykernel = "^6.25.0"
diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py
index 01426e9f..4dc02146 100644
--- a/semantic_router/encoders/cohere.py
+++ b/semantic_router/encoders/cohere.py
@@ -1,15 +1,14 @@
 import os
-from typing import List, Optional
+from typing import Any, List, Optional
 
-import cohere
-from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse
+from pydantic.v1 import PrivateAttr
 
 from semantic_router.encoders import BaseEncoder
 from semantic_router.utils.defaults import EncoderDefault
 
 
 class CohereEncoder(BaseEncoder):
-    client: Optional[cohere.Client] = None
+    client: Any = PrivateAttr()
     type: str = "cohere"
     input_type: Optional[str] = "search_query"
 
@@ -28,15 +27,36 @@ class CohereEncoder(BaseEncoder):
             input_type=input_type,  # type: ignore
         )
         self.input_type = input_type
+        self.client = self._initialize_client(cohere_api_key)
+    
+    def _initialize_client(self, cohere_api_key: Optional[str] = None):
+        """Initializes the Cohere client.
+
+        :param cohere_api_key: The API key for the Cohere client, can also
+        be set via the COHERE_API_KEY environment variable.
+
+        :return: An instance of the Cohere client.
+        """
+        try:
+            import cohere
+            from cohere.types.embed_response import EmbeddingsByTypeEmbedResponse
+            self.EmbeddingsByTypeEmbedResponse = EmbeddingsByTypeEmbedResponse
+        except ImportError:
+            raise ImportError(
+                "Please install Cohere to use CohereEncoder. "
+                "You can install it with: "
+                "`pip install 'semantic-router[cohere]'`"
+            )
         cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
         if cohere_api_key is None:
             raise ValueError("Cohere API key cannot be 'None'.")
         try:
-            self.client = cohere.Client(cohere_api_key)
+            client = cohere.Client(cohere_api_key)
         except Exception as e:
             raise ValueError(
                 f"Cohere API client failed to initialize. Error: {e}"
             ) from e
+        return client
 
     def __call__(self, docs: List[str]) -> List[List[float]]:
         if self.client is None:
@@ -46,7 +66,7 @@ class CohereEncoder(BaseEncoder):
                 texts=docs, input_type=self.input_type, model=self.name
             )
             # Check for unsupported type.
-            if isinstance(embeds, EmbeddingsByTypeEmbedResponse):
+            if isinstance(embeds, self.EmbeddingsByTypeEmbedResponse):
                 raise NotImplementedError(
                     "Handling of EmbedByTypeResponseEmbeddings is not implemented."
                 )
diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py
index 37eb4338..98dc445c 100644
--- a/semantic_router/llms/cohere.py
+++ b/semantic_router/llms/cohere.py
@@ -1,14 +1,14 @@
 import os
-from typing import List, Optional
+from typing import Any, List, Optional
 
-import cohere
+from pydantic.v1 import PrivateAttr
 
 from semantic_router.llms import BaseLLM
 from semantic_router.schema import Message
 
 
 class CohereLLM(BaseLLM):
-    client: Optional[cohere.Client] = None
+    client: Any = PrivateAttr()
 
     def __init__(
         self,
@@ -18,15 +18,27 @@ class CohereLLM(BaseLLM):
         if name is None:
             name = os.getenv("COHERE_CHAT_MODEL_NAME", "command")
         super().__init__(name=name)
+        self.client = self._initialize_client(cohere_api_key)
+
+    def _initialize_client(self, cohere_api_key: Optional[str] = None):
+        try:
+            import cohere
+        except ImportError:
+            raise ImportError(
+                "Please install Cohere to use CohereLLM. "
+                "You can install it with: "
+                "`pip install 'semantic-router[cohere]'`"
+            )
         cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
         if cohere_api_key is None:
             raise ValueError("Cohere API key cannot be 'None'.")
         try:
-            self.client = cohere.Client(cohere_api_key)
+            client = cohere.Client(cohere_api_key)
         except Exception as e:
             raise ValueError(
                 f"Cohere API client failed to initialize. Error: {e}"
             ) from e
+        return client
 
     def __call__(self, messages: List[Message]) -> str:
         if self.client is None:
-- 
GitLab