Skip to content
Snippets Groups Projects
Unverified Commit a89a4c79 authored by Andrei Fajardo's avatar Andrei Fajardo Committed by GitHub
Browse files

make reference_context optional (#9266)

* make reference_context optional

* lint

* make entry to chlog
parent 1f9ba34f
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
### New Features ### New Features
- Make `reference_contexts` optional in `LabelledRagDataset` (#9266)
- Re-organize `download` module (#9253)
- Added document management to ingestion pipeline (#9135) - Added document management to ingestion pipeline (#9135)
- Add docs for `LabelledRagDataset` (#9228) - Add docs for `LabelledRagDataset` (#9228)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import json import json
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import List, Optional, Type from typing import List, Optional, Type, Union
import tqdm import tqdm
from pandas import DataFrame as PandasDataFrame from pandas import DataFrame as PandasDataFrame
...@@ -58,10 +58,17 @@ class BaseLlamaDataExample(BaseModel): ...@@ -58,10 +58,17 @@ class BaseLlamaDataExample(BaseModel):
class BaseLlamaPredictionDataset(BaseModel): class BaseLlamaPredictionDataset(BaseModel):
_prediction_type: Type[BaseLlamaExamplePrediction] = BaseLlamaExamplePrediction # type: ignore[misc] _prediction_type: Type[BaseLlamaExamplePrediction] = BaseLlamaExamplePrediction # type: ignore[misc]
predictions: Optional[List[BaseLlamaExamplePrediction]] = Field( predictions: List[BaseLlamaExamplePrediction] = Field(
default=None, description="Predictions on train_examples." default=list, description="Predictions on train_examples."
) )
def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaExamplePrediction]:
"""Enable slicing and indexing.
Returns the desired slice on `predictions`.
"""
return self.predictions[val]
@abstractmethod @abstractmethod
def to_pandas(self) -> PandasDataFrame: def to_pandas(self) -> PandasDataFrame:
"""Create pandas dataframe.""" """Create pandas dataframe."""
...@@ -99,6 +106,13 @@ class BaseLlamaDataset(BaseModel): ...@@ -99,6 +106,13 @@ class BaseLlamaDataset(BaseModel):
default=[], description="Data examples of this dataset." default=[], description="Data examples of this dataset."
) )
def __getitem__(self, val: Union[slice, int]) -> List[BaseLlamaDataExample]:
"""Enable slicing and indexing.
Returns the desired slice on `examples`.
"""
return self.examples[val]
@abstractmethod @abstractmethod
def to_pandas(self) -> PandasDataFrame: def to_pandas(self) -> PandasDataFrame:
"""Create pandas dataframe.""" """Create pandas dataframe."""
......
...@@ -19,16 +19,17 @@ class RagExamplePrediction(BaseLlamaExamplePrediction): ...@@ -19,16 +19,17 @@ class RagExamplePrediction(BaseLlamaExamplePrediction):
"""RAG example prediction class. """RAG example prediction class.
Args: Args:
response: str response (str): The response generated by the LLM.
contexts: List[str] contexts (Optional[List[str]]): The retrieved context (text) for generating
response.
""" """
response: str = Field( response: str = Field(
default_factory=str, default_factory=str,
description="The generated (predicted) response that can be compared to a reference (ground-truth) answer.", description="The generated (predicted) response that can be compared to a reference (ground-truth) answer.",
) )
contexts: List[str] = Field( contexts: Optional[List[str]] = Field(
default_factory=List, default_factory=None,
description="The contexts in raw text form used to generate the response.", description="The contexts in raw text form used to generate the response.",
) )
...@@ -45,10 +46,11 @@ class LabelledRagDataExample(BaseLlamaDataExample): ...@@ -45,10 +46,11 @@ class LabelledRagDataExample(BaseLlamaDataExample):
Args: Args:
query (str): The user query query (str): The user query
kind (LlamaRagDataExampleKind): The example is generated by human or ai query_by (CreatedBy): Query generated by human or ai (model-name)
reference_contexts (List[str] or List[TextNode]): The contexts used for response reference_contexts (Optional[List[str]]): The contexts used for response
reference_answer ([str]): Reference answer to the query. An answer reference_answer ([str]): Reference answer to the query. An answer
that would receive full marks upon evaluation. that would receive full marks upon evaluation.
reference_answer_by: The reference answer generated by human or ai (model-name).
""" """
query: str = Field( query: str = Field(
...@@ -57,8 +59,8 @@ class LabelledRagDataExample(BaseLlamaDataExample): ...@@ -57,8 +59,8 @@ class LabelledRagDataExample(BaseLlamaDataExample):
query_by: Optional[CreatedBy] = Field( query_by: Optional[CreatedBy] = Field(
default=None, description="What generated the query." default=None, description="What generated the query."
) )
reference_contexts: List[str] = Field( reference_contexts: Optional[List[str]] = Field(
default_factory=List, default_factory=None,
description="The contexts used to generate the reference answer.", description="The contexts used to generate the reference answer.",
) )
reference_answer: str = Field( reference_answer: str = Field(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment