diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index 49e498113a53446075351b44afc0e97e645a7ca1..daa3e1f48d118c24fb8cf9d854c290a7b0195b69 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d ### Decoding parameters **Decoder** We tested two video decoding backends from torchvision: -- `pyav` (default) +- `pyav` - `video_reader` (requires to build torchvision from source) **Requested timestamps** diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 101e71f446cefbba222bf107b1b961fed8504a6b..6ef955dd7361e0e4f54546c742de2eab31396d9f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import ( VideoFrame, decode_video_frames, encode_video_frames, + get_safe_default_codec, get_video_info, ) from lerobot.common.robot_devices.robots.utils import Robot @@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. """ super().__init__() @@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else "torchcodec" + self.video_backend = video_backend if video_backend else get_safe_default_codec() self.delta_indices = None # Unused attributes @@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.episode_data_index = None - obj.video_backend = video_backend if video_backend is not None else "torchcodec" + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() return obj diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 3fe19d8b6e269a9902ad76d596da5201ff870f33..4f69686175f7abf025de0d6582b15e1b18734260 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import json import logging import subprocess @@ -27,14 +28,23 @@ import torch import torchvision from datasets.features.features import register_feature from PIL import Image -from torchcodec.decoders import VideoDecoder + + +def get_safe_default_codec(): + if importlib.util.find_spec("torchcodec"): + return "torchcodec" + else: + logging.warning( + "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" + ) + return "pyav" def decode_video_frames( video_path: Path | str, timestamps: list[float], tolerance_s: float, - backend: str = "torchcodec", + backend: str | None = None, ) -> torch.Tensor: """ Decodes video frames using the specified backend. @@ -43,13 +53,15 @@ def decode_video_frames( video_path (Path): Path to the video file. timestamps (list[float]): List of timestamps to extract frames. tolerance_s (float): Allowed deviation in seconds for frame retrieval. - backend (str, optional): Backend to use for decoding. Defaults to "torchcodec". + backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".. Returns: torch.Tensor: Decoded frames. Currently supports torchcodec on cpu and pyav. """ + if backend is None: + backend = get_safe_default_codec() if backend == "torchcodec": return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) elif backend in ["pyav", "video_reader"]: @@ -173,6 +185,12 @@ def decode_video_frames_torchcodec( and all subsequent frames until reaching the requested frame. The number of key frames in a video can be adjusted during encoding to take into account decoding time and video size in bytes. """ + + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + # initialize video decoder decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") loaded_frames = [] diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index dee0649aa0c5a0d281e9260c667f12368032c47d..b23bbb6d90859143957828ba9ef64a5c806d57e1 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -20,6 +20,7 @@ from lerobot.common import ( policies, # noqa: F401 ) from lerobot.common.datasets.transforms import ImageTransformsConfig +from lerobot.common.datasets.video_utils import get_safe_default_codec @dataclass @@ -35,7 +36,7 @@ class DatasetConfig: image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) revision: str | None = None use_imagenet_stats: bool = True - video_backend: str = "pyav" + video_backend: str = field(default_factory=get_safe_default_codec) @dataclass diff --git a/pyproject.toml b/pyproject.toml index f1f836b4c4231d05280fc63187effa2884dc003f..665c743a89e9622225e7a143fef40eacedc98318 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "rerun-sdk>=0.21.0", "termcolor>=2.4.0", "torch>=2.2.1", - "torchcodec>=0.2.1", + "torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", "torchvision>=0.21.0", "wandb>=0.16.3", "zarr>=2.17.0",