Skip to content
Snippets Groups Projects
Commit 6b901f8d authored by zahid-syed's avatar zahid-syed
Browse files

optional dependency issue fix started

parent 0cd4c37a
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,7 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= "^0.0.12"
mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
......
"""This file contains the MistralEncoder class which is used to encode text using MistralAI"""
import os
from time import sleep
from typing import List, Optional
from typing import List, Optional, Any
from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr
class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""
client: Optional[MistralClient]
client: Any = PrivateAttr()
embedding_response: Any = PrivateAttr()
mistral_exception: Any = PrivateAttr()
type: str = "mistral"
def __init__(
......@@ -29,12 +31,39 @@ class MistralEncoder(BaseEncoder):
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
self._client = self._initialize_client(mistralai_api_key)
def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
self.client = MistralClient(api_key=api_key)
self.embedding_response = EmbeddingResponse
self.mistral_exception = MistralException
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("Mistral client not initialized")
embeds = None
......@@ -46,13 +75,13 @@ class MistralEncoder(BaseEncoder):
embeds = self.client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
except self.mistral_exception as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data:
if not embeds or not isinstance(embeds, self.embedding_response) or not embeds.data:
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
return embeddings
\ No newline at end of file
import os
from typing import List, Optional
from typing import List, Optional, Any
from mistralai.client import MistralClient
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
from pydantic.v1 import PrivateAttr
class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
......@@ -27,15 +29,26 @@ class MistralAILLM(BaseLLM):
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
self._initialize_client(api_key)
self.temperature = temperature
self.max_tokens = max_tokens
def _initialize_client(self, api_key):
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
try:
self.client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("MistralAI client is not initialized.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment