diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py index 9f8a206cd38c2e0daa00b8b304b7d26149cd3c49..dc33abcc0d46a0fd50d786446f661f83d1449f52 100644 --- a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py +++ b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py @@ -12,9 +12,18 @@ _AGG_FUNC: Dict[str, Callable] = {"mean": np.mean, "median": np.median, "max": n class HitRate(BaseRetrievalMetric): - """Hit rate metric.""" + """Hit rate metric: Compute hit rate with two calculation options. + + - The default method checks for a single match between any of the retrieved docs and expected docs. + - The more granular method checks for all potential matches between retrieved docs and expected docs. + + Attributes: + use_granular_hit_rate (bool): Determines whether to use the granular method for calculation. + metric_name (str): The name of the metric. + """ metric_name: str = "hit_rate" + use_granular_hit_rate: bool = False def compute( self, @@ -23,21 +32,57 @@ class HitRate(BaseRetrievalMetric): retrieved_ids: Optional[List[str]] = None, expected_texts: Optional[List[str]] = None, retrieved_texts: Optional[List[str]] = None, - **kwargs: Any, ) -> RetrievalMetricResult: - """Compute metric.""" - if retrieved_ids is None or expected_ids is None: + """Compute metric based on the provided inputs. + + Parameters: + query (Optional[str]): The query string (not used in the current implementation). + expected_ids (Optional[List[str]]): Expected document IDs. + retrieved_ids (Optional[List[str]]): Retrieved document IDs. + expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation). + retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation). + + Raises: + ValueError: If the necessary IDs are not provided. + + Returns: + RetrievalMetricResult: The result with the computed hit rate score. + """ + # Checking for the required arguments + if ( + retrieved_ids is None + or expected_ids is None + or not retrieved_ids + or not expected_ids + ): raise ValueError("Retrieved ids and expected ids must be provided") - is_hit = any(id in expected_ids for id in retrieved_ids) - return RetrievalMetricResult( - score=1.0 if is_hit else 0.0, - ) + + if self.use_granular_hit_rate: + # Granular HitRate calculation: Calculate all hits and divide by the number of expected docs + expected_set = set(expected_ids) + hits = sum(1 for doc_id in retrieved_ids if doc_id in expected_set) + score = hits / len(expected_ids) if expected_ids else 0.0 + else: + # Default HitRate calculation: Check if there is a single hit + is_hit = any(id in expected_ids for id in retrieved_ids) + score = 1.0 if is_hit else 0.0 + + return RetrievalMetricResult(score=score) class MRR(BaseRetrievalMetric): - """MRR metric.""" + """MRR (Mean Reciprocal Rank) metric with two calculation options. + + - The default method calculates the reciprocal rank of the first relevant retrieved document. + - The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant documents. + + Attributes: + use_granular_mrr (bool): Determines whether to use the granular method for calculation. + metric_name (str): The name of the metric. + """ metric_name: str = "mrr" + use_granular_mrr: bool = False def compute( self, @@ -46,19 +91,53 @@ class MRR(BaseRetrievalMetric): retrieved_ids: Optional[List[str]] = None, expected_texts: Optional[List[str]] = None, retrieved_texts: Optional[List[str]] = None, - **kwargs: Any, ) -> RetrievalMetricResult: - """Compute metric.""" - if retrieved_ids is None or expected_ids is None: + """Compute MRR based on the provided inputs and selected method. + + Parameters: + query (Optional[str]): The query string (not used in the current implementation). + expected_ids (Optional[List[str]]): Expected document IDs. + retrieved_ids (Optional[List[str]]): Retrieved document IDs. + expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation). + retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation). + + Raises: + ValueError: If the necessary IDs are not provided. + + Returns: + RetrievalMetricResult: The result with the computed MRR score. + """ + # Checking for the required arguments + if ( + retrieved_ids is None + or expected_ids is None + or not retrieved_ids + or not expected_ids + ): raise ValueError("Retrieved ids and expected ids must be provided") - for i, id in enumerate(retrieved_ids): - if id in expected_ids: - return RetrievalMetricResult( - score=1.0 / (i + 1), - ) - return RetrievalMetricResult( - score=0.0, - ) + + if self.use_granular_mrr: + # Granular MRR calculation: All relevant retrieved docs have their reciprocal ranks summed and averaged + expected_set = set(expected_ids) + reciprocal_rank_sum = 0.0 + relevant_docs_count = 0 + for index, doc_id in enumerate(retrieved_ids): + if doc_id in expected_set: + relevant_docs_count += 1 + reciprocal_rank_sum += 1.0 / (index + 1) + mrr_score = ( + reciprocal_rank_sum / relevant_docs_count + if relevant_docs_count > 0 + else 0.0 + ) + else: + # Default MRR calculation: Reciprocal rank of the first relevant document retrieved + for i, id in enumerate(retrieved_ids): + if id in expected_ids: + return RetrievalMetricResult(score=1.0 / (i + 1)) + mrr_score = 0.0 + + return RetrievalMetricResult(score=mrr_score) class CohereRerankRelevancyMetric(BaseRetrievalMetric): diff --git a/llama-index-core/tests/evaluation/test_rr_mrr_hitrate.py b/llama-index-core/tests/evaluation/test_rr_mrr_hitrate.py new file mode 100644 index 0000000000000000000000000000000000000000..448245f29b06055ca0fa5332d171335f8b3e807e --- /dev/null +++ b/llama-index-core/tests/evaluation/test_rr_mrr_hitrate.py @@ -0,0 +1,77 @@ +import pytest +from llama_index.core.evaluation.retrieval.metrics import HitRate, MRR + + +# Test cases for the updated HitRate class using instance attribute +@pytest.mark.parametrize( + ("expected_ids", "retrieved_ids", "use_granular", "expected_result"), + [ + (["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1.0), + (["id1", "id2", "id3", "id4"], ["id1", "id5", "id2"], True, 2 / 4), + (["id1", "id2"], ["id3", "id4"], False, 0.0), + (["id1", "id2"], ["id2", "id1", "id7"], True, 2 / 2), + ], +) +def test_hit_rate(expected_ids, retrieved_ids, use_granular, expected_result): + hr = HitRate() + hr.use_granular_hit_rate = use_granular + result = hr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids) + assert result.score == pytest.approx(expected_result) + + +# Test cases for the updated MRR class using instance attribute +@pytest.mark.parametrize( + ("expected_ids", "retrieved_ids", "use_granular", "expected_result"), + [ + (["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1 / 1), + (["id1", "id2", "id3", "id4"], ["id5", "id1"], False, 1 / 2), + (["id1", "id2"], ["id3", "id4"], False, 0.0), + (["id1", "id2"], ["id2", "id1", "id7"], False, 1 / 1), + ( + ["id1", "id2", "id3"], + ["id3", "id1", "id2", "id4"], + True, + (1 / 1 + 1 / 2 + 1 / 3) / 3, + ), + ( + ["id1", "id2", "id3", "id4"], + ["id1", "id2", "id5"], + True, + (1 / 1 + 1 / 2) / 2, + ), + (["id1", "id2"], ["id1", "id7", "id15", "id2"], True, (1 / 1 + 1 / 4) / 2), + ], +) +def test_mrr(expected_ids, retrieved_ids, use_granular, expected_result): + mrr = MRR() + mrr.use_granular_mrr = use_granular + result = mrr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids) + assert result.score == pytest.approx(expected_result) + + +# Test cases for exceptions handling for both HitRate and MRR +@pytest.mark.parametrize( + ("expected_ids", "retrieved_ids", "use_granular"), + [ + ( + None, + ["id3", "id1", "id2", "id4"], + False, + ), # None expected_ids should trigger ValueError + ( + ["id1", "id2", "id3"], + None, + True, + ), # None retrieved_ids should trigger ValueError + ([], [], False), # Empty IDs should trigger ValueError + ], +) +def test_exceptions(expected_ids, retrieved_ids, use_granular): + with pytest.raises(ValueError): + hr = HitRate() + hr.use_granular_hit_rate = use_granular + hr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids) + + mrr = MRR() + mrr.use_granular_mrr = use_granular + mrr.compute(expected_ids=expected_ids, retrieved_ids=retrieved_ids)