From e8f73a4b4791b55198788d295325887ad0c58641 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Mon, 3 Mar 2025 13:34:27 -0600 Subject: [PATCH] MistralAI MultImodal content blocks (#17997) --- .../llama_index/core/base/llms/types.py | 15 +- .../llama_index/llms/mistralai/base.py | 62 ++-- .../llama_index/llms/mistralai/utils.py | 4 + .../llama-index-llms-mistralai/pyproject.toml | 2 +- .../multi_modal_llms/mistralai/base.py | 297 +++--------------- .../multi_modal_llms/mistralai/utils.py | 140 --------- .../pyproject.toml | 6 +- .../tests/test_multi-modal-llms_mistral.py | 6 +- 8 files changed, 106 insertions(+), 426 deletions(-) delete mode 100644 llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 6fa3863210..52c9f03323 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -1,9 +1,11 @@ from __future__ import annotations import base64 +import filetype from binascii import Error as BinasciiError from enum import Enum from io import BytesIO +from pathlib import Path from typing import ( Annotated, Any, @@ -61,10 +63,13 @@ class ImageBlock(BaseModel): @field_validator("url", mode="after") @classmethod - def urlstr_to_anyurl(cls, url: str | AnyUrl) -> AnyUrl: + def urlstr_to_anyurl(cls, url: str | AnyUrl | None) -> AnyUrl | None: """Store the url as Anyurl.""" if isinstance(url, AnyUrl): return url + if url is None: + return None + return AnyUrl(url=url) @model_validator(mode="after") @@ -76,6 +81,14 @@ class ImageBlock(BaseModel): operations, we won't load the path or the URL to guess the mimetype. """ if not self.image: + if not self.image_mimetype: + path = self.path or self.url + if path: + suffix = Path(str(path)).suffix.replace(".", "") or None + mimetype = filetype.get_type(ext=suffix) + if mimetype and str(mimetype.mime).startswith("image/"): + self.image_mimetype = str(mimetype.mime) + return self try: diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 822723ce53..929914b05c 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -6,11 +6,14 @@ from llama_index.core.base.llms.types import ( ChatResponse, ChatResponseAsyncGen, ChatResponseGen, + ContentBlock, CompletionResponse, CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, MessageRole, + TextBlock, + ImageBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager @@ -43,6 +46,9 @@ from mistralai.models import ( SystemMessage, ToolMessage, UserMessage, + TextChunk, + ImageURLChunk, + ContentChunk, ) if TYPE_CHECKING: @@ -53,24 +59,52 @@ DEFAULT_MISTRALAI_ENDPOINT = "https://api.mistral.ai" DEFAULT_MISTRALAI_MAX_TOKENS = 512 +def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[ContentChunk]: + content_chunks = [] + for content_block in content_blocks: + if isinstance(content_block, TextBlock): + content_chunks.append(TextChunk(text=content_block.text)) + elif isinstance(content_block, ImageBlock): + if content_block.url: + content_chunks.append(ImageURLChunk(url=content_block.url)) + else: + base_64_str = ( + content_block.resolve_image(as_base64=True).read().decode("utf-8") + ) + image_mimetype = content_block.image_mimetype + if not image_mimetype: + raise ValueError( + "Image mimetype not found in chat message image block" + ) + + content_chunks.append( + ImageURLChunk( + image_url=f"data:{image_mimetype};base64,{base_64_str}" + ) + ) + else: + raise ValueError(f"Unsupported content block type {type(content_block)}") + + return content_chunks + + def to_mistral_chatmessage( messages: Sequence[ChatMessage], ) -> List[Messages]: new_messages = [] for m in messages: tool_calls = m.additional_kwargs.get("tool_calls") + chunks = to_mistral_chunks(m.blocks) if m.role == MessageRole.USER: - new_messages.append(UserMessage(content=m.content)) + new_messages.append(UserMessage(content=chunks)) elif m.role == MessageRole.ASSISTANT: - new_messages.append( - AssistantMessage(content=m.content, tool_calls=tool_calls) - ) + new_messages.append(AssistantMessage(content=chunks, tool_calls=tool_calls)) elif m.role == MessageRole.SYSTEM: - new_messages.append(SystemMessage(content=m.content)) + new_messages.append(SystemMessage(content=chunks)) elif m.role == MessageRole.TOOL or m.role == MessageRole.FUNCTION: new_messages.append( ToolMessage( - content=m.content, + content=chunks, tool_call_id=m.additional_kwargs.get("tool_call_id"), name=m.additional_kwargs.get("name"), ) @@ -284,12 +318,8 @@ class MistralAI(FunctionCallingLLM): if delta.tool_calls: additional_kwargs["tool_calls"] = delta.tool_calls - content_delta = delta.content - if content_delta is None: - pass - # continue - else: - content += content_delta + content_delta = delta.content or "" + content += content_delta yield ChatResponse( message=ChatMessage( role=role, @@ -360,12 +390,8 @@ class MistralAI(FunctionCallingLLM): if delta.tool_calls: additional_kwargs["tool_calls"] = delta.tool_calls - content_delta = delta.content - if content_delta is None: - pass - # continue - else: - content += content_delta + content_delta = delta.content or "" + content += content_delta yield ChatResponse( message=ChatMessage( role=role, diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py index a4faed40f4..d8ab76766a 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/utils.py @@ -16,6 +16,8 @@ MISTRALAI_MODELS: Dict[str, int] = { "open-mistral-nemo-latest": 131000, "ministral-8b-latest": 131000, "ministral-3b-latest": 131000, + "pixtral-large-latest": 131000, + "pixtral-12b-2409": 131000, } MISTRALAI_FUNCTION_CALLING_MODELS = ( @@ -26,6 +28,8 @@ MISTRALAI_FUNCTION_CALLING_MODELS = ( "mistral-small-latest", "codestral-latest", "open-mistral-nemo-latest", + "pixtral-large-latest", + "pixtral-12b-2409", ) MISTRALAI_CODE_MODELS = "codestral-latest" diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml index a999420fe1..5214ea13ec 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-mistralai" readme = "README.md" -version = "0.3.3" +version = "0.4.0" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py index 34b06d595f..5eea424bfc 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/base.py @@ -1,132 +1,41 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Sequence -import httpx from llama_index.core.base.llms.types import ( CompletionResponse, CompletionResponseAsyncGen, CompletionResponseGen, MessageRole, ) -from llama_index.core.bridge.pydantic import Field, PrivateAttr -from llama_index.core.callbacks import CallbackManager -from llama_index.core.constants import ( - DEFAULT_CONTEXT_WINDOW, - DEFAULT_NUM_OUTPUTS, - DEFAULT_TEMPERATURE, -) from llama_index.core.base.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, + chat_response_to_completion_response, + stream_chat_response_to_completion_response, + astream_chat_response_to_completion_response, ) -from llama_index.core.multi_modal_llms import ( - MultiModalLLM, - MultiModalLLMMetadata, +from llama_index.core.base.llms.types import ( + ChatMessage, + MessageRole, + TextBlock, + ImageBlock, ) from llama_index.core.schema import ImageNode -from llama_index.multi_modal_llms.mistralai.utils import ( - MISTRALAI_MULTI_MODAL_MODELS, - generate_mistral_multi_modal_chat_message, - resolve_mistral_credentials, -) - -from mistralai import Mistral +from llama_index.llms.mistralai import MistralAI -class MistralAIMultiModal(MultiModalLLM): - model: str = Field(description="The Multi-Modal model to use from Mistral.") - temperature: float = Field(description="The temperature to use for sampling.") - max_tokens: Optional[int] = Field( - description=" The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt", - gt=0, - ) - context_window: Optional[int] = Field( - description="The maximum number of context tokens for the model.", - gt=0, - ) - max_retries: int = Field( - default=3, - description="Maximum number of retries.", - ge=0, - ) - timeout: float = Field( - default=60.0, - description="The timeout, in seconds, for API requests.", - ge=0, - ) - api_key: str = Field(default=None, description="The Mistral API key.", exclude=True) - api_base: str = Field(default=None, description="The base URL for Mistral API.") - api_version: str = Field(description="The API version for Mistral API.") - additional_kwargs: Dict[str, Any] = Field( - default_factory=dict, description="Additional kwargs for the Mistral API." - ) - default_headers: Optional[Dict[str, str]] = Field( - default=None, description="The default headers for API requests." - ) - - _messages_to_prompt: Callable = PrivateAttr() - _completion_to_prompt: Callable = PrivateAttr() - _client: Mistral = PrivateAttr() - _http_client: Optional[httpx.Client] = PrivateAttr() - +class MistralAIMultiModal(MistralAI): def __init__( self, model: str = "pixtral-12b-2409", - temperature: float = DEFAULT_TEMPERATURE, - max_tokens: Optional[int] = 300, - additional_kwargs: Optional[Dict[str, Any]] = None, - context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW, - max_retries: int = 3, - timeout: float = 60.0, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, - callback_manager: Optional[CallbackManager] = None, - default_headers: Optional[Dict[str, str]] = None, - http_client: Optional[httpx.Client] = None, **kwargs: Any, ) -> None: - api_key, api_base, api_version = resolve_mistral_credentials( - api_key=api_key, - api_base=api_base, - api_version=api_version, - ) - super().__init__( model=model, - temperature=temperature, - max_tokens=max_tokens, - additional_kwargs=additional_kwargs or {}, - context_window=context_window, - max_retries=max_retries, - timeout=timeout, - api_key=api_key, - api_base=api_base, - api_version=api_version, - callback_manager=callback_manager, - default_headers=default_headers, **kwargs, ) - self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - self._completion_to_prompt = completion_to_prompt or (lambda x: x) - self._http_client = http_client - self._client = self._get_clients(**kwargs) - - def _get_clients(self, **kwargs: Any) -> Mistral: - return Mistral(**self._get_credential_kwargs()) @classmethod def class_name(cls) -> str: return "mistral_multi_modal_llm" - @property - def metadata(self) -> MultiModalLLMMetadata: - """Multi Modal LLM metadata.""" - return MultiModalLLMMetadata( - num_output=self.max_tokens or DEFAULT_NUM_OUTPUTS, - model_name=self.model, - ) - def _get_credential_kwargs(self, **kwargs: Any) -> Dict[str, Any]: return { "api_key": self.api_key, @@ -138,183 +47,53 @@ class MistralAIMultiModal(MultiModalLLM): prompt: str, role: str, image_documents: Sequence[ImageNode], - **kwargs: Any, - ) -> List[Dict]: - return generate_mistral_multi_modal_chat_message( - prompt=prompt, - role=role, - image_documents=image_documents, - ) - - # Model Params for Mistral Multi Modal model. - def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: - if self.model not in MISTRALAI_MULTI_MODAL_MODELS: - raise ValueError( - f"Invalid model {self.model}. " - f"Available models are: {list(MISTRALAI_MULTI_MODAL_MODELS.keys())}" + ) -> List[ChatMessage]: + blocks = [] + for image_document in image_documents: + blocks.append( + ImageBlock( + image=image_document.image, + path=image_document.image_path, + url=image_document.image_url, + image_mimetype=image_document.image_mimetype, + ) ) - base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs} - if self.max_tokens is not None: - base_kwargs["max_tokens"] = self.max_tokens - return {**base_kwargs, **self.additional_kwargs} - - def _get_response_token_counts(self, raw_response: Any) -> dict: - """Get the token usage reported by the response.""" - if not isinstance(raw_response, dict): - return {} - usage = raw_response.get("usage", {}) - # NOTE: other model providers that use the mistral client may not report usage - if usage is None: - return {} + blocks.append(TextBlock(text=prompt)) + return [ChatMessage(role=role, blocks=blocks)] - return { - "prompt_tokens": usage.get("prompt_tokens", 0), - "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - } - - def _complete( + def complete( self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any ) -> CompletionResponse: - all_kwargs = self._get_model_kwargs(**kwargs) - message_dict = self._get_multi_modal_chat_messages( + messages = self._get_multi_modal_chat_messages( prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents ) - - response = self._client.chat.complete( - messages=message_dict, - stream=False, - **all_kwargs, - ) - - return CompletionResponse( - text=response.choices[0].message.content, - raw=response, - additional_kwargs=self._get_response_token_counts(response), - ) - - def _stream_complete( - self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any - ) -> CompletionResponseGen: - all_kwargs = self._get_model_kwargs(**kwargs) - message_dict = self._get_multi_modal_chat_messages( - prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents - ) - - response = self._client.chat.stream(messages=message_dict, **all_kwargs) - - def gen() -> CompletionResponseGen: - content = "" - for chunk in response: - delta = chunk.data.choices[0].delta - role = delta.role or MessageRole.ASSISTANT.value - - content_delta = delta.content or "" - if content_delta is None: - pass - # continue - else: - content += content_delta - - yield CompletionResponse( - delta=content_delta, - text=content, - raw=response, - additional_kwargs=self._get_response_token_counts(response), - ) - - return gen() - - def complete( - self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any - ) -> CompletionResponse: - return self._complete(prompt, image_documents, **kwargs) + chat_response = self.chat(messages=messages, **kwargs) + return chat_response_to_completion_response(chat_response) def stream_complete( self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any ) -> CompletionResponseGen: - return self._stream_complete(prompt, image_documents, **kwargs) - - def chat( - self, - **kwargs: Any, - ) -> Any: - raise NotImplementedError("This function is not yet implemented.") - - def stream_chat( - self, - **kwargs: Any, - ) -> Any: - raise NotImplementedError("This function is not yet implemented.") - - # ===== Async Endpoints ===== - - async def _acomplete( - self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any - ) -> CompletionResponse: - all_kwargs = self._get_model_kwargs(**kwargs) - message_dict = self._get_multi_modal_chat_messages( + messages = self._get_multi_modal_chat_messages( prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents ) - - response = await self._client.chat.complete_async( - messages=message_dict, - stream=False, - **all_kwargs, - ) - - return CompletionResponse( - text=response.choices[0].message.content, - raw=response, - additional_kwargs=self._get_response_token_counts(response), - ) + chat_response = self.stream_chat(messages=messages, **kwargs) + return stream_chat_response_to_completion_response(chat_response) async def acomplete( self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any ) -> CompletionResponse: - return await self._acomplete(prompt, image_documents, **kwargs) - - async def _astream_complete( - self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any - ) -> CompletionResponseAsyncGen: - all_kwargs = self._get_model_kwargs(**kwargs) - message_dict = self._get_multi_modal_chat_messages( + messages = self._get_multi_modal_chat_messages( prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents ) - - response = await self._client.chat.stream_async( - messages=message_dict, **all_kwargs - ) - - async def gen() -> CompletionResponseAsyncGen: - content = "" - async for chunk in response: - delta = chunk.data.choices[0].delta - role = delta.role or MessageRole.ASSISTANT.value - - content_delta = delta.content - if content_delta is None: - pass - # continue - else: - content += content_delta - yield CompletionResponse( - delta=content_delta, - text=content, - raw=response, - additional_kwargs=self._get_response_token_counts(response), - ) - - return gen() + chat_response = await self.achat(messages=messages, **kwargs) + return chat_response_to_completion_response(chat_response) async def astream_complete( self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any ) -> CompletionResponseAsyncGen: - return await self._astream_complete(prompt, image_documents, **kwargs) - - async def achat(self, **kwargs: Any) -> Any: - raise NotImplementedError("This function is not yet implemented.") - - async def astream_chat(self, **kwargs: Any) -> Any: - raise NotImplementedError("This function is not yet implemented.") + messages = self._get_multi_modal_chat_messages( + prompt=prompt, role=MessageRole.USER.value, image_documents=image_documents + ) + chat_response = await self.astream_chat(messages=messages, **kwargs) + return astream_chat_response_to_completion_response(chat_response) diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py deleted file mode 100644 index e71b72b07d..0000000000 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/llama_index/multi_modal_llms/mistralai/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -import base64 -import logging -from typing import Any, Dict, List, Optional, Sequence, Tuple -import filetype - -from llama_index.core.base.llms.generic_utils import get_from_param_or_env -from llama_index.core.multi_modal_llms.generic_utils import encode_image -from llama_index.core.schema import ImageDocument - -DEFAULT_MISTRALAI_API_TYPE = "mistral_ai" -DEFAULT_MISTRALAI_API_BASE = "https://api.mistral.ai/" -DEFAULT_MISTRALAI_API_VERSION = "" - - -MISTRALAI_MULTI_MODAL_MODELS = { - "pixtral-12b-2409": 128000, - "pixtral-large-latest": 128000, -} - - -MISSING_API_KEY_ERROR_MESSAGE = """No API key found for Mistral. -Please set either the MISTRAL_API_KEY environment variable \ -API keys can be found or created at \ -https://console.mistral.ai/api-keys/ -""" - -logger = logging.getLogger(__name__) - - -def infer_image_mimetype_from_base64(base64_string) -> str: - # Decode the base64 string - decoded_data = base64.b64decode(base64_string) - - # Use filetype to guess the MIME type - kind = filetype.guess(decoded_data) - - # Return the MIME type if detected, otherwise return None - return kind.mime if kind is not None else None - - -def infer_image_mimetype_from_file_path(image_file_path: str) -> str: - # Get the file extension - file_extension = image_file_path.split(".")[-1].lower() - - # Map file extensions to mimetypes - # Pixtral support the base64 source type for images, and the image/jpeg, image/png, image/gif, and image/webp media types. - # https://docs.mistral.ai/capabilities/vision/ - if file_extension == "jpg" or file_extension == "jpeg": - return "image/jpeg" - elif file_extension == "png": - return "image/png" - elif file_extension == "gif": - return "image/gif" - elif file_extension == "webp": - return "image/webp" - # Add more mappings for other image types if needed - - # If the file extension is not recognized - return "image/jpeg" - - -def generate_mistral_multi_modal_chat_message( - prompt: str, - role: str, - image_documents: Optional[Sequence[ImageDocument]] = None, -) -> List[Dict[str, Any]]: - # if image_documents is empty, return text only chat message - if image_documents is None: - return [{"role": role, "content": prompt}] - - # if image_documents is not empty, return text with images chat message - completion_content = [] - for image_document in image_documents: - image_content: Dict[str, Any] = {} - if image_document.image_path and image_document.image_path != "": - mimetype = infer_image_mimetype_from_file_path(image_document.image_path) - base64_image = encode_image(image_document.image_path) - image_content = { - "type": "image_url", - "image_url": f"data:{mimetype};base64,{base64_image}", - } - elif ( - "file_path" in image_document.metadata - and image_document.metadata["file_path"] != "" - ): - mimetype = infer_image_mimetype_from_file_path( - image_document.metadata["file_path"] - ) - base64_image = encode_image(image_document.metadata["file_path"]) - image_content = { - "type": "image_url", - "image_url": f"data:{mimetype};base64,{base64_image}", - } - elif image_document.image_url and image_document.image_url != "": - mimetype = infer_image_mimetype_from_file_path(image_document.image_url) - image_content = { - "type": "image_url", - "image_url": f"{image_document.image_url}", - } - elif image_document.image != "": - base64_image = image_document.image - mimetype = infer_image_mimetype_from_base64(base64_image) - image_content = { - "type": "image_url", - "image_url": f"data:{mimetype};base64,{base64_image}", - } - completion_content.append(image_content) - - completion_content.append({"type": "text", "text": prompt}) - - return [{"role": role, "content": completion_content}] - - -def resolve_mistral_credentials( - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, -) -> Tuple[Optional[str], str, str]: - """ - "Resolve Mistral credentials. - - The order of precedence is: - 1. param - 2. env - 3. mistral module - 4. default - """ - # resolve from param or env - api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "") - api_base = get_from_param_or_env("api_base", api_base, "MISTRAL_API_BASE", "") - api_version = get_from_param_or_env( - "api_version", api_version, "MISTRAL_API_VERSION", "" - ) - - # resolve from Mistral module or default - final_api_key = api_key or "" - final_api_base = api_base or DEFAULT_MISTRALAI_API_BASE - final_api_version = api_version or DEFAULT_MISTRALAI_API_VERSION - - return final_api_key, str(final_api_base), final_api_version diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml index 4c471028ed..118a769aec 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/pyproject.toml @@ -27,13 +27,11 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-multi-modal-llms-mistralai" readme = "README.md" -version = "0.3.1" +version = "0.4.0" [tool.poetry.dependencies] python = ">=3.9,<4.0" -mistralai = ">=1.2.3" -llama-index-core = "^0.12.0" -filetype = "^1.2.0" +llama-index-llms-mistralai = "^0.4.0" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py index 156d862b08..16538e7dad 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-mistralai/tests/test_multi-modal-llms_mistral.py @@ -1,12 +1,12 @@ -from llama_index.core.multi_modal_llms.base import MultiModalLLM +from llama_index.llms.mistralai import MistralAI from llama_index.multi_modal_llms.mistralai import MistralAIMultiModal def test_embedding_class(): names_of_base_classes = [b.__name__ for b in MistralAIMultiModal.__mro__] - assert MultiModalLLM.__name__ in names_of_base_classes + assert MistralAI.__name__ in names_of_base_classes def test_init(): - m = MistralAIMultiModal(max_tokens=400) + m = MistralAIMultiModal(max_tokens=400, api_key="test") assert m.max_tokens == 400 -- GitLab