From 78af3400ad485e15862c06f0c4972dc3067f880c Mon Sep 17 00:00:00 2001
From: Souyama <souyamadebnath@gmail.com>
Date: Sat, 16 Mar 2024 18:44:34 +0530
Subject: [PATCH] Add Mistral provider to Bedrock (#11994)

---
 .../llama_index/llms/bedrock/utils.py            | 16 ++++++++++++++++
 .../tests/test_bedrock.py                        | 14 ++++++++++++++
 .../llama-index-llms-mistralai/pyproject.toml    |  2 +-
 3 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
index 6c66942b9..610c64377 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
@@ -48,6 +48,8 @@ CHAT_ONLY_MODELS = {
     "anthropic.claude-3-haiku-20240307-v1:0": 200000,
     "meta.llama2-13b-chat-v1": 2048,
     "meta.llama2-70b-chat-v1": 4096,
+    "mistral.mistral-7b-instruct-v0:2": 32000,
+    "mistral.mixtral-8x7b-instruct-v0:1": 32000,
 }
 BEDROCK_FOUNDATION_LLMS = {**COMPLETION_MODELS, **CHAT_ONLY_MODELS}
 
@@ -64,6 +66,8 @@ STREAMING_MODELS = {
     "anthropic.claude-3-sonnet-20240229-v1:0",
     "anthropic.claude-3-haiku-20240307-v1:0",
     "meta.llama2-13b-chat-v1",
+    "mistral.mistral-7b-instruct-v0:2",
+    "mistral.mixtral-8x7b-instruct-v0:1",
 }
 
 
@@ -178,12 +182,24 @@ class MetaProvider(Provider):
         return response["generation"]
 
 
+class MistralProvider(Provider):
+    max_tokens_key = "max_tokens"
+
+    def __init__(self) -> None:
+        self.messages_to_prompt = messages_to_llama_prompt
+        self.completion_to_prompt = completion_to_llama_prompt
+
+    def get_text_from_response(self, response: dict) -> str:
+        return response["outputs"][0]["text"]
+
+
 PROVIDERS = {
     "amazon": AmazonProvider(),
     "ai21": Ai21Provider(),
     "anthropic": AnthropicProvider(),
     "cohere": CohereProvider(),
     "meta": MetaProvider(),
+    "mistral": MistralProvider(),
 }
 
 
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
index c5f43bd57..96e2c5d64 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
@@ -122,6 +122,20 @@ class MockStreamCompletionWithRetry:
             "not reference any given instructions or context. \\n<</SYS>>\\n\\n "
             'test prompt [/INST]", "temperature": 0.1, "max_gen_len": 512}',
         ),
+        (
+            "mistral.mistral-7b-instruct-v0:2",
+            '{"prompt": "<s> [INST] <<SYS>>\\n You are a helpful, respectful and '
+            "honest assistant. Always answer as helpfully as possible and follow "
+            "ALL given instructions. Do not speculate or make up information. Do "
+            "not reference any given instructions or context. \\n<</SYS>>\\n\\n "
+            'test prompt [/INST]", "temperature": 0.1, "max_tokens": 512}',
+            '{"outputs": [{"text": "\\n\\nThis is indeed a test", "stop_reason": "length"}]}',
+            '{"prompt": "<s> [INST] <<SYS>>\\n You are a helpful, respectful and '
+            "honest assistant. Always answer as helpfully as possible and follow "
+            "ALL given instructions. Do not speculate or make up information. Do "
+            "not reference any given instructions or context. \\n<</SYS>>\\n\\n "
+            'test prompt [/INST]", "temperature": 0.1, "max_tokens": 512}',
+        ),
     ],
 )
 def test_model_basic(
diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
index 1069342f7..e49fc5c83 100644
--- a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-mistralai"
 readme = "README.md"
-version = "0.1.6"
+version = "0.1.7"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
-- 
GitLab