Skip to content
Snippets Groups Projects
Commit 6b901f8d authored by zahid-syed's avatar zahid-syed
Browse files

optional dependency issue fix started

parent 0cd4c37a
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ python = ">=3.9,<3.13" ...@@ -19,7 +19,7 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3" pydantic = "^2.5.3"
openai = "^1.10.0" openai = "^1.10.0"
cohere = "^4.32" cohere = "^4.32"
mistralai= "^0.0.12" mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2" numpy = "^1.25.2"
colorlog = "^6.8.0" colorlog = "^6.8.0"
pyyaml = "^6.0.1" pyyaml = "^6.0.1"
......
"""This file contains the MistralEncoder class which is used to encode text using MistralAI""" """This file contains the MistralEncoder class which is used to encode text using MistralAI"""
import os import os
from time import sleep from time import sleep
from typing import List, Optional from typing import List, Optional, Any
from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr
class MistralEncoder(BaseEncoder): class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI""" """Class to encode text using MistralAI"""
client: Optional[MistralClient] client: Any = PrivateAttr()
embedding_response: Any = PrivateAttr()
mistral_exception: Any = PrivateAttr()
type: str = "mistral" type: str = "mistral"
def __init__( def __init__(
...@@ -29,12 +31,39 @@ class MistralEncoder(BaseEncoder): ...@@ -29,12 +31,39 @@ class MistralEncoder(BaseEncoder):
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None: if api_key is None:
raise ValueError("Mistral API key not provided") raise ValueError("Mistral API key not provided")
self._client = self._initialize_client(mistralai_api_key)
def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try: try:
self.client = MistralClient(api_key=api_key) self.client = MistralClient(api_key=api_key)
self.embedding_response = EmbeddingResponse
self.mistral_exception = MistralException
except Exception as e: except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
def __call__(self, docs: List[str]) -> List[List[float]]: def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None: if self.client is None:
raise ValueError("Mistral client not initialized") raise ValueError("Mistral client not initialized")
embeds = None embeds = None
...@@ -46,13 +75,13 @@ class MistralEncoder(BaseEncoder): ...@@ -46,13 +75,13 @@ class MistralEncoder(BaseEncoder):
embeds = self.client.embeddings(model=self.name, input=docs) embeds = self.client.embeddings(model=self.name, input=docs)
if embeds.data: if embeds.data:
break break
except MistralException as e: except self.mistral_exception as e:
sleep(2**_) sleep(2**_)
error_message = str(e) error_message = str(e)
except Exception as e: except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data: if not embeds or not isinstance(embeds, self.embedding_response) or not embeds.data:
raise ValueError(f"No embeddings returned from MistralAI: {error_message}") raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
\ No newline at end of file
import os import os
from typing import List, Optional from typing import List, Optional, Any
from mistralai.client import MistralClient
from semantic_router.llms import BaseLLM from semantic_router.llms import BaseLLM
from semantic_router.schema import Message from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from pydantic.v1 import PrivateAttr
class MistralAILLM(BaseLLM): class MistralAILLM(BaseLLM):
client: Optional[MistralClient] client: Any = PrivateAttr()
temperature: Optional[float] temperature: Optional[float]
max_tokens: Optional[int] max_tokens: Optional[int]
...@@ -27,15 +29,26 @@ class MistralAILLM(BaseLLM): ...@@ -27,15 +29,26 @@ class MistralAILLM(BaseLLM):
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None: if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.") raise ValueError("MistralAI API key cannot be 'None'.")
self._initialize_client(api_key)
self.temperature = temperature
self.max_tokens = max_tokens
def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try: try:
self.client = MistralClient(api_key=api_key) self.client = MistralClient(api_key=api_key)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}" f"MistralAI API client failed to initialize. Error: {e}"
) from e ) from e
self.temperature = temperature
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str: def __call__(self, messages: List[Message]) -> str:
if self.client is None: if self.client is None:
raise ValueError("MistralAI client is not initialized.") raise ValueError("MistralAI client is not initialized.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment