diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b0812b7a3ee0130355656d623cd8baed99cd75d..af99da4f35f9ff5f0b6a64e15ba1fd50b899a871 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features +- Make `reference_contexts` optional in `LabelledRagDataset` (#9266) +- Re-organize `download` module (#9253) - Added document management to ingestion pipeline (#9135) - Add docs for `LabelledRagDataset` (#9228) diff --git a/llama_index/llama_dataset/base.py b/llama_index/llama_dataset/base.py index 353bb52f98bf82f1895e2e1d536aebc8508ce1cc..5e25b4b7bf676f1cb4e07b95837ae37037b8bae6 100644 --- a/llama_index/llama_dataset/base.py +++ b/llama_index/llama_dataset/base.py @@ -3,7 +3,7 @@ import json from abc import abstractmethod from enum import Enum -from typing import List, Optional, Type +from typing import List, Optional, Type, Union import tqdm from pandas import DataFrame as PandasDataFrame @@ -58,10 +58,17 @@ class BaseLlamaDataExample(BaseModel): class BaseLlamaPredictionDataset(BaseModel): _prediction_type: Type[BaseLlamaExamplePrediction] = BaseLlamaExamplePrediction # type: ignore[misc] - predictions: Optional[List[BaseLlamaExamplePrediction]] = Field( - default=None, description="Predictions on train_examples." + predictions: List[BaseLlamaExamplePrediction] = Field( + default=list, description="Predictions on train_examples." ) + def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaExamplePrediction]: + """Enable slicing and indexing. + + Returns the desired slice on `predictions`. + """ + return self.predictions[val] + @abstractmethod def to_pandas(self) -> PandasDataFrame: """Create pandas dataframe.""" @@ -99,6 +106,13 @@ class BaseLlamaDataset(BaseModel): default=[], description="Data examples of this dataset." ) + def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaDataExample]: + """Enable slicing and indexing. + + Returns the desired slice on `examples`. + """ + return self.examples[val] + @abstractmethod def to_pandas(self) -> PandasDataFrame: """Create pandas dataframe.""" diff --git a/llama_index/llama_dataset/rag.py b/llama_index/llama_dataset/rag.py index 14dd504b4413d5a41ca1d9905597eef0a64f5547..1d96a3d56b4a204fb9270be05033371cc38abab2 100644 --- a/llama_index/llama_dataset/rag.py +++ b/llama_index/llama_dataset/rag.py @@ -19,16 +19,17 @@ class RagExamplePrediction(BaseLlamaExamplePrediction): """RAG example prediction class. Args: - response: str - contexts: List[str] + response (str): The response generated by the LLM. + contexts (Optional[List[str]]): The retrieved context (text) for generating + response. """ response: str = Field( default_factory=str, description="The generated (predicted) response that can be compared to a reference (ground-truth) answer.", ) - contexts: List[str] = Field( - default_factory=List, + contexts: Optional[List[str]] = Field( + default_factory=None, description="The contexts in raw text form used to generate the response.", ) @@ -45,10 +46,11 @@ class LabelledRagDataExample(BaseLlamaDataExample): Args: query (str): The user query - kind (LlamaRagDataExampleKind): The example is generated by human or ai - reference_contexts (List[str] or List[TextNode]): The contexts used for response + query_by (CreatedBy): Query generated by human or ai (model-name) + reference_contexts (Optional[List[str]]): The contexts used for response reference_answer ([str]): Reference answer to the query. An answer that would receive full marks upon evaluation. + reference_answer_by: The reference answer generated by human or ai (model-name). """ query: str = Field( @@ -57,8 +59,8 @@ class LabelledRagDataExample(BaseLlamaDataExample): query_by: Optional[CreatedBy] = Field( default=None, description="What generated the query." ) - reference_contexts: List[str] = Field( - default_factory=List, + reference_contexts: Optional[List[str]] = Field( + default_factory=None, description="The contexts used to generate the reference answer.", ) reference_answer: str = Field(