From 78af3400ad485e15862c06f0c4972dc3067f880c Mon Sep 17 00:00:00 2001 From: Souyama <souyamadebnath@gmail.com> Date: Sat, 16 Mar 2024 18:44:34 +0530 Subject: [PATCH] Add Mistral provider to Bedrock (#11994) --- .../llama_index/llms/bedrock/utils.py | 16 ++++++++++++++++ .../tests/test_bedrock.py | 14 ++++++++++++++ .../llama-index-llms-mistralai/pyproject.toml | 2 +- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py index 6c66942b9..610c64377 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py @@ -48,6 +48,8 @@ CHAT_ONLY_MODELS = { "anthropic.claude-3-haiku-20240307-v1:0": 200000, "meta.llama2-13b-chat-v1": 2048, "meta.llama2-70b-chat-v1": 4096, + "mistral.mistral-7b-instruct-v0:2": 32000, + "mistral.mixtral-8x7b-instruct-v0:1": 32000, } BEDROCK_FOUNDATION_LLMS = {**COMPLETION_MODELS, **CHAT_ONLY_MODELS} @@ -64,6 +66,8 @@ STREAMING_MODELS = { "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-haiku-20240307-v1:0", "meta.llama2-13b-chat-v1", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", } @@ -178,12 +182,24 @@ class MetaProvider(Provider): return response["generation"] +class MistralProvider(Provider): + max_tokens_key = "max_tokens" + + def __init__(self) -> None: + self.messages_to_prompt = messages_to_llama_prompt + self.completion_to_prompt = completion_to_llama_prompt + + def get_text_from_response(self, response: dict) -> str: + return response["outputs"][0]["text"] + + PROVIDERS = { "amazon": AmazonProvider(), "ai21": Ai21Provider(), "anthropic": AnthropicProvider(), "cohere": CohereProvider(), "meta": MetaProvider(), + "mistral": MistralProvider(), } diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py index c5f43bd57..96e2c5d64 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py @@ -122,6 +122,20 @@ class MockStreamCompletionWithRetry: "not reference any given instructions or context. \\n<</SYS>>\\n\\n " 'test prompt [/INST]", "temperature": 0.1, "max_gen_len": 512}', ), + ( + "mistral.mistral-7b-instruct-v0:2", + '{"prompt": "<s> [INST] <<SYS>>\\n You are a helpful, respectful and ' + "honest assistant. Always answer as helpfully as possible and follow " + "ALL given instructions. Do not speculate or make up information. Do " + "not reference any given instructions or context. \\n<</SYS>>\\n\\n " + 'test prompt [/INST]", "temperature": 0.1, "max_tokens": 512}', + '{"outputs": [{"text": "\\n\\nThis is indeed a test", "stop_reason": "length"}]}', + '{"prompt": "<s> [INST] <<SYS>>\\n You are a helpful, respectful and ' + "honest assistant. Always answer as helpfully as possible and follow " + "ALL given instructions. Do not speculate or make up information. Do " + "not reference any given instructions or context. \\n<</SYS>>\\n\\n " + 'test prompt [/INST]", "temperature": 0.1, "max_tokens": 512}', + ), ], ) def test_model_basic( diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml index 1069342f7..e49fc5c83 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-mistralai" readme = "README.md" -version = "0.1.6" +version = "0.1.7" [tool.poetry.dependencies] python = ">=3.9,<4.0" -- GitLab