Skip to content
Snippets Groups Projects
Unverified Commit 1c15bab7 authored by Steven Palma's avatar Steven Palma Committed by GitHub
Browse files

fix(codec): hot-fix for default codec in linux arm platforms (#868)

parent 9f0a8a49
No related branches found
No related tags found
No related merge requests found
...@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d ...@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters ### Decoding parameters
**Decoder** **Decoder**
We tested two video decoding backends from torchvision: We tested two video decoding backends from torchvision:
- `pyav` (default) - `pyav`
- `video_reader` (requires to build torchvision from source) - `video_reader` (requires to build torchvision from source)
**Requested timestamps** **Requested timestamps**
......
...@@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import ( ...@@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames, decode_video_frames,
encode_video_frames, encode_video_frames,
get_safe_default_codec,
get_video_info, get_video_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
...@@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ...@@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
""" """
super().__init__() super().__init__()
...@@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ...@@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec" self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
...@@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ...@@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec" obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj return obj
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import json import json
import logging import logging
import subprocess import subprocess
...@@ -27,14 +28,23 @@ import torch ...@@ -27,14 +28,23 @@ import torch
import torchvision import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
backend: str = "torchcodec", backend: str | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
...@@ -43,13 +53,15 @@ def decode_video_frames( ...@@ -43,13 +53,15 @@ def decode_video_frames(
video_path (Path): Path to the video file. video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames. timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval. tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec". backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
...@@ -173,6 +185,12 @@ def decode_video_frames_torchcodec( ...@@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes. can be adjusted during encoding to take into account decoding time and video size in bytes.
""" """
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# initialize video decoder # initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = [] loaded_frames = []
......
...@@ -20,6 +20,7 @@ from lerobot.common import ( ...@@ -20,6 +20,7 @@ from lerobot.common import (
policies, # noqa: F401 policies, # noqa: F401
) )
from lerobot.common.datasets.transforms import ImageTransformsConfig from lerobot.common.datasets.transforms import ImageTransformsConfig
from lerobot.common.datasets.video_utils import get_safe_default_codec
@dataclass @dataclass
...@@ -35,7 +36,7 @@ class DatasetConfig: ...@@ -35,7 +36,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = "pyav" video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass @dataclass
......
...@@ -69,7 +69,7 @@ dependencies = [ ...@@ -69,7 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1", "torch>=2.2.1",
"torchcodec>=0.2.1", "torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"wandb>=0.16.3", "wandb>=0.16.3",
"zarr>=2.17.0", "zarr>=2.17.0",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment