Skip to content
Snippets Groups Projects
Unverified Commit f04846d9 authored by Kurtis Massey's avatar Kurtis Massey
Browse files

Increase test coverage for bedrock cohere

parent d68f685f
No related branches found
No related tags found
No related merge requests found
......@@ -88,3 +88,79 @@ class TestBedrockEncoder:
)
with pytest.raises(ValueError):
bedrock_encoder(["test"])
def test_raises_value_error_if_no_aws_session_credentials(self, mocker):
mocker.patch("boto3.Session")
mock_session = mocker.Mock()
mock_session.get_credentials.return_value = None
with pytest.raises(ValueError, match="Could not get AWS session"):
BedrockEncoder(session=mock_session)
def test_raises_value_error_if_no_aws_region(self, mocker):
mocker.patch("boto3.Session")
mock_session = mocker.Mock()
mock_session.region_name = None
with pytest.raises(ValueError, match="No AWS region provided"):
BedrockEncoder(session=mock_session)
def test_raises_value_error_if_client_initialisation_fails(self, mocker):
mocker.patch("boto3.Session")
mock_session = mocker.Mock()
mock_session.client.side_effect = Exception("Client initialisation failed")
with pytest.raises(ValueError, match="Bedrock client failed to initialise"):
BedrockEncoder(session=mock_session)
def test_raises_value_error_for_unknown_model_name(self, mocker):
mocker.patch("boto3.Session")
mock_session = mocker.Mock()
mock_session.get_credentials.return_value = True
mocker.patch("boto3.Session.client")
unknown_model_name = "unknown_model"
bedrock_encoder = BedrockEncoder(
name=unknown_model_name,
session=mock_session,
region="us-west-2",
)
with pytest.raises(ValueError, match="Unknown model name"):
bedrock_encoder(["test"])
@pytest.fixture
def bedrock_encoder_with_cohere(mocker):
mocker.patch("boto3.Session")
mock_session = mocker.Mock()
mock_session.get_credentials.return_value = True
mocker.patch("boto3.Session.client")
return BedrockEncoder(name="cohere_model", session=mock_session, region="us-west-2")
class TestBedrockEncoderWithCohere:
def test_cohere_embedding_single_chunk(self, bedrock_encoder_with_cohere):
response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]})
response_body = BytesIO(response_content.encode("utf-8"))
mock_response = {"body": response_body}
bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response
result = bedrock_encoder_with_cohere(["short test"])
assert isinstance(result, list), "Result should be a list"
assert all(
isinstance(item, list) for item in result
), "Each item should be a list"
assert result == [[0.1, 0.2, 0.3]], "Expected embedding [0.1, 0.2, 0.3]"
def test_cohere_input_type(self, bedrock_encoder_with_cohere):
bedrock_encoder_with_cohere.input_type = "different_type"
response_content = json.dumps({"embeddings": [[0.1, 0.2, 0.3]]})
response_body = BytesIO(response_content.encode("utf-8"))
mock_response = {"body": response_body}
bedrock_encoder_with_cohere.client.invoke_model.return_value = mock_response
result = bedrock_encoder_with_cohere(["test with different input type"])
assert isinstance(result, list), "Result should be a list"
assert result == [[0.1, 0.2, 0.3]], "Expected specific embeddings"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment