Skip to content
Snippets Groups Projects
Commit 4a12b4a8 authored by zahid-syed's avatar zahid-syed
Browse files

fix mistral ai llm issue

parent 09f0335b
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ class MistralEncoder(BaseEncoder):
if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
def _initialize_client(self, api_key):
try:
......@@ -67,7 +67,9 @@ class MistralEncoder(BaseEncoder):
if (
not embeds
or not isinstance(embeds, self._mistralai.models.embeddings.EmbeddingResponse)
or not isinstance(
embeds, self._mistralai.models.embeddings.EmbeddingResponse
)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
......
......@@ -14,6 +14,7 @@ class MistralAILLM(BaseLLM):
_client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()
def __init__(
self,
......@@ -25,12 +26,13 @@ class MistralAILLM(BaseLLM):
if name is None:
name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name)
self._client = self._initialize_client(mistralai_api_key)
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self.temperature = temperature
self.max_tokens = max_tokens
def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
......@@ -47,15 +49,22 @@ class MistralAILLM(BaseLLM):
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
return client
return client, mistralai
def __call__(self, messages: List[Message]) -> str:
if self._client is None:
raise ValueError("MistralAI client is not initialized.")
chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
)
for m in messages
]
try:
completion = self._client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
messages=chat_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment