From c2272b5730d34937710c0bd8cda03359792fd69f Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:38:19 -0400 Subject: [PATCH] Adds new LabelledSimpleDataset (llama-dataset) (#11805) --- .../llama_index/core/llama_dataset/base.py | 3 +- .../llama_index/core/llama_dataset/simple.py | 114 ++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 llama-index-core/llama_index/core/llama_dataset/simple.py diff --git a/llama-index-core/llama_index/core/llama_dataset/base.py b/llama-index-core/llama_index/core/llama_dataset/base.py index 07eb5d926..ed3873b6e 100644 --- a/llama-index-core/llama_index/core/llama_dataset/base.py +++ b/llama-index-core/llama_index/core/llama_dataset/base.py @@ -8,12 +8,13 @@ from typing import Generator, Generic, List, Optional, Type, TypeVar, Union import tqdm from llama_index.core.async_utils import asyncio_module from llama_index.core.base.base_query_engine import BaseQueryEngine +from llama_index.core.llms import LLM from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr from llama_index.core.evaluation import BaseEvaluator from openai import RateLimitError from pandas import DataFrame as PandasDataFrame -PredictorType = Union[BaseQueryEngine, BaseEvaluator] +PredictorType = Union[BaseQueryEngine, BaseEvaluator, LLM] P = TypeVar("P", bound=PredictorType) diff --git a/llama-index-core/llama_index/core/llama_dataset/simple.py b/llama-index-core/llama_index/core/llama_dataset/simple.py new file mode 100644 index 000000000..a1712e624 --- /dev/null +++ b/llama-index-core/llama_index/core/llama_dataset/simple.py @@ -0,0 +1,114 @@ +from typing import Optional, List +from llama_index.core.llama_dataset.base import ( + BaseLlamaDataExample, + BaseLlamaDataset, + CreatedBy, + BaseLlamaExamplePrediction, + BaseLlamaPredictionDataset, +) +from llama_index.core.llms import LLM +from llama_index.core.bridge.pydantic import Field +from pandas import DataFrame as PandasDataFrame + + +class SimpleExamplePrediction(BaseLlamaExamplePrediction): + """RAG example prediction class. + + Args: + response (str): The response generated by the LLM. + contexts (Optional[List[str]]): The retrieved context (text) for generating + response. + """ + + label: str = Field( + default_factory=str, + description="The generated (predicted) label that can be compared to a reference (ground-truth) label.", + ) + + @property + def class_name(self) -> str: + """Data example class name.""" + return "SimpleExamplePrediction" + + +class SimplePredictionDataset(BaseLlamaPredictionDataset): + """RagDataset class.""" + + _prediction_type = SimpleExamplePrediction + + def to_pandas(self) -> PandasDataFrame: + """Create pandas dataframe.""" + data = {} + if self.predictions: + data = { + "label": [t.label for t in self.predictions], + } + + return PandasDataFrame(data) + + @property + def class_name(self) -> str: + """Class name.""" + return "SimplePredictionDataset" + + +class LabelledSimpleDataExample(BaseLlamaDataExample): + reference_label: str = Field(default_factory=str, description="Class label") + text: str = Field(default_factory=str, description="Text body of example") + text_by: Optional[CreatedBy] = Field( + default=None, description="What generated the query." + ) + + @property + def class_name(self) -> str: + """Data example class name.""" + return "LabelledSimpleDataExample" + + +class LabelledSimpleDataset(BaseLlamaDataset[LLM]): + _example_type = LabelledSimpleDataExample + + def _construct_prediction_dataset( + self, predictions: List[SimpleExamplePrediction] + ) -> SimplePredictionDataset: + """Construct the specific prediction dataset. + + Args: + predictions (List[BaseLlamaExamplePrediction]): the list of predictions. + + Returns: + BaseLlamaPredictionDataset: A dataset of predictions. + """ + return SimplePredictionDataset(predictions=predictions) + + def to_pandas(self) -> PandasDataFrame: + """Create pandas dataframe.""" + data = { + "reference_label": [t.reference_label for t in self.examples], + "text": [t.text for t in self.examples], + "text_by": [str(t.text_by) for t in self.examples], + } + + return PandasDataFrame(data) + + async def _apredict_example( + self, + predictor: LLM, + example: LabelledSimpleDataExample, + sleep_time_in_seconds: int, + ) -> SimpleExamplePrediction: + """Async predict RAG example with a query engine.""" + raise NotImplementedError("This method has not yet been implemented.") + + def _predict_example( + self, + predictor: LLM, + example: BaseLlamaDataExample, + sleep_time_in_seconds: int = 0, + ) -> BaseLlamaExamplePrediction: + raise NotImplementedError("This method has not yet been implemented.") + + @property + def class_name(self) -> str: + """Data example class name.""" + return "LabelledSimpleDataset" -- GitLab