From c2272b5730d34937710c0bd8cda03359792fd69f Mon Sep 17 00:00:00 2001
From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com>
Date: Mon, 11 Mar 2024 12:38:19 -0400
Subject: [PATCH] Adds new LabelledSimpleDataset (llama-dataset) (#11805)

---
 .../llama_index/core/llama_dataset/base.py    |   3 +-
 .../llama_index/core/llama_dataset/simple.py  | 114 ++++++++++++++++++
 2 files changed, 116 insertions(+), 1 deletion(-)
 create mode 100644 llama-index-core/llama_index/core/llama_dataset/simple.py

diff --git a/llama-index-core/llama_index/core/llama_dataset/base.py b/llama-index-core/llama_index/core/llama_dataset/base.py
index 07eb5d926..ed3873b6e 100644
--- a/llama-index-core/llama_index/core/llama_dataset/base.py
+++ b/llama-index-core/llama_index/core/llama_dataset/base.py
@@ -8,12 +8,13 @@ from typing import Generator, Generic, List, Optional, Type, TypeVar, Union
 import tqdm
 from llama_index.core.async_utils import asyncio_module
 from llama_index.core.base.base_query_engine import BaseQueryEngine
+from llama_index.core.llms import LLM
 from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr
 from llama_index.core.evaluation import BaseEvaluator
 from openai import RateLimitError
 from pandas import DataFrame as PandasDataFrame
 
-PredictorType = Union[BaseQueryEngine, BaseEvaluator]
+PredictorType = Union[BaseQueryEngine, BaseEvaluator, LLM]
 P = TypeVar("P", bound=PredictorType)
 
 
diff --git a/llama-index-core/llama_index/core/llama_dataset/simple.py b/llama-index-core/llama_index/core/llama_dataset/simple.py
new file mode 100644
index 000000000..a1712e624
--- /dev/null
+++ b/llama-index-core/llama_index/core/llama_dataset/simple.py
@@ -0,0 +1,114 @@
+from typing import Optional, List
+from llama_index.core.llama_dataset.base import (
+    BaseLlamaDataExample,
+    BaseLlamaDataset,
+    CreatedBy,
+    BaseLlamaExamplePrediction,
+    BaseLlamaPredictionDataset,
+)
+from llama_index.core.llms import LLM
+from llama_index.core.bridge.pydantic import Field
+from pandas import DataFrame as PandasDataFrame
+
+
+class SimpleExamplePrediction(BaseLlamaExamplePrediction):
+    """RAG example prediction class.
+
+    Args:
+        response (str): The response generated by the LLM.
+        contexts (Optional[List[str]]): The retrieved context (text) for generating
+                                        response.
+    """
+
+    label: str = Field(
+        default_factory=str,
+        description="The generated (predicted) label that can be compared to a reference (ground-truth) label.",
+    )
+
+    @property
+    def class_name(self) -> str:
+        """Data example class name."""
+        return "SimpleExamplePrediction"
+
+
+class SimplePredictionDataset(BaseLlamaPredictionDataset):
+    """RagDataset class."""
+
+    _prediction_type = SimpleExamplePrediction
+
+    def to_pandas(self) -> PandasDataFrame:
+        """Create pandas dataframe."""
+        data = {}
+        if self.predictions:
+            data = {
+                "label": [t.label for t in self.predictions],
+            }
+
+        return PandasDataFrame(data)
+
+    @property
+    def class_name(self) -> str:
+        """Class name."""
+        return "SimplePredictionDataset"
+
+
+class LabelledSimpleDataExample(BaseLlamaDataExample):
+    reference_label: str = Field(default_factory=str, description="Class label")
+    text: str = Field(default_factory=str, description="Text body of example")
+    text_by: Optional[CreatedBy] = Field(
+        default=None, description="What generated the query."
+    )
+
+    @property
+    def class_name(self) -> str:
+        """Data example class name."""
+        return "LabelledSimpleDataExample"
+
+
+class LabelledSimpleDataset(BaseLlamaDataset[LLM]):
+    _example_type = LabelledSimpleDataExample
+
+    def _construct_prediction_dataset(
+        self, predictions: List[SimpleExamplePrediction]
+    ) -> SimplePredictionDataset:
+        """Construct the specific prediction dataset.
+
+        Args:
+            predictions (List[BaseLlamaExamplePrediction]): the list of predictions.
+
+        Returns:
+            BaseLlamaPredictionDataset: A dataset of predictions.
+        """
+        return SimplePredictionDataset(predictions=predictions)
+
+    def to_pandas(self) -> PandasDataFrame:
+        """Create pandas dataframe."""
+        data = {
+            "reference_label": [t.reference_label for t in self.examples],
+            "text": [t.text for t in self.examples],
+            "text_by": [str(t.text_by) for t in self.examples],
+        }
+
+        return PandasDataFrame(data)
+
+    async def _apredict_example(
+        self,
+        predictor: LLM,
+        example: LabelledSimpleDataExample,
+        sleep_time_in_seconds: int,
+    ) -> SimpleExamplePrediction:
+        """Async predict RAG example with a query engine."""
+        raise NotImplementedError("This method has not yet been implemented.")
+
+    def _predict_example(
+        self,
+        predictor: LLM,
+        example: BaseLlamaDataExample,
+        sleep_time_in_seconds: int = 0,
+    ) -> BaseLlamaExamplePrediction:
+        raise NotImplementedError("This method has not yet been implemented.")
+
+    @property
+    def class_name(self) -> str:
+        """Data example class name."""
+        return "LabelledSimpleDataset"
-- 
GitLab