From e8f73a4b4791b55198788d295325887ad0c58641 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Mon, 3 Mar 2025 13:34:27 -0600
Subject: [PATCH] MistralAI MultImodal content blocks (#17997)

---
 .../llama_index/core/base/llms/types.py       |  15 +-
 .../llama_index/llms/mistralai/base.py        |  62 ++--
 .../llama_index/llms/mistralai/utils.py       |   4 +
 .../llama-index-llms-mistralai/pyproject.toml |   2 +-
 .../multi_modal_llms/mistralai/base.py        | 297 +++---------------
 .../multi_modal_llms/mistralai/utils.py       | 140 ---------
 .../pyproject.toml                            |   6 +-
 .../tests/test_multi-modal-llms_mistral.py    |   6 +-
 8 files changed, 106 insertions(+), 426 deletions(-)
 delete mode 100644 llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py

diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py
index 6fa3863210..52c9f03323 100644
--- a/llama-index-core/llama_index/core/base/llms/types.py
+++ b/llama-index-core/llama_index/core/base/llms/types.py
@@ -1,9 +1,11 @@
 from __future__ import annotations
 
 import base64
+import filetype
 from binascii import Error as BinasciiError
 from enum import Enum
 from io import BytesIO
+from pathlib import Path
 from typing import (
     Annotated,
     Any,
@@ -61,10 +63,13 @@ class ImageBlock(BaseModel):
 
     @field_validator("url", mode="after")
     @classmethod
-    def urlstr_to_anyurl(cls, url: str | AnyUrl) -> AnyUrl:
+    def urlstr_to_anyurl(cls, url: str | AnyUrl | None) -> AnyUrl | None:
         """Store the url as Anyurl."""
         if isinstance(url, AnyUrl):
             return url
+        if url is None:
+            return None
+
         return AnyUrl(url=url)
 
     @model_validator(mode="after")
@@ -76,6 +81,14 @@ class ImageBlock(BaseModel):
         operations, we won't load the path or the URL to guess the mimetype.
         """
         if not self.image:
+            if not self.image_mimetype:
+                path = self.path or self.url
+                if path:
+                    suffix = Path(str(path)).suffix.replace(".", "") or None
+                    mimetype = filetype.get_type(ext=suffix)
+                    if mimetype and str(mimetype.mime).startswith("image/"):
+                        self.image_mimetype = str(mimetype.mime)
+
             return self
 
         try:
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
index 822723ce53..929914b05c 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py
@@ -6,11 +6,14 @@ from llama_index.core.base.llms.types import (
     ChatResponse,
     ChatResponseAsyncGen,
     ChatResponseGen,
+    ContentBlock,
     CompletionResponse,
     CompletionResponseAsyncGen,
     CompletionResponseGen,
     LLMMetadata,
     MessageRole,
+    TextBlock,
+    ImageBlock,
 )
 from llama_index.core.bridge.pydantic import Field, PrivateAttr
 from llama_index.core.callbacks import CallbackManager
@@ -43,6 +46,9 @@ from mistralai.models import (
     SystemMessage,
     ToolMessage,
     UserMessage,
+    TextChunk,
+    ImageURLChunk,
+    ContentChunk,
 )
 
 if TYPE_CHECKING:
@@ -53,24 +59,52 @@ DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai"
 DEFAULT_MISTRALAI_MAX_TOKENS = 512
 
 
+def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[ContentChunk]:
+    content_chunks = []
+    for content_block in content_blocks:
+        if isinstance(content_block, TextBlock):
+            content_chunks.append(TextChunk(text=content_block.text))
+        elif isinstance(content_block, ImageBlock):
+            if content_block.url:
+                content_chunks.append(ImageURLChunk(url=content_block.url))
+            else:
+                base_64_str = (
+                    content_block.resolve_image(as_base64=True).read().decode("utf-8")
+                )
+                image_mimetype = content_block.image_mimetype
+                if not image_mimetype:
+                    raise ValueError(
+                        "Image mimetype not found in chat message image block"
+                    )
+
+                content_chunks.append(
+                    ImageURLChunk(
+                        image_url=f"data:{image_mimetype};base64,{base_64_str}"
+                    )
+                )
+        else:
+            raise ValueError(f"Unsupported content block type {type(content_block)}")
+
+    return content_chunks
+
+
 def to_mistral_chatmessage(
     messages: Sequence[ChatMessage],
 ) -> List[Messages]:
     new_messages = []
     for m in messages:
         tool_calls = m.additional_kwargs.get("tool_calls")
+        chunks = to_mistral_chunks(m.blocks)
         if m.role == MessageRole.USER:
-            new_messages.append(UserMessage(content=m.content))
+            new_messages.append(UserMessage(content=chunks))
         elif m.role == MessageRole.ASSISTANT:
-            new_messages.append(
-                AssistantMessage(content=m.content, tool_calls=tool_calls)
-            )
+            new_messages.append(AssistantMessage(content=chunks, tool_calls=tool_calls))
         elif m.role == MessageRole.SYSTEM:
-            new_messages.append(SystemMessage(content=m.content))
+            new_messages.append(SystemMessage(content=chunks))
         elif m.role == MessageRole.TOOL or m.role == MessageRole.FUNCTION:
             new_messages.append(
                 ToolMessage(
-                    content=m.content,
+                    content=chunks,
                     tool_call_id=m.additional_kwargs.get("tool_call_id"),
                     name=m.additional_kwargs.get("name"),
                 )
@@ -284,12 +318,8 @@ class MistralAI(FunctionCallingLLM):
                 if delta.tool_calls:
                     additional_kwargs["tool_calls"] = delta.tool_calls
 
-                content_delta = delta.content
-                if content_delta is None:
-                    pass
-                    # continue
-                else:
-                    content += content_delta
+                content_delta = delta.content or ""
+                content += content_delta
                 yield ChatResponse(
                     message=ChatMessage(
                         role=role,
@@ -360,12 +390,8 @@ class MistralAI(FunctionCallingLLM):
                 if delta.tool_calls:
                     additional_kwargs["tool_calls"] = delta.tool_calls
 
-                content_delta = delta.content
-                if content_delta is None:
-                    pass
-                    # continue
-                else:
-                    content += content_delta
+                content_delta = delta.content or ""
+                content += content_delta
                 yield ChatResponse(
                     message=ChatMessage(
                         role=role,
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
index a4faed40f4..d8ab76766a 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py
@@ -16,6 +16,8 @@ MISTRALAI_MODELS: Dict[str, int] = {
     "open-mistral-nemo-latest": 131000,
     "ministral-8b-latest": 131000,
     "ministral-3b-latest": 131000,
+    "pixtral-large-latest": 131000,
+    "pixtral-12b-2409": 131000,
 }
 
 MISTRALAI_FUNCTION_CALLING_MODELS = (
@@ -26,6 +28,8 @@ MISTRALAI_FUNCTION_CALLING_MODELS = (
     "mistral-small-latest",
     "codestral-latest",
     "open-mistral-nemo-latest",
+    "pixtral-large-latest",
+    "pixtral-12b-2409",
 )
 
 MISTRALAI_CODE_MODELS = "codestral-latest"
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
index a999420fe1..5214ea13ec 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-mistralai"
 readme = "README.md"
-version = "0.3.3"
+version = "0.4.0"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py
index 34b06d595f..5eea424bfc 100644
--- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py
+++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py
@@ -1,132 +1,41 @@
-from typing import Any, Callable, Dict, List, Optional, Sequence
+from typing import Any, Dict, List, Sequence
 
-import httpx
 from llama_index.core.base.llms.types import (
     CompletionResponse,
     CompletionResponseAsyncGen,
     CompletionResponseGen,
     MessageRole,
 )
-from llama_index.core.bridge.pydantic import Field, PrivateAttr
-from llama_index.core.callbacks import CallbackManager
-from llama_index.core.constants import (
-    DEFAULT_CONTEXT_WINDOW,
-    DEFAULT_NUM_OUTPUTS,
-    DEFAULT_TEMPERATURE,
-)
 from llama_index.core.base.llms.generic_utils import (
-    messages_to_prompt as generic_messages_to_prompt,
+    chat_response_to_completion_response,
+    stream_chat_response_to_completion_response,
+    astream_chat_response_to_completion_response,
 )
-from llama_index.core.multi_modal_llms import (
-    MultiModalLLM,
-    MultiModalLLMMetadata,
+from llama_index.core.base.llms.types import (
+    ChatMessage,
+    MessageRole,
+    TextBlock,
+    ImageBlock,
 )
 from llama_index.core.schema import ImageNode
-from llama_index.multi_modal_llms.mistralai.utils import (
-    MISTRALAI_MULTI_MODAL_MODELS,
-    generate_mistral_multi_modal_chat_message,
-    resolve_mistral_credentials,
-)
-
-from mistralai import Mistral
+from llama_index.llms.mistralai import MistralAI
 
 
-class MistralAIMultiModal(MultiModalLLM):
-    model: str = Field(description="The Multi-Modal model to use from Mistral.")
-    temperature: float = Field(description="The temperature to use for sampling.")
-    max_tokens: Optional[int] = Field(
-        description=" The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
-        gt=0,
-    )
-    context_window: Optional[int] = Field(
-        description="The maximum number of context tokens for the model.",
-        gt=0,
-    )
-    max_retries: int = Field(
-        default=3,
-        description="Maximum number of retries.",
-        ge=0,
-    )
-    timeout: float = Field(
-        default=60.0,
-        description="The timeout, in seconds, for API requests.",
-        ge=0,
-    )
-    api_key: str = Field(default=None, description="The Mistral API key.", exclude=True)
-    api_base: str = Field(default=None, description="The base URL for Mistral API.")
-    api_version: str = Field(description="The API version for Mistral API.")
-    additional_kwargs: Dict[str, Any] = Field(
-        default_factory=dict, description="Additional kwargs for the Mistral API."
-    )
-    default_headers: Optional[Dict[str, str]] = Field(
-        default=None, description="The default headers for API requests."
-    )
-
-    _messages_to_prompt: Callable = PrivateAttr()
-    _completion_to_prompt: Callable = PrivateAttr()
-    _client: Mistral = PrivateAttr()
-    _http_client: Optional[httpx.Client] = PrivateAttr()
-
+class MistralAIMultiModal(MistralAI):
     def __init__(
         self,
         model: str = "pixtral-12b-2409",
-        temperature: float = DEFAULT_TEMPERATURE,
-        max_tokens: Optional[int] = 300,
-        additional_kwargs: Optional[Dict[str, Any]] = None,
-        context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW,
-        max_retries: int = 3,
-        timeout: float = 60.0,
-        api_key: Optional[str] = None,
-        api_base: Optional[str] = None,
-        api_version: Optional[str] = None,
-        messages_to_prompt: Optional[Callable] = None,
-        completion_to_prompt: Optional[Callable] = None,
-        callback_manager: Optional[CallbackManager] = None,
-        default_headers: Optional[Dict[str, str]] = None,
-        http_client: Optional[httpx.Client] = None,
         **kwargs: Any,
     ) -> None:
-        api_key, api_base, api_version = resolve_mistral_credentials(
-            api_key=api_key,
-            api_base=api_base,
-            api_version=api_version,
-        )
-
         super().__init__(
             model=model,
-            temperature=temperature,
-            max_tokens=max_tokens,
-            additional_kwargs=additional_kwargs or {},
-            context_window=context_window,
-            max_retries=max_retries,
-            timeout=timeout,
-            api_key=api_key,
-            api_base=api_base,
-            api_version=api_version,
-            callback_manager=callback_manager,
-            default_headers=default_headers,
             **kwargs,
         )
-        self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
-        self._completion_to_prompt = completion_to_prompt or (lambda x: x)
-        self._http_client = http_client
-        self._client = self._get_clients(**kwargs)
-
-    def _get_clients(self, **kwargs: Any) -> Mistral:
-        return Mistral(**self._get_credential_kwargs())
 
     @classmethod
     def class_name(cls) -> str:
         return "mistral_multi_modal_llm"
 
-    @property
-    def metadata(self) -> MultiModalLLMMetadata:
-        """Multi Modal LLM metadata."""
-        return MultiModalLLMMetadata(
-            num_output=self.max_tokens or DEFAULT_NUM_OUTPUTS,
-            model_name=self.model,
-        )
-
     def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
         return {
             "api_key": self.api_key,
@@ -138,183 +47,53 @@ class MistralAIMultiModal(MultiModalLLM):
         prompt: str,
         role: str,
         image_documents: Sequence[ImageNode],
-        **kwargs: Any,
-    ) -> List[Dict]:
-        return generate_mistral_multi_modal_chat_message(
-            prompt=prompt,
-            role=role,
-            image_documents=image_documents,
-        )
-
-    # Model Params for Mistral Multi Modal model.
-    def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
-        if self.model not in MISTRALAI_MULTI_MODAL_MODELS:
-            raise ValueError(
-                f"Invalid model {self.model}. "
-                f"Available models are: {list(MISTRALAI_MULTI_MODAL_MODELS.keys())}"
+    ) -> List[ChatMessage]:
+        blocks = []
+        for image_document in image_documents:
+            blocks.append(
+                ImageBlock(
+                    image=image_document.image,
+                    path=image_document.image_path,
+                    url=image_document.image_url,
+                    image_mimetype=image_document.image_mimetype,
+                )
             )
-        base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
-        if self.max_tokens is not None:
-            base_kwargs["max_tokens"] = self.max_tokens
-        return {**base_kwargs, **self.additional_kwargs}
-
-    def _get_response_token_counts(self, raw_response: Any) -> dict:
-        """Get the token usage reported by the response."""
-        if not isinstance(raw_response, dict):
-            return {}
 
-        usage = raw_response.get("usage", {})
-        # NOTE: other model providers that use the mistral client may not report usage
-        if usage is None:
-            return {}
+        blocks.append(TextBlock(text=prompt))
+        return [ChatMessage(role=role, blocks=blocks)]
 
-        return {
-            "prompt_tokens": usage.get("prompt_tokens", 0),
-            "completion_tokens": usage.get("completion_tokens", 0),
-            "total_tokens": usage.get("total_tokens", 0),
-        }
-
-    def _complete(
+    def complete(
         self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
     ) -> CompletionResponse:
-        all_kwargs = self._get_model_kwargs(**kwargs)
-        message_dict = self._get_multi_modal_chat_messages(
+        messages = self._get_multi_modal_chat_messages(
             prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents
         )
-
-        response = self._client.chat.complete(
-            messages=message_dict,
-            stream=False,
-            **all_kwargs,
-        )
-
-        return CompletionResponse(
-            text=response.choices[0].message.content,
-            raw=response,
-            additional_kwargs=self._get_response_token_counts(response),
-        )
-
-    def _stream_complete(
-        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
-    ) -> CompletionResponseGen:
-        all_kwargs = self._get_model_kwargs(**kwargs)
-        message_dict = self._get_multi_modal_chat_messages(
-            prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents
-        )
-
-        response = self._client.chat.stream(messages=message_dict, **all_kwargs)
-
-        def gen() -> CompletionResponseGen:
-            content = ""
-            for chunk in response:
-                delta = chunk.data.choices[0].delta
-                role = delta.role or MessageRole.ASSISTANT.value
-
-                content_delta = delta.content or ""
-                if content_delta is None:
-                    pass
-                    # continue
-                else:
-                    content += content_delta
-
-                yield CompletionResponse(
-                    delta=content_delta,
-                    text=content,
-                    raw=response,
-                    additional_kwargs=self._get_response_token_counts(response),
-                )
-
-        return gen()
-
-    def complete(
-        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
-    ) -> CompletionResponse:
-        return self._complete(prompt, image_documents, **kwargs)
+        chat_response = self.chat(messages=messages, **kwargs)
+        return chat_response_to_completion_response(chat_response)
 
     def stream_complete(
         self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
     ) -> CompletionResponseGen:
-        return self._stream_complete(prompt, image_documents, **kwargs)
-
-    def chat(
-        self,
-        **kwargs: Any,
-    ) -> Any:
-        raise NotImplementedError("This function is not yet implemented.")
-
-    def stream_chat(
-        self,
-        **kwargs: Any,
-    ) -> Any:
-        raise NotImplementedError("This function is not yet implemented.")
-
-    # ===== Async Endpoints =====
-
-    async def _acomplete(
-        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
-    ) -> CompletionResponse:
-        all_kwargs = self._get_model_kwargs(**kwargs)
-        message_dict = self._get_multi_modal_chat_messages(
+        messages = self._get_multi_modal_chat_messages(
             prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents
         )
-
-        response = await self._client.chat.complete_async(
-            messages=message_dict,
-            stream=False,
-            **all_kwargs,
-        )
-
-        return CompletionResponse(
-            text=response.choices[0].message.content,
-            raw=response,
-            additional_kwargs=self._get_response_token_counts(response),
-        )
+        chat_response = self.stream_chat(messages=messages, **kwargs)
+        return stream_chat_response_to_completion_response(chat_response)
 
     async def acomplete(
         self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
     ) -> CompletionResponse:
-        return await self._acomplete(prompt, image_documents, **kwargs)
-
-    async def _astream_complete(
-        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
-    ) -> CompletionResponseAsyncGen:
-        all_kwargs = self._get_model_kwargs(**kwargs)
-        message_dict = self._get_multi_modal_chat_messages(
+        messages = self._get_multi_modal_chat_messages(
             prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents
         )
-
-        response = await self._client.chat.stream_async(
-            messages=message_dict, **all_kwargs
-        )
-
-        async def gen() -> CompletionResponseAsyncGen:
-            content = ""
-            async for chunk in response:
-                delta = chunk.data.choices[0].delta
-                role = delta.role or MessageRole.ASSISTANT.value
-
-                content_delta = delta.content
-                if content_delta is None:
-                    pass
-                    # continue
-                else:
-                    content += content_delta
-                yield CompletionResponse(
-                    delta=content_delta,
-                    text=content,
-                    raw=response,
-                    additional_kwargs=self._get_response_token_counts(response),
-                )
-
-        return gen()
+        chat_response = await self.achat(messages=messages, **kwargs)
+        return chat_response_to_completion_response(chat_response)
 
     async def astream_complete(
         self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
     ) -> CompletionResponseAsyncGen:
-        return await self._astream_complete(prompt, image_documents, **kwargs)
-
-    async def achat(self, **kwargs: Any) -> Any:
-        raise NotImplementedError("This function is not yet implemented.")
-
-    async def astream_chat(self, **kwargs: Any) -> Any:
-        raise NotImplementedError("This function is not yet implemented.")
+        messages = self._get_multi_modal_chat_messages(
+            prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents
+        )
+        chat_response = await self.astream_chat(messages=messages, **kwargs)
+        return astream_chat_response_to_completion_response(chat_response)
diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py
deleted file mode 100644
index e71b72b07d..0000000000
--- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import base64
-import logging
-from typing import Any, Dict, List, Optional, Sequence, Tuple
-import filetype
-
-from llama_index.core.base.llms.generic_utils import get_from_param_or_env
-from llama_index.core.multi_modal_llms.generic_utils import encode_image
-from llama_index.core.schema import ImageDocument
-
-DEFAULT_MISTRALAI_API_TYPE = "mistral_ai"
-DEFAULT_MISTRALAI_API_BASE = "https://api.mistral.ai/"
-DEFAULT_MISTRALAI_API_VERSION = ""
-
-
-MISTRALAI_MULTI_MODAL_MODELS = {
-    "pixtral-12b-2409": 128000,
-    "pixtral-large-latest": 128000,
-}
-
-
-MISSING_API_KEY_ERROR_MESSAGE = """No API key found for Mistral.
-Please set either the MISTRAL_API_KEY environment variable \
-API keys can be found or created at \
-https://console.mistral.ai/api-keys/
-"""
-
-logger = logging.getLogger(__name__)
-
-
-def infer_image_mimetype_from_base64(base64_string) -> str:
-    # Decode the base64 string
-    decoded_data = base64.b64decode(base64_string)
-
-    # Use filetype to guess the MIME type
-    kind = filetype.guess(decoded_data)
-
-    # Return the MIME type if detected, otherwise return None
-    return kind.mime if kind is not None else None
-
-
-def infer_image_mimetype_from_file_path(image_file_path: str) -> str:
-    # Get the file extension
-    file_extension = image_file_path.split(".")[-1].lower()
-
-    # Map file extensions to mimetypes
-    # Pixtral support the base64 source type for images, and the image/jpeg, image/png, image/gif, and image/webp media types.
-    # https://docs.mistral.ai/capabilities/vision/
-    if file_extension == "jpg" or file_extension == "jpeg":
-        return "image/jpeg"
-    elif file_extension == "png":
-        return "image/png"
-    elif file_extension == "gif":
-        return "image/gif"
-    elif file_extension == "webp":
-        return "image/webp"
-    # Add more mappings for other image types if needed
-
-    # If the file extension is not recognized
-    return "image/jpeg"
-
-
-def generate_mistral_multi_modal_chat_message(
-    prompt: str,
-    role: str,
-    image_documents: Optional[Sequence[ImageDocument]] = None,
-) -> List[Dict[str, Any]]:
-    # if image_documents is empty, return text only chat message
-    if image_documents is None:
-        return [{"role": role, "content": prompt}]
-
-    # if image_documents is not empty, return text with images chat message
-    completion_content = []
-    for image_document in image_documents:
-        image_content: Dict[str, Any] = {}
-        if image_document.image_path and image_document.image_path != "":
-            mimetype = infer_image_mimetype_from_file_path(image_document.image_path)
-            base64_image = encode_image(image_document.image_path)
-            image_content = {
-                "type": "image_url",
-                "image_url": f"data:{mimetype};base64,{base64_image}",
-            }
-        elif (
-            "file_path" in image_document.metadata
-            and image_document.metadata["file_path"] != ""
-        ):
-            mimetype = infer_image_mimetype_from_file_path(
-                image_document.metadata["file_path"]
-            )
-            base64_image = encode_image(image_document.metadata["file_path"])
-            image_content = {
-                "type": "image_url",
-                "image_url": f"data:{mimetype};base64,{base64_image}",
-            }
-        elif image_document.image_url and image_document.image_url != "":
-            mimetype = infer_image_mimetype_from_file_path(image_document.image_url)
-            image_content = {
-                "type": "image_url",
-                "image_url": f"{image_document.image_url}",
-            }
-        elif image_document.image != "":
-            base64_image = image_document.image
-            mimetype = infer_image_mimetype_from_base64(base64_image)
-            image_content = {
-                "type": "image_url",
-                "image_url": f"data:{mimetype};base64,{base64_image}",
-            }
-        completion_content.append(image_content)
-
-    completion_content.append({"type": "text", "text": prompt})
-
-    return [{"role": role, "content": completion_content}]
-
-
-def resolve_mistral_credentials(
-    api_key: Optional[str] = None,
-    api_base: Optional[str] = None,
-    api_version: Optional[str] = None,
-) -> Tuple[Optional[str], str, str]:
-    """
-    "Resolve Mistral credentials.
-
-    The order of precedence is:
-    1. param
-    2. env
-    3. mistral module
-    4. default
-    """
-    # resolve from param or env
-    api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
-    api_base = get_from_param_or_env("api_base", api_base, "MISTRAL_API_BASE", "")
-    api_version = get_from_param_or_env(
-        "api_version", api_version, "MISTRAL_API_VERSION", ""
-    )
-
-    # resolve from Mistral module or default
-    final_api_key = api_key or ""
-    final_api_base = api_base or DEFAULT_MISTRALAI_API_BASE
-    final_api_version = api_version or DEFAULT_MISTRALAI_API_VERSION
-
-    return final_api_key, str(final_api_base), final_api_version
diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml
index 4c471028ed..118a769aec 100644
--- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml
+++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml
@@ -27,13 +27,11 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-multi-modal-llms-mistralai"
 readme = "README.md"
-version = "0.3.1"
+version = "0.4.0"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
-mistralai = ">=1.2.3"
-llama-index-core = "^0.12.0"
-filetype = "^1.2.0"
+llama-index-llms-mistralai = "^0.4.0"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py
index 156d862b08..16538e7dad 100644
--- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py
+++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py
@@ -1,12 +1,12 @@
-from llama_index.core.multi_modal_llms.base import MultiModalLLM
+from llama_index.llms.mistralai import MistralAI
 from llama_index.multi_modal_llms.mistralai import MistralAIMultiModal
 
 
 def test_embedding_class():
     names_of_base_classes = [b.__name__ for b in MistralAIMultiModal.__mro__]
-    assert MultiModalLLM.__name__ in names_of_base_classes
+    assert MistralAI.__name__ in names_of_base_classes
 
 
 def test_init():
-    m = MistralAIMultiModal(max_tokens=400)
+    m = MistralAIMultiModal(max_tokens=400, api_key="test")
     assert m.max_tokens == 400
-- 
GitLab