Skip to content
Snippets Groups Projects
Commit 4a12b4a8 authored by zahid-syed's avatar zahid-syed
Browse files

fix mistral ai llm issue

parent 09f0335b
No related branches found
No related tags found
No related merge requests found
...@@ -25,7 +25,7 @@ class MistralEncoder(BaseEncoder): ...@@ -25,7 +25,7 @@ class MistralEncoder(BaseEncoder):
if name is None: if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"] name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
self._client, self._mistralai = self._initialize_client(mistralai_api_key) self._client, self._mistralai = self._initialize_client(mistralai_api_key)
def _initialize_client(self, api_key): def _initialize_client(self, api_key):
try: try:
...@@ -67,7 +67,9 @@ class MistralEncoder(BaseEncoder): ...@@ -67,7 +67,9 @@ class MistralEncoder(BaseEncoder):
if ( if (
not embeds not embeds
or not isinstance(embeds, self._mistralai.models.embeddings.EmbeddingResponse) or not isinstance(
embeds, self._mistralai.models.embeddings.EmbeddingResponse
)
or not embeds.data or not embeds.data
): ):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}") raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
......
...@@ -14,6 +14,7 @@ class MistralAILLM(BaseLLM): ...@@ -14,6 +14,7 @@ class MistralAILLM(BaseLLM):
_client: Any = PrivateAttr() _client: Any = PrivateAttr()
temperature: Optional[float] temperature: Optional[float]
max_tokens: Optional[int] max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()
def __init__( def __init__(
self, self,
...@@ -25,12 +26,13 @@ class MistralAILLM(BaseLLM): ...@@ -25,12 +26,13 @@ class MistralAILLM(BaseLLM):
if name is None: if name is None:
name = EncoderDefault.MISTRAL.value["language_model"] name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name) super().__init__(name=name)
self._client = self._initialize_client(mistralai_api_key) self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
def _initialize_client(self, api_key): def _initialize_client(self, api_key):
try: try:
import mistralai
from mistralai.client import MistralClient from mistralai.client import MistralClient
except ImportError: except ImportError:
raise ImportError( raise ImportError(
...@@ -47,15 +49,22 @@ class MistralAILLM(BaseLLM): ...@@ -47,15 +49,22 @@ class MistralAILLM(BaseLLM):
raise ValueError( raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}" f"MistralAI API client failed to initialize. Error: {e}"
) from e ) from e
return client return client, mistralai
def __call__(self, messages: List[Message]) -> str: def __call__(self, messages: List[Message]) -> str:
if self._client is None: if self._client is None:
raise ValueError("MistralAI client is not initialized.") raise ValueError("MistralAI client is not initialized.")
chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
)
for m in messages
]
try: try:
completion = self._client.chat( completion = self._client.chat(
model=self.name, model=self.name,
messages=[m.to_mistral() for m in messages], messages=chat_messages,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment