diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b401684148002105bb984563aa902f076e7e86a1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2024 Boston Dynamics AI Institute LLC + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the copyright notice included +with the software, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the copyright notice, this +list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. +3. Modified versions of the software must be conspicuously marked as such. +4. The software may only be used for non-commercial research purposes. +For profit enterprises may use the software, subject to this limitation. + +THIS SOFTWARE IS PROVIDED BY THE AI INSTITUTE AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, NON- +INFRINGEMENT,TITLE, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE AI INSTITUTE OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, DAMAGES ARISING OUT OF CLAIMS OF +INTELLECTUAL PROPERTY RIGHTS INFRINGEMENT; PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/theia/__init__.py b/theia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/configs/dataset/ego4d.yaml b/theia/configs/dataset/ego4d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cd10b2691475e805bf856df9ef6850aef950791d --- /dev/null +++ b/theia/configs/dataset/ego4d.yaml @@ -0,0 +1,5 @@ +defaults: + - image_video_default + +dataset_mix: + - "ego4d_1in150" diff --git a/theia/configs/dataset/epic_kitchen.yaml b/theia/configs/dataset/epic_kitchen.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d93b4d5fa488ee0689d964e42969c8c55e7a0971 --- /dev/null +++ b/theia/configs/dataset/epic_kitchen.yaml @@ -0,0 +1,5 @@ +defaults: + - image_video_default + +dataset_mix: + - "epic_kitchen_1in60" diff --git a/theia/configs/dataset/image_video_default.yaml b/theia/configs/dataset/image_video_default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f7754ba06e96322b377aed7f28830ffd971365a --- /dev/null +++ b/theia/configs/dataset/image_video_default.yaml @@ -0,0 +1,7 @@ +return_metadata: False +shuffle: True +shuffle_buffer_size: 1024 +feature_norm: True +dataset_root: "/storage/nfs/datasets/jshang/" +dataset_ratio: 0.1 +load_action: False diff --git a/theia/configs/dataset/image_video_mix.yaml b/theia/configs/dataset/image_video_mix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e842cb9690501a1e021bfb27b4989fc3aef930dc --- /dev/null +++ b/theia/configs/dataset/image_video_mix.yaml @@ -0,0 +1,8 @@ +defaults: + - image_video_default + +dataset_mix: + - "ego4d_1in150" + - "ssv2_1in32" + - "epic_kitchen_1in60" + - "imagenet" diff --git a/theia/configs/dataset/imagenet.yaml b/theia/configs/dataset/imagenet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc4513fcb01f324033a662c459a1c53fcf8dd82e --- /dev/null +++ b/theia/configs/dataset/imagenet.yaml @@ -0,0 +1,5 @@ +defaults: + - image_video_default + +dataset_mix: + - "imagenet" diff --git a/theia/configs/dataset/oxe_octo_mix.yaml b/theia/configs/dataset/oxe_octo_mix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d77cde0cf133f6c69b0c30343ddfd5aa9c5475e1 --- /dev/null +++ b/theia/configs/dataset/oxe_octo_mix.yaml @@ -0,0 +1,12 @@ +_target_: dataset.oxe.oxe_data_utils.OXEDataset +dataset_mix: "oxe_magic_soup" +image_action_set_root: "/storage/nfs/datasets/jshang/oxe_image_action" +feature_set_root: "/storage/nfs/datasets/jshang/oxe_vfm_features" +image_views: null +split: "train" +data_portion: 0.01 +load_action: False +bf16: True +safe_tensors: True +trajectory_subsample_len: 32 +return_metadata: False diff --git a/theia/configs/dataset/ssv2.yaml b/theia/configs/dataset/ssv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29f00acd2f741aaf55036df9c70cd22410367211 --- /dev/null +++ b/theia/configs/dataset/ssv2.yaml @@ -0,0 +1,5 @@ +defaults: + - image_video_default + +dataset_mix: + - "ssv2_1in32" diff --git a/theia/configs/logging/default.yaml b/theia/configs/logging/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..73271701f2fb98479bc7ab2774f8a2ca9830b1e0 --- /dev/null +++ b/theia/configs/logging/default.yaml @@ -0,0 +1,6 @@ +model_path: "/storage/nfs/jshang/trained_models" +log_path: "/storage/nfs/jshang/logs" +save_ckpt_interval: 20000 +notes: "" +run_identifier_prefix: "" +project: "theia" diff --git a/theia/configs/model/backbone/deit.yaml b/theia/configs/model/backbone/deit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f2ff8291a73dc05d9a3542eff065314019bc07c --- /dev/null +++ b/theia/configs/model/backbone/deit.yaml @@ -0,0 +1,2 @@ +backbone: facebook/deit-small-patch16-224 +pretrained: False diff --git a/theia/configs/model/backbone/deit_nocls.yaml b/theia/configs/model/backbone/deit_nocls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0ddada1a029a866d6b9848a35fd746653bbfb3c --- /dev/null +++ b/theia/configs/model/backbone/deit_nocls.yaml @@ -0,0 +1,2 @@ +backbone: nocls-facebook/deit-tiny-patch16-224 +pretrained: False diff --git a/theia/configs/model/backbone/deit_reg.yaml b/theia/configs/model/backbone/deit_reg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ee8f155401453bdb065d506cd1aaad58bf22e87 --- /dev/null +++ b/theia/configs/model/backbone/deit_reg.yaml @@ -0,0 +1,3 @@ +backbone: reg-facebook/deit-tiny-patch16-224 +pretrained: False +num_reg_tokens: 7 diff --git a/theia/configs/model/translator/conv.yaml b/theia/configs/model/translator/conv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d78b80229b387be57abd33c69f0b203c0d69fd8 --- /dev/null +++ b/theia/configs/model/translator/conv.yaml @@ -0,0 +1,3 @@ +type: "conv" +kwargs: + translator_hidden_size: 1024 diff --git a/theia/configs/model/translator/lconv.yaml b/theia/configs/model/translator/lconv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3abaef5f4dee8a99c8aa281780b0aa7913b7c605 --- /dev/null +++ b/theia/configs/model/translator/lconv.yaml @@ -0,0 +1,3 @@ +type: "lconv" +kwargs: + hidden_size_factor: 1.0 diff --git a/theia/configs/model/translator/mlp.yaml b/theia/configs/model/translator/mlp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8e2b66e5bdded40dbba486be72cb7055d95d89f --- /dev/null +++ b/theia/configs/model/translator/mlp.yaml @@ -0,0 +1,4 @@ +type: "mlp" +kwargs: + translator_n_layer: 3 + hidden_size: 1024 diff --git a/theia/configs/model/translator/transformer.yaml b/theia/configs/model/translator/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..604d06434a8d36e51899cc0a4bc9a495b063c606 --- /dev/null +++ b/theia/configs/model/translator/transformer.yaml @@ -0,0 +1,5 @@ +type: "transformer" +kwargs: + translator_n_layers: 2 + translator_n_heads: 8 + translator_hidden_size: 1024 diff --git a/theia/configs/train_rvfm_imagenet.yaml b/theia/configs/train_rvfm_imagenet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..476f2f0a77e54febbd0bcff8db2d2efdfb1334a8 --- /dev/null +++ b/theia/configs/train_rvfm_imagenet.yaml @@ -0,0 +1,9 @@ +defaults: + - dataset: imagenet + - model/backbone: deit + - model/translator: lconv + - training: frame_level + - logging: default + - _self_ + +seed: 0 diff --git a/theia/configs/training/frame_level.yaml b/theia/configs/training/frame_level.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34748e30e3e166d87e41aced2a2cb56974fdd934 --- /dev/null +++ b/theia/configs/training/frame_level.yaml @@ -0,0 +1,35 @@ +defaults: + - target_models: cdiv + +epochs: 50 +warm_up_steps_ratio: 0.1 + +base_lr: 2e-3 +batch_size: 16 +random_target_models: -1 +num_workers: 8 +# base training settings to scale lr, rarely changed +base_batch_size: 64 +base_world_size: 8 + +weight_decay: 0.01 + + +optimizer: + _target_: torch.optim.AdamW + betas: [0.9, 0.999] + +lr_scheduler: + _target_: theia.lr_schedulers.get_constant_lrs_with_linear_warm_up + warm_up_lr_start_factor: 1e-2 + + +grad_clip: False +grad_clip_norm_warmup: 10.0 +grad_clip_norm: 1.0 + +freeze_translator: False +freeze_translator_start_steps_ratio: 0.2 +translator_lr_factor: 1.0 + +main_loss: cos_l1 diff --git a/theia/configs/training/target_models/cdds.yaml b/theia/configs/training/target_models/cdds.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3d1e92b2f22d27c4806d76af7b3baa58952a5cd --- /dev/null +++ b/theia/configs/training/target_models/cdds.yaml @@ -0,0 +1,6 @@ +target_model_names: + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" + - "facebook/sam-vit-huge" + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/cddsv.yaml b/theia/configs/training/target_models/cddsv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f037a78518a7f2d03a1bf24068bff38ed6581839 --- /dev/null +++ b/theia/configs/training/target_models/cddsv.yaml @@ -0,0 +1,7 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" + - "facebook/sam-vit-huge" + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/cddv.yaml b/theia/configs/training/target_models/cddv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13db812c5cde279865f0c08677023c93ceb709fd --- /dev/null +++ b/theia/configs/training/target_models/cddv.yaml @@ -0,0 +1,6 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/cdesv.yaml b/theia/configs/training/target_models/cdesv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5696b402f914f5cbf7a35cd10c3991ff1445185e --- /dev/null +++ b/theia/configs/training/target_models/cdesv.yaml @@ -0,0 +1,6 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "openai/clip-vit-large-patch14" + - "facebook/sam-vit-huge" + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/cdis.yaml b/theia/configs/training/target_models/cdis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f298afcca8ef9bb6b6c58b972b5987a33090bec1 --- /dev/null +++ b/theia/configs/training/target_models/cdis.yaml @@ -0,0 +1,5 @@ +target_model_names: + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" + - "facebook/sam-vit-huge" +target_model_weights: null diff --git a/theia/configs/training/target_models/cdisv.yaml b/theia/configs/training/target_models/cdisv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f2e8f585a3428486842fa9538cbad26a498f2afd --- /dev/null +++ b/theia/configs/training/target_models/cdisv.yaml @@ -0,0 +1,6 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" + - "facebook/sam-vit-huge" +target_model_weights: null diff --git a/theia/configs/training/target_models/cdiv.yaml b/theia/configs/training/target_models/cdiv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9130d9ecf8ee16537af8e6037cd091fd839df7b --- /dev/null +++ b/theia/configs/training/target_models/cdiv.yaml @@ -0,0 +1,5 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "facebook/dinov2-large" + - "openai/clip-vit-large-patch14" +target_model_weights: null diff --git a/theia/configs/training/target_models/clip.yaml b/theia/configs/training/target_models/clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f2a31815bae9bfda0859bedd7854b19e6d276961 --- /dev/null +++ b/theia/configs/training/target_models/clip.yaml @@ -0,0 +1,3 @@ +target_model_names: + - "openai/clip-vit-large-patch14" +target_model_weights: null diff --git a/theia/configs/training/target_models/ddsv.yaml b/theia/configs/training/target_models/ddsv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df5197be300fc301f39b847f7e88d8df78b3ca1a --- /dev/null +++ b/theia/configs/training/target_models/ddsv.yaml @@ -0,0 +1,6 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" + - "facebook/dinov2-large" + - "facebook/sam-vit-huge" + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/depth_anything.yaml b/theia/configs/training/target_models/depth_anything.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57367c3cb527860d348c501e70abc009bdd2fa40 --- /dev/null +++ b/theia/configs/training/target_models/depth_anything.yaml @@ -0,0 +1,3 @@ +target_model_names: + - "LiheYoung/depth-anything-large-hf" +target_model_weights: null diff --git a/theia/configs/training/target_models/dinov2.yaml b/theia/configs/training/target_models/dinov2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8aca1f2452e022512dde25ffa416a613dbb4734 --- /dev/null +++ b/theia/configs/training/target_models/dinov2.yaml @@ -0,0 +1,3 @@ +target_model_names: + - "facebook/dinov2-large" +target_model_weights: null diff --git a/theia/configs/training/target_models/sam.yaml b/theia/configs/training/target_models/sam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c9c4a53ca57b399fc5f03f00bfb449906a6f4e9 --- /dev/null +++ b/theia/configs/training/target_models/sam.yaml @@ -0,0 +1,3 @@ +target_model_names: + - "facebook/sam-vit-huge" +target_model_weights: null diff --git a/theia/configs/training/target_models/vit.yaml b/theia/configs/training/target_models/vit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ae383eb0d8dd87a25658e900adaaa7680c22e91 --- /dev/null +++ b/theia/configs/training/target_models/vit.yaml @@ -0,0 +1,3 @@ +target_model_names: + - "google/vit-huge-patch14-224-in21k" +target_model_weights: null diff --git a/theia/dataset/__init__.py b/theia/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4216d210244524d8b1abec2ad5fddd19f068faa0 --- /dev/null +++ b/theia/dataset/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .image.image_common import ALL_IMAGE_DATASETS +from .oxe.oxe_common import ALL_OXE_DATASETS +from .video.video_common import ALL_VIDEO_DATASETS diff --git a/theia/dataset/data_utils.py b/theia/dataset/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4edf36c441220172610b46039cafdbe4d3156643 --- /dev/null +++ b/theia/dataset/data_utils.py @@ -0,0 +1,591 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +"""Defines PyTorch datasets of dataloaders for multiple image, video, and OXE datasets. +Should use with webdataset >= 0.2.90. See https://github.com/webdataset/webdataset/pull/347""" + +import glob +import json +import math +import os.path as osp +from collections import OrderedDict +from functools import partial +from io import BytesIO +from typing import Any, Callable, Generator, Iterator, Literal, Optional + +import cv2 +import numpy as np +import omegaconf +import torch +import webdataset as wds +from datasets.combine import DatasetType +from einops import rearrange +from numpy.typing import NDArray +from safetensors.torch import load as sft_load +from torch import default_generator +from torch.utils.data import DataLoader, Dataset, IterableDataset, default_collate + +from theia.foundation_models.common import MODELS +from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS +from theia.dataset.oxe.oxe_mixes import OXE_NAMED_MIXES + +PACKED_FEATURES = [model_name for model_name in MODELS if "llava" not in model_name] + + +def normalize_ds_weights_by_ds_len(weights: list[float], lengths: list[int]) -> tuple[list[float], float | Literal[0]]: + """Normalize dataset weights by dataset lengths (frames). + + Args: + weights (list[float]): assigned weights. + lengths (list[int]): lengths of datasets. + + Returns: + tuple[list[float], int]: normalized weights, and sum of the expected lengths of datasets + """ + expected_lengths = [weight * length for weight, length in zip(weights, lengths, strict=False)] + sum_expected_lengths = sum(expected_lengths) + if sum_expected_lengths == 0: + raise ValueError("Sum of dataset length is 0.") + normalized_weights = [length * 1.0 / sum_expected_lengths for length in expected_lengths] + return normalized_weights, sum_expected_lengths + + +def get_vo_keys(dataset_name: str, image_views: Optional[list | str | dict[str, str | list[str]]] = None) -> list[str]: + """Get visual observation keys of datasets (to be compatible with OXE). + + Args: + dataset_name (str): name of the dataset. + image_views (Optional[dict[str, str | list[str]]], optional): keys of selected views. + Defaults to None. + + Returns: + list[str]: keys to the views in the dataset. + """ + default_visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"][:1] + visual_observation_keys = [] + if image_views is None: + visual_observation_keys = default_visual_observation_keys + elif isinstance(image_views, list): + visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] + elif isinstance(image_views, str): + if image_views == "static": + visual_observation_keys = [ + k + for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] + if "wrist" not in k and "hand" not in k + ] + elif image_views == "wrist": + visual_observation_keys = [ + k for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] if "wrist" in k or "hand" in k + ] + if len(visual_observation_keys) == 0: + visual_observation_keys = default_visual_observation_keys + return visual_observation_keys + + +class RandomMix(IterableDataset): + """A random interleave of multiple iterable datasets.""" + + def __init__( + self, + datasets: list[IterableDataset], + probs: list[float] | NDArray | None = None, + stopping_strategy: str = "all_exhausted", + seed: Optional[int | str] = 0, + ) -> None: + """Initialization of a random interleave dataset. + + Args: + datasets (list[IterableDataset]): datasets to be interleaved. + probs (list[float] | NDArray, optional): probability of each dataset. Defaults to None. + stopping_strategy (str, optional): when to end the sampling for one epoch. Defaults to `all_exhausted`. + `all_exhausted`: each sample in the dataset will be sampled at least once. + `first_exhausted`: when the first dataset is ran out, this episode ends. + See also https://huggingface.co/docs/datasets/en/stream#interleave for definitions. + seed (Optional[int | str]): seed. Defaults to 0. + """ + self.datasets = datasets + if probs is None: + self.probs = [1.0] * len(self.datasets) + elif isinstance(probs, np.ndarray): + self.probs = probs.tolist() + else: + self.probs = probs + self.stopping_strategy = stopping_strategy + self.seed = seed + + def __iter__(self) -> Generator: + """Return an iterator over the sources.""" + sources = [iter(d) for d in self.datasets] + probs = self.probs[:] + seed_gen = torch.Generator() + seed_gen.manual_seed(self.seed) + cum = (np.array(probs) / np.sum(probs)).cumsum() + while len(sources) > 0: + r = torch.rand(1, generator=seed_gen).item() + i = np.searchsorted(cum, r) + try: + yield next(sources[i]) + except StopIteration: + if self.stopping_strategy == "all_exhausted": + del sources[i] + del probs[i] + cum = (np.array(probs) / np.sum(probs)).cumsum() + elif self.stopping_strategy == "first_exhausted": + break + + +def decode_sample( + key: str, data: bytes, image_transform: Optional[Callable] = None, feature_transform: Optional[Callable] = None +) -> Any: + """Decode a sample from bytes with optional image and feature transforms + + Args: + key (str): key of an attribute (a column) of the sample. + data (bytes): original data bytes. + image_transform (Optional[Callable], optional): image transform. Defaults to None. + feature_transform (Optional[Callable], optional): feature transform. Defaults to None. + + Returns: + Any: decoded data. + """ + if ".safetensors" in key: + sft = sft_load(data) + embedding = rearrange(sft["embedding"], "c h w -> (h w) c") + if feature_transform is not None: + embedding = feature_transform(embedding) + if "cls_token" in sft: + cls = sft["cls_token"] + if feature_transform is not None: + cls = feature_transform(cls) + return {"embedding": embedding, "cls": cls} + return {"embedding": embedding} + elif key == ".image": + image = np.load(BytesIO(data)) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif len(image.shape) == 3 and image.shape[-1] == 4: + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + if image_transform is not None: + return image_transform(image) + return image + else: + return data + + +def get_oxe_frame_dataset( + dataset_root: str, + dataset_mix: Optional[str | dict[str, float] | list] = "oxe_magic_soup", + feature_models: Optional[list[str]] = None, + split: str = "train", + dataset_ratio: float = 1.0, + image_views: Optional[dict[str, str | list[str]]] = None, + image_transform: Optional[Callable[[Any], torch.Tensor]] = None, + seed: Optional[int | str] = 0, + shuffle: bool = False, + world_size: int = 1, +) -> tuple[dict[str, DatasetType], float | Literal[0]]: + """Get OXE datasets at frame level. + + Args: + dataset_root (str): root dir of the datasets. + dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets. + Defaults to "oxe_magic_soup". + feature_models (Optional[list[str]], optional): models to load their features. Defaults to None. + split (str, optional): split "train" or "val" or "test". Defaults to "train". + dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0. + image_views (Optional[dict[str, str | list[str]]], optional): image views to select. Defaults to None. + image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples. + Defaults to None. + seed (Optional[int | str], optional): seed. Defaults to 0. + shuffle (bool, optional): shuffle or not. Defaults to False. + world_size (int, optional): world size of DDP training. Defaults to 1. + + Returns: + tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}. + """ + # read dataset mix from any acceptable form + if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES: + dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]}) + elif isinstance(dataset_mix, dict): + dataset_mix = OrderedDict(**dataset_mix) + elif isinstance(dataset_mix, list): + dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix}) + else: + raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.") + + if split == "eval" or split == "val": + dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix}) + + # note down the dataset weights + dataset_weights: list[float] = [] + # get frame level length + dataset_lens: list[int] = [] + + all_feature_datasets: dict[str, DatasetType] = {} + for dataset in dataset_mix: + visual_observation_keys = get_vo_keys(dataset_name=dataset, image_views=image_views) + + if feature_models is None: + feature_models = PACKED_FEATURES + + with open(osp.join(dataset_root, dataset, "splits.json"), "r") as splitf: + dataset_len = json.load(splitf)[split] + # if the length is 0, skip + # this may happen for small datasets with very few shards + if dataset_len == 0: + continue + + for vo_key in visual_observation_keys: + for model_name in feature_models: + if model_name not in PACKED_FEATURES: + feature_set_name = model_name + path_pattern = osp.join( + dataset_root, dataset, vo_key + f"_{model_name.replace('/', '_')}", f"*-{split}*.tar" + ) + rename_kw = {model_name: model_name.replace("/", "_") + ".safetensors"} # replace v by k + elif "packed" in all_feature_datasets: + continue + else: + feature_set_name = "packed" + path_pattern = osp.join(dataset_root, dataset, vo_key, f"*-{split}*.tar") + rename_kw = { + name: name.replace("/", "_") + ".safetensors" for name in PACKED_FEATURES + } # replace v by k + rename_kw["image"] = "image" + + if feature_set_name not in all_feature_datasets: + all_feature_datasets[feature_set_name] = [] + + shard_paths = sorted(glob.glob(path_pattern)) + num_shards = len(shard_paths) + if num_shards < world_size * 8: + shard_paths *= math.ceil(world_size * 8 / num_shards) + ds = ( + wds.WebDataset( + shard_paths, + nodesplitter=wds.split_by_node, + workersplitter=wds.split_by_worker, + detshuffle=True, + shardshuffle=shuffle, + seed=seed, + ) + .decode(partial(decode_sample, image_transform=image_transform)) + .rename(keep=False, **rename_kw) + ) + all_feature_datasets[feature_set_name].append(ds) + + dataset_weights.append(dataset_mix[dataset]) + dataset_lens.append(math.ceil(dataset_len * dataset_ratio)) + + normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens) + + combined_feature_datasets: dict[str, Dataset] = {} + for feature_set_name, fds in all_feature_datasets.items(): + ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted") + combined_feature_datasets[feature_set_name] = ds + + return combined_feature_datasets, sum_expected_lengths + + +def get_oxe_frame_dataloader( + datasets: dict[str, DatasetType], batch_size: Optional[int] = None, shuffle_buffer_size: int = 1_000, **kwargs: Any +) -> dict[str, DataLoader]: + """Get dataloaders of OXE datasets. Corresponding to `get_oxe_frame_dataset()`. + + Args: + datasets (dict[str, DatasetType]): OXE datasets from `get_oxe_frame_dataset(). + batch_size (Optional[int], optional): batch size. Defaults to None. + shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000. + + Returns: + dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}. + """ + loaders = { + k: ( + wds.WebLoader(datasets[k], batch_size=None, **kwargs) + .shuffle(shuffle_buffer_size) # shuffle after mix + .batched(batch_size, collation_fn=default_collate) + ) + for k in datasets + } + return loaders + + +def get_oxe_frame_iterator( + data_loaders: dict[str, DataLoader], +) -> Iterator[dict[str, Any]]: + """Get iterator from dataloders. Corresponding to `get_oxe_frame_dataloader()`. + + Args: + data_loaders (dict[str, DataLoader]): dataloaders from `get_oxe_frame_dataloader()`. + + Yields: + Iterator[dict[str, Any]]: data sample. + """ + packed_loader = data_loaders.get("packed", None) + # place packed_loader at the first + if packed_loader is not None: + loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]] + else: + loaders = list(data_loaders.values()) + + # merge dicts + for data in zip(*loaders, strict=False): + # yield data + for i in range(1, len(loaders)): + for k in data[i]: + if k not in data[0]: + data[0][k] = data[i][k] + yield data[0] + + +def normalize_feature( + x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Normalize the feature given mean and std. + + Args: + x (torch.Tensor): input features + mean (Optional[torch.Tensor], optional): mean values. Defaults to None. + std (Optional[torch.Tensor], optional): std values. Defaults to None. + + Returns: + torch.Tensor: feature after normalization + """ + return x if mean is None or std is None else (x - mean) / std + + +def load_feature_stats( + dataset_root: str, feature_models: list[str] +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Load feature statictics (mean and variance). + + Args: + dataset_root (str): root dir of the dataset (or where to hold the statistics). + feature_models (list[str]): names of the models/features. + + Returns: + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variances. Keys are model names. + """ + feature_means: dict[str, torch.Tensor] = {} + feature_vars: dict[str, torch.Tensor] = {} + for model in feature_models: + model_name = model.replace("/", "_") + feature_means[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_mean_{model_name}.npy"))).to( + torch.bfloat16 + ) + feature_vars[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_var_{model_name}.npy"))).to( + torch.bfloat16 + ) + return feature_means, feature_vars + + +def pad_shard_paths(shard_paths: list[str], num_shards: int, num_parts: int) -> list[str]: + """Pad shard paths to be divided by number of partitions (ranks*nodes). + + Args: + shard_paths (list[str]): pathes of dataset shards. + num_shards (int): number of shards. + num_parts (int): number of partitions. + + Returns: + list[str]: shard paths padded. + """ + final_shard_paths = shard_paths + if num_shards % num_parts != 0: + if num_shards < num_parts - num_shards: + for _ in range(math.floor((num_parts - num_shards) / num_shards)): + final_shard_paths += shard_paths[:] + final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)] + else: + final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)] + return final_shard_paths + + +def get_image_video_dataset( + dataset_root: str, + feature_models: list[str], + dataset_mix: Optional[str | dict[str, float] | list] = None, + split: str = "train", + dataset_ratio: float = 1.0, + image_transform: Optional[Callable[[Any], torch.Tensor]] = None, + feature_norm: bool = False, + seed: Optional[int | str] = 0, + shuffle: bool = False, + world_size: int = 1, + **kwargs: Any, +) -> tuple[dict[str, DatasetType], float | Literal[0]]: + """Get image and video datasets at frame level. + + Args: + dataset_root (str): root dir of the datasets. + feature_models (list[str]): models to load their features. + dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets. + split (str, optional): split "train" or "val" or "test". Defaults to "train". + dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0. + image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples. + Defaults to None. + feature_norm: (bool, optional): whether to normalize the feature. Defaults to False. + seed (Optional[int | str], optional): seed. Defaults to 0. + shuffle (bool, optional): shuffle or not. Defaults to False. + world_size (int, optional): world size of DDP training. Defaults to 1. + kwargs (Any): arguments to pass-through. + + Returns: + tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}. + """ + # read dataset mix from any acceptable form + if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES: + dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]}) + elif isinstance(dataset_mix, dict): + dataset_mix = OrderedDict(**dataset_mix) + elif isinstance(dataset_mix, list) or isinstance(dataset_mix, omegaconf.listconfig.ListConfig): + dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix}) + else: + raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.") + + if split == "eval" or split == "val": + dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix}) + + # note down the dataset weights + dataset_weights: list[float] = [] + # get frame level length + dataset_lens: list[int] = [] + + all_feature_datasets: dict[str, DatasetType] = {} + + if feature_norm: + feature_means, feature_vars = load_feature_stats(dataset_root, feature_models) + + for d in dataset_mix: + + with open(osp.join(dataset_root, d, "splits.json"), "r") as splitf: + dataset_len = json.load(splitf)[split] + + # if the length is 0, skip + # this may happen for small datasets with very few shards + if dataset_len == 0: + continue + + path_pattern = osp.join(dataset_root, d, "images", f"*-{split}.tar") + if "image" not in all_feature_datasets: + all_feature_datasets["image"] = [] + shard_paths = sorted(glob.glob(path_pattern)) + num_shards = len(shard_paths) + num_parts = world_size + final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts) + ds = wds.WebDataset( + final_shard_paths, + nodesplitter=wds.split_by_node, + workersplitter=wds.split_by_worker, + detshuffle=True, + shardshuffle=shuffle, + seed=seed, + ).decode(partial(decode_sample, image_transform=image_transform)) + all_feature_datasets["image"].append(ds) + + for model_name in feature_models: + path_pattern = osp.join(dataset_root, d, f"{model_name.replace('/', '_')}", f"*-{split}.tar") + rename_kw = {model_name: model_name.replace("/", "_").lower() + ".safetensors"} # replace v by k + + if model_name not in all_feature_datasets: + all_feature_datasets[model_name] = [] + + shard_paths = sorted(glob.glob(path_pattern)) + num_shards = len(shard_paths) + num_parts = world_size + final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts) + if feature_norm: + feature_transform = partial( + normalize_feature, mean=feature_means[model_name], std=feature_vars[model_name] + ) + else: + feature_transform = None + ds = ( + wds.WebDataset( + final_shard_paths, + nodesplitter=wds.split_by_node, + workersplitter=wds.split_by_worker, + detshuffle=True, + shardshuffle=shuffle, + seed=seed, + ) + .decode(partial(decode_sample, image_transform=image_transform, feature_transform=feature_transform)) + .rename(keep=False, **rename_kw) + ) + all_feature_datasets[model_name].append(ds) + + dataset_weights.append(dataset_mix[d]) + dataset_lens.append(math.ceil(dataset_len * dataset_ratio)) + + normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens) + + combined_feature_datasets: dict[str, Dataset] = {} + for feature_set_name, fds in all_feature_datasets.items(): + ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted", seed=seed) + combined_feature_datasets[feature_set_name] = ds + + return combined_feature_datasets, sum_expected_lengths + + +def get_frame_dataloader( + datasets: dict[str, DatasetType], + batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_buffer_size: int = 1_000, + seed: Optional[int] = 0, + **kwargs: Any, +) -> dict[str, DataLoader]: + """Get dataloaders of image and video datasets. Corresponding to `get_image_video_dataset()`. + + Args: + datasets (dict[str, DatasetType]): image and video datasets from `get_image_video_dataset(). + batch_size (Optional[int], optional): batch size. Defaults to None. + shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000. + + Returns: + dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}. + """ + loaders = {} + for k in datasets: + loader = wds.WebLoader(datasets[k], batch_size=None, generator=default_generator, **kwargs) + if shuffle: + loader = loader.shuffle(shuffle_buffer_size, seed=seed) # shuffle after mix + loader = loader.batched(batch_size, collation_fn=default_collate) + loaders[k] = loader + return loaders + + +def get_frame_iterator( + data_loaders: dict[str, DataLoader], +) -> Iterator[dict[str, Any]]: + """Get iterator from image and video dataset dataloders. Corresponding to `get_frame_dataloader()`. + + Args: + data_loaders (dict[str, DataLoader]): dataloaders from `get_frame_dataloader()`. + + Yields: + Iterator[dict[str, Any]]: data sample. + """ + packed_loader = data_loaders.get("packed", None) + # place packed_loader at the first + if packed_loader is not None: + loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]] + else: + loaders = list(data_loaders.values()) + + # merge dicts + # this is to accommodate the old organization of datasets (each shard contains one or more columns, + # and images are duplicated columns). + # In new (current) dataset organization (columns are completely separated), + # column keys are all different except some "built-in" keys added by webdataset, + # but they are not related to any data, training, so on. + # During transit from old to new, where two organizations exist at the same time, + # this is to ignore extra "image" field in datasets loaded. + for data in zip(*loaders, strict=False): + # yield data + for i in range(1, len(loaders)): + for k in data[i]: + if k not in data[0]: + data[0][k] = data[i][k] + yield data[0] diff --git a/theia/dataset/image/__init__.py b/theia/dataset/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81d96b9788b4242dec0396c89f7cd1037cffcb05 --- /dev/null +++ b/theia/dataset/image/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .image_common import ALL_IMAGE_DATASETS diff --git a/theia/dataset/image/image_common.py b/theia/dataset/image/image_common.py new file mode 100644 index 0000000000000000000000000000000000000000..dd657b3f8c00528109448433f0963b1a85efd89c --- /dev/null +++ b/theia/dataset/image/image_common.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from collections import OrderedDict + +ALL_IMAGE_DATASETS = OrderedDict({"imagenet": {"steps": 1_281_167}}) diff --git a/theia/dataset/oxe/__init__.py b/theia/dataset/oxe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/dataset/oxe/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/dataset/oxe/oxe_common.py b/theia/dataset/oxe/oxe_common.py new file mode 100644 index 0000000000000000000000000000000000000000..36c7ffed4eaf81c2a58d2d717df0a86434b4ef8e --- /dev/null +++ b/theia/dataset/oxe/oxe_common.py @@ -0,0 +1,430 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from collections import OrderedDict +from typing import Optional + +""" +This ALL_OXE_DATASETS below records metadata of all subsets of OXE dataset. +The datasets are in alphabetical order. + +versions (list[str]): available and usable versions, sorted from older to newer. + Usually use the last one. +episodes (int): total episodes in the dataset. +steps (int): total steps in the dataset. +visual_observation_keys (list[str]): keys to specify image observations. +""" +ALL_OXE_DATASETS: OrderedDict = OrderedDict( + { + "agent_aware_affordances": { + "versions": ["1.0.0"], + "episodes": 118, + "steps": 151628, + "visual_observation_keys": ["image"], + }, + "asu_table_top_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 110, + "steps": 26113, + "visual_observation_keys": ["image"], + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 50, + "steps": 34112, + "visual_observation_keys": ["image", "wrist_image"], + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 240, + "steps": 353094, + "visual_observation_keys": ["image", "wrist_image"], + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 559, + "steps": 279939, + "visual_observation_keys": ["image", "wrist_image"], + }, + "bc_z": { + "versions": [ + "0.1.0", # "1.0.0", "old1.0.1", and "1.0.1" are not usable + ], + "episodes": 39350, + "steps": 5471693, + "visual_observation_keys": ["image"], + }, + "berkeley_autolab_ur5": { + "versions": ["0.1.0"], + "episodes": 896, + "steps": 87783, + "visual_observation_keys": ["image", "hand_image"], + }, + "berkeley_cable_routing": { + "versions": ["0.1.0"], + "episodes": 1482, + "steps": 38240, + "visual_observation_keys": ["image", "top_image", "wrist225_image", "wrist45_image"], + }, + "berkeley_fanuc_manipulation": { + "versions": ["0.1.0"], + "episodes": 415, + "steps": 62613, + "visual_observation_keys": ["image", "wrist_image"], + }, + "berkeley_gnm_cory_hall": { + "versions": ["0.1.0"], + "episodes": 7331, + "steps": 156012, + "visual_observation_keys": ["image"], + }, + "berkeley_gnm_recon": { + "versions": ["0.1.0"], + "episodes": 11834, + "steps": 610907, + "visual_observation_keys": ["image"], + }, + "berkeley_gnm_sac_son": { + "versions": ["0.1.0"], + "episodes": 2955, + "steps": 241059, + "visual_observation_keys": ["image"], + }, + "berkeley_mvp_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 480, + "steps": 45308, + "visual_observation_keys": ["hand_image"], + }, + "berkeley_rpt_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 908, + "steps": 392578, + "visual_observation_keys": ["hand_image"], + }, + "bridge": {"versions": ["0.1.0"], "episodes": 25460, "steps": 864292, "visual_observation_keys": ["image"]}, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 199, + "steps": 1990, + "visual_observation_keys": ["image"], + }, + "cmu_play_fusion": { + "versions": ["0.1.0"], + "episodes": 576, + "steps": 235922, + "visual_observation_keys": ["image"], + }, + "cmu_playing_with_food": { # this dataset seems to be corrupted + "versions": ["1.0.0"], + "episodes": 4200, + "steps": 83240, + "visual_observation_keys": ["image"], + }, + "cmu_stretch": {"versions": ["0.1.0"], "episodes": 135, "steps": 25016, "visual_observation_keys": ["image"]}, + "columbia_cairlab_pusht_real": { + "versions": ["0.1.0"], + "episodes": 122, + "steps": 24924, + "visual_observation_keys": ["image", "wrist_image"], + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 104, + "steps": 8928, + "visual_observation_keys": ["image"], + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 107, + "steps": 7622, + "visual_observation_keys": ["image"], + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 100, + "steps": 12971, + "visual_observation_keys": ["image"], + }, + "eth_agent_affordances": { + "versions": ["0.1.0"], + "episodes": 118, + "steps": 151628, + "visual_observation_keys": ["image"], + }, + "fanuc_manipulation_v2": { + "versions": ["1.0.0"], + "episodes": 415, + "steps": 62613, + "visual_observation_keys": ["image", "wrist_image"], + }, + "fractal20220817_data": { + "versions": ["0.1.0"], + "episodes": 87212, + "steps": 3786400, + "visual_observation_keys": ["image"], + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 5100, + "steps": 3948057, + "visual_observation_keys": ["image", "wrist_image"], + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 631, + "steps": 146241, + "visual_observation_keys": ["image", "wrist_image"], + }, + "imperial_wrist_dataset": { + "versions": ["1.0.0"], + "episodes": 170, + "steps": 7148, + "visual_observation_keys": ["image", "wrist_image"], + }, + "imperialcollege_sawyer_wrist_cam": { + "versions": ["0.1.0"], + "episodes": 170, + "steps": 7148, + "visual_observation_keys": ["image", "wrist_image"], + }, + "jaco_play": { + "versions": ["0.1.0"], + "episodes": 976, + "steps": 70127, + "visual_observation_keys": ["image", "image_wrist"], + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 201, + "steps": 32429, + "visual_observation_keys": ["image"], + }, + "kuka": {"versions": ["0.1.0"], "episodes": 580392, "steps": 8583978, "visual_observation_keys": ["image"]}, + "language_table": { + "versions": ["0.0.1", "0.1.0"], + "episodes": 442226, + "steps": 7045476, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktoabsolute_oracle_sim": { + "versions": ["0.0.1"], + "episodes": 200000, + "steps": 15866385, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktoblock_4block_sim": { + "versions": ["0.0.1"], + "episodes": 8298, + "steps": 326768, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktoblock_oracle_sim": { + "versions": ["0.0.1"], + "episodes": 200000, + "steps": 12970620, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktoblock_sim": { + "versions": ["0.0.1"], + "episodes": 8000, + "steps": 351688, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktoblockrelative_oracle_sim": { + "versions": ["0.0.1"], + "episodes": 200000, + "steps": 13016749, + "visual_observation_keys": ["rgb"], + }, + "language_table_blocktorelative_oracle_sim": { + "versions": ["0.0.1"], + "episodes": 200000, + "steps": 8655815, + "visual_observation_keys": ["rgb"], + }, + "language_table_separate_oracle_sim": { + "versions": ["0.0.1"], + "episodes": 200000, + "steps": 3196661, + "visual_observation_keys": ["rgb"], + }, + "language_table_sim": { + "versions": ["0.0.1"], + "episodes": 181020, + "steps": 4665423, + "visual_observation_keys": ["rgb"], + }, + "maniskill_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 30213, + "steps": 4537402, + "visual_observation_keys": ["image", "wrist_image"], + }, + "mutex_dataset": { + "versions": ["1.0.0"], + "episodes": 1500, + "steps": 361883, + "visual_observation_keys": ["image", "wrist_image"], + }, + "nyu_door_opening_surprising_effectiveness": { + "versions": ["0.1.0"], + "episodes": 435, + "steps": 18196, + "visual_observation_keys": ["image"], + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 365, + "steps": 34448, + "visual_observation_keys": ["image", "image_additional_view"], + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 14, + "steps": 440, + "visual_observation_keys": ["image"], + }, + "qut_dexterous_manpulation": { + "versions": ["0.1.0"], + "episodes": 200, + "steps": 176278, + "visual_observation_keys": ["image", "wrist_image"], + }, + "robo_net": { + "versions": ["0.1.0", "1.0.0"], + "episodes": 82775, + "steps": 2483250, + "visual_observation_keys": ["image", "image1", "image2"], + }, + "robot_vqa": { + "versions": ["0.1.0"], + "episodes": 3331523, + "steps": 3331523, + "visual_observation_keys": ["images"], + }, + "roboturk": { + "versions": ["0.1.0"], + "episodes": 1796, + "steps": 168423, + "visual_observation_keys": ["front_rgb"], + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 570, + "steps": 358234, + "visual_observation_keys": ["image", "wrist_image"], + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 3000, + "steps": 149985, + "visual_observation_keys": ["image"], + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 9109, + "steps": 282379, + "visual_observation_keys": ["image"], + }, + "stanford_robocook_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 2460, + "steps": 112980, + "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"], + }, + "taco_play": { + "versions": ["0.1.0"], + "episodes": 3242, + "steps": 213972, + "visual_observation_keys": ["rgb_static", "rgb_gripper"], + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 50, + "steps": 11925, + "visual_observation_keys": ["image"], + }, + "toto": {"versions": ["0.1.0"], "episodes": 902, "steps": 294139, "visual_observation_keys": ["image"]}, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 150, + "steps": 3970, + "visual_observation_keys": ["image"], + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 1355, + "steps": 67750, + "visual_observation_keys": ["image"], + }, + "uiuc_d3field": { # this dataset seems to be corrupted + "versions": ["0.1.0", "1.1.2"], + "episodes": 196, + "steps": 13384, + "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"], + }, + "usc_cloth_sim_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 800, + "steps": 80000, + "visual_observation_keys": ["image"], + }, + "utaustin_mutex": { + "versions": ["0.1.0"], + "episodes": 1500, + "steps": 361883, + "visual_observation_keys": ["image", "wrist_image"], + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 64, + "steps": 9140, + "visual_observation_keys": ["image"], + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 192, + "steps": 26346, + "visual_observation_keys": ["image"], + }, + "utokyo_saytap_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 20, + "steps": 22937, + "visual_observation_keys": ["image", "wrist_image"], + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 64, + "steps": 1388, + "visual_observation_keys": ["image"], + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "versions": ["0.1.0"], + "episodes": 92, + "steps": 6789, + "visual_observation_keys": ["image", "hand_image", "image2"], + }, + "viola": { + "versions": ["0.1.0"], + "episodes": 135, + "steps": 68913, + "visual_observation_keys": ["agentview_rgb", "eye_in_hand_rgb"], + }, + } +) + + +def oxe_dsname2path(dataset_name: str, version: Optional[str] = None) -> str: + """From dataset name to remote google clound path to the dataset. + + Args: + dataset_name (str): dataset name. + version (Optional[str]): version string. + + Returns: + str: google clound path + """ + if version is None: + version = ALL_OXE_DATASETS[dataset_name]["versions"][-1] + return f"gs://gresearch/robotics/{dataset_name}/{version}" diff --git a/theia/dataset/oxe/oxe_mixes.py b/theia/dataset/oxe/oxe_mixes.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb793ce649b8d8f7f689c21363a464189910371 --- /dev/null +++ b/theia/dataset/oxe/oxe_mixes.py @@ -0,0 +1,139 @@ +# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +"""MIT License Copyright (c) 2023 Robotic AI & Learning Lab Berkeley + +From Octo https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_mixes.py +""" + +BRIDGE_MIX = [ + ("bridge_dataset", 1.0), +] + +RT_X_MIX = [ + ("fractal20220817_data", 0.54087122203), + ("kuka", 0.8341046294), + ("bridge_dataset", 1.0), + ("taco_play", 2.0), + ("jaco_play", 2.0), + ("berkeley_cable_routing", 3.0), + ("roboturk", 1.0), + ("nyu_door_opening_surprising_effectiveness", 5.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), +] + + +OXE_FRANKA_MIX = [ + ("taco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("viola", 1.0), + ("toto", 1.0), + ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), + ("austin_buds_dataset_converted_externally_to_rlds", 3.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("maniskill_dataset_converted_externally_to_rlds", 0.1), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("berkeley_rpt_converted_externally_to_rlds", 1.0), + ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), + ("stanford_robocook_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + # ("cmu_playing_with_food", 1.0), + ("cmu_play_fusion", 1.0), +] + +OXE_MAGIC_SOUP = [ + ("fractal20220817_data", 0.54087122203), + ("kuka", 0.8341046294), + ("bridge", 1.0), + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + ("nyu_door_opening_surprising_effectiveness", 1.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("bc_z", 0.2), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + # ("uiuc_d3field", 1.0), --> somehow raw data is broken + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), +] + + +OXE_FULL_MIX = [ + ("fractal20220817_data", 1.0), + ("kuka", 1.0), + ("bridge_dataset", 1), + ("taco_play", 1.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 1.0), + ("nyu_door_opening_surprising_effectiveness", 1.0), + ("viola", 1.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), + ("language_table", 1.0), + ("columbia_cairlab_pusht_real", 1.0), + ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0), + ("nyu_rot_dataset_converted_externally_to_rlds", 1.0), + ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0), + ("maniskill_dataset_converted_externally_to_rlds", 1.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 1.0), + ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0), + ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("bc_z", 1.0), + ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0), + ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0), + ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0), + ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0), + ("robo_net", 1.0), + ("berkeley_mvp_converted_externally_to_rlds", 1.0), + ("berkeley_rpt_converted_externally_to_rlds", 1.0), + ("kaist_nonprehensile_converted_externally_to_rlds", 1.0), + ("stanford_mask_vit_converted_externally_to_rlds", 1.0), + ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0), + ("dlr_sara_pour_converted_externally_to_rlds", 1.0), + ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("asu_table_top_converted_externally_to_rlds", 1.0), + ("stanford_robocook_converted_externally_to_rlds", 1.0), + ("imperialcollege_sawyer_wrist_cam", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("uiuc_d3field", 1.0), + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 1.0), + ("cmu_playing_with_food", 1.0), + ("cmu_play_fusion", 1.0), + ("cmu_stretch", 1.0), + ("berkeley_gnm_recon", 1.0), + ("berkeley_gnm_cory_hall", 1.0), + ("berkeley_gnm_sac_son", 1.0), +] + +OXE_NAMED_MIXES = { + "bridge": BRIDGE_MIX, + "rtx": RT_X_MIX, + "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX, + "oxe_magic_soup": OXE_MAGIC_SOUP, +} diff --git a/theia/dataset/oxe/oxe_transforms.py b/theia/dataset/oxe/oxe_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..43cca3118bc0bd419388faee57cd6dcdd89f62ec --- /dev/null +++ b/theia/dataset/oxe/oxe_transforms.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import torch +from numpy.typing import NDArray +from torchvision.transforms.v2 import Compose, Normalize, ToDtype, ToImage + + +def totensor(arr: NDArray) -> torch.Tensor: + """Convert ndarray to tensor.""" + return torch.from_numpy(arr) + + +oxe_image_transform = Compose( + [ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] +) # ImageNet statistics normalization diff --git a/theia/dataset/video/__init__.py b/theia/dataset/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04f7ecdb72ccb4f56e2b19f30f7c49d2245d8195 --- /dev/null +++ b/theia/dataset/video/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .video_common import ALL_VIDEO_DATASETS diff --git a/theia/dataset/video/video_common.py b/theia/dataset/video/video_common.py new file mode 100644 index 0000000000000000000000000000000000000000..5b950ad31dbaa6a68ce0c7b8d0ca778bb8dcbd85 --- /dev/null +++ b/theia/dataset/video/video_common.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from collections import OrderedDict + +ALL_VIDEO_DATASETS = OrderedDict( + { + "ego4d_1in150": {"steps": 2_800_871}, + "epic_kitchen_1in60": {"steps": 333_117}, + "ssv2_1in32": {"steps": 312_772}, + } +) diff --git a/theia/decoding/__init__.py b/theia/decoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af4707e751ee70d8a72adb50df0ef4e559b320fc --- /dev/null +++ b/theia/decoding/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .decode import decode_everything, load_feature_stats +from .depth_anything import prepare_depth_decoder +from .sam import prepare_mask_generator diff --git a/theia/decoding/decode.py b/theia/decoding/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..99a387fd136b65f63425269f22d94fa0eaec41d1 --- /dev/null +++ b/theia/decoding/decode.py @@ -0,0 +1,198 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import os +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from numpy.typing import NDArray +from PIL import Image +from sklearn.decomposition import PCA +from transformers import SamModel, SamProcessor +from transformers.pipelines import MaskGenerationPipeline + +from theia.decoding.depth_anything import decode_depth_anything +from theia.decoding.dinov2 import decode_dinov2 +from theia.decoding.sam import decode_sam +from theia.preprocessing.feature_extraction_core import ( + get_feature_outputs, + get_model, +) + + +def denormalize_feature( + x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Denormalize the features using mean and std. + + Args: + x (torch.Tensor): features to be denomalized. + mean (Optional[torch.Tensor], optional): mean value of the features. Defaults to None + std (Optional[torch.Tensor], optional): std value of the features. Defaults to None. + + Returns: + torch.Tensor: denormalized features. + """ + if mean is None and std is None: + return x + elif mean is None and std is not None: + return x * std + elif mean is not None and std is None: + return x + mean + return x * std + mean + + +def load_feature_stats( + feature_models: list[str], stat_file_root: str +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Load the statistics (mean and variance) of the features, per model. + + Args: + feature_models (list[str]): names of the models. Note: there are `/` in the name. + stat_file_root (str): directory that holds feature stat files. + + Returns: + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variance. + """ + feature_means: dict[str, torch.Tensor] = {} + feature_vars: dict[str, torch.Tensor] = {} + for model in feature_models: + model_name = model.replace("/", "_") + feature_means[model] = torch.from_numpy( + np.load(os.path.join(stat_file_root, f"imagenet_mean_{model_name}.npy")) + ) + feature_vars[model] = torch.from_numpy(np.load(os.path.join(stat_file_root, f"imagenet_var_{model_name}.npy"))) + return feature_means, feature_vars + + +def decode_everything( + theia_model: nn.Module, + feature_means: dict[str, torch.Tensor], + feature_vars: dict[str, torch.Tensor], + images: list[Image.Image], + mask_generator: MaskGenerationPipeline, + sam_model: SamModel, + depth_anything_decoder: nn.Module, + pred_iou_thresh: float = 0.9, + stability_score_thresh: float = 0.9, + gt: bool = False, + pca: Optional[PCA] = None, + device: int | str | torch.device = 0, +) -> tuple[list[NDArray], Optional[list[NDArray]]]: + """Decode features from given `theia_model` into different outputs corresponding to upstream models including + DINOv2, Sam, and Depth-Anything. + + Args: + theia_model (nn.Module): theia model. + feature_means (dict[str, torch.Tensor]): means of the features for denormalization. + feature_vars (dict[str, torch.Tensor]): variance of the features for denormalization. + images (list[Image.Image]): input images. + mask_generator (MaskGenerationPipeline): mask generation pipeline. + sam_model (SamModel): sam model. + depth_anything_decoder (nn.Module): depth anything decoder. + pred_iou_thresh (float, optional): iou threshold for mask generation. + See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9. + stability_score_thresh (float, optional): stability score threshold for mask generation. + See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9. + gt (bool): whether to attach ground truth result in the visualization. Defaults to False. + pca (Optional[PCA]): pca for DINOv2 decoding. If provided, will use this pca particular. Defaults to None. + device (int | str | torch.device, optional): device for decoding. Defaults to 0. + + Returns: + tuple[list[NDArray], Optional[list[NDArray]]]: decoding results from given model, + and ground truth (if `gt=True`). + """ + features: dict[str, torch.Tensor] = {} + with torch.no_grad(): + for im in images: + feature = theia_model([im]) + if len(features) == 0: + features = {k: [] for k in feature} + for k in feature: + features[k].append(feature[k].detach().cpu()) + for k in features: + features[k] = torch.cat(features[k], dim=0) + for m in features: + features[m] = denormalize_feature(features[m], feature_means[m], feature_vars[m]) + + dino_model_name = "facebook/dinov2-large" + sam_model_name = "facebook/sam-vit-huge" + depth_anything_model_name = "LiheYoung/depth-anything-large-hf" + + pca = None + # gt + gt_decode_results = None + if gt: + def legit_model_name(model_name: str) -> str: + return model_name.replace("/", "_") + + dino_model, dino_processor = get_model(dino_model_name, device=device) + dino_gt_feature = [] + for im in images: + dino_gt_feature.append( + get_feature_outputs( + legit_model_name(dino_model_name), dino_model, dino_processor, [im], dtype=torch.float + )[legit_model_name(dino_model_name)]["embedding"] + .detach() + .cpu() + ) + dino_gt_feature = torch.cat(dino_gt_feature, dim=0) + dino_gt_feature = rearrange(dino_gt_feature, "b c h w -> b (h w) c") + dino_gt_dec, pca = decode_dinov2(dino_gt_feature, pca=pca) + sam_processor = SamProcessor.from_pretrained(sam_model_name) + sam_gt_feature = [] + for im in images: + sam_inputs = sam_processor(images=[im], return_tensors="pt").to(device) + with torch.no_grad(): + sam_gt_feature.append(sam_model.get_image_embeddings(sam_inputs["pixel_values"]).detach().cpu()) + sam_gt_feature = torch.cat(sam_gt_feature, dim=0) + sam_gt_feature = rearrange(sam_gt_feature, "b c h w -> b (h w) c") + sam_gt_dec = decode_sam( + sam_gt_feature, images, mask_generator, pred_iou_thresh=0.9, stability_score_thresh=0.9, device=device + ) + depth_anything_model, depth_anything_processor = get_model(depth_anything_model_name, device=device) + depth_anything_gt_feature = [] + for im in images: + depth_anything_gt_feature.append( + get_feature_outputs( + legit_model_name(depth_anything_model_name), + depth_anything_model, + depth_anything_processor, + [im], + dtype=torch.float, + )[legit_model_name(depth_anything_model_name)]["embedding"] + .detach() + .cpu() + ) + depth_anything_gt_feature = torch.cat(depth_anything_gt_feature, dim=0) + depth_anything_gt_feature = rearrange(depth_anything_gt_feature, "b c h w -> b (h w) c") + depth_gt_dec = decode_depth_anything(depth_anything_gt_feature, depth_anything_decoder, device=device) + + gt_decode_results = [ + np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_gt_dec[i], sam_gt_dec[i], depth_gt_dec[i]]) + for i in range(len(images)) + ] + + dino_dec, _ = decode_dinov2(features[dino_model_name], pca=pca) + + try: + sam_dec = decode_sam( + features[sam_model_name], + images, + mask_generator, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + device=device, + ) + except IndexError: + sam_dec = np.zeros_like(dino_dec) + depth_dec = decode_depth_anything(features[depth_anything_model_name], depth_anything_decoder, device=device) + + theia_decode_results = [ + np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_dec[i], sam_dec[i], depth_dec[i]]) + for i in range(len(images)) + ] + + return theia_decode_results, gt_decode_results diff --git a/theia/decoding/depth_anything.py b/theia/decoding/depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..26836d53966d4f676afe02d62ec30f78d148cce9 --- /dev/null +++ b/theia/decoding/depth_anything.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import torch +import torch.nn as nn +from einops import rearrange +from theia.foundation_models.vision_models.depth_anything import DepthAnythingForDepthEstimation +from numpy.typing import NDArray +from torch.nn.functional import interpolate + + +def prepare_depth_decoder(model_name: str, device: int | str | torch.device = 0) -> tuple[nn.Module, int]: + """Prepare a depth decoder using DepthAnythingForDepthEstimation. + + Args: + model_name (str): name of the depth anything model. + device (int | str | torch.device, optional): device to put the model on. Defaults to 0. + + Returns: + tuple[nn.Module, int]: the decoder, and the patch size for depth anything model. + """ + decoder_head = DepthAnythingForDepthEstimation.from_pretrained(model_name) + patch_size = decoder_head.config.patch_size + decoder_head = decoder_head.head + decoder_head = decoder_head.to(device) + return decoder_head, patch_size + + +def decode_depth_anything(features: torch.Tensor, decoder: nn.Module, device: int | str | torch.device = 0) -> NDArray: + """Decode features to predicted depth using depth anything + + Args: + features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim]. + decoder (nn.Module): depth anything decoder + device (int | str | torch.device, optional): device to perform the decoding. Defaults to 0. + + Returns: + NDArray: decoded depth in image format, represented by an NDArray in size [batch_size, height, width, channels] + with value between [0, 1]. The depth values are min-max normalized to [0, 1] to generate images. + """ + with torch.no_grad(): + P = int(features.size(1) ** 0.5) + features = rearrange(features, "b (h w) c -> b c h w", h=P, w=P) + features = interpolate(features, (224, 224)) + predicted_depths = [] + for feature in features: + feature = feature.unsqueeze(0).to(device) + + predicted_depth = decoder.activation1(feature) + predicted_depth = decoder.conv3(predicted_depth) + predicted_depth = decoder.activation2(predicted_depth) + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + for i in range(len(predicted_depth)): + min_depth, max_depth = predicted_depth[i].min(), predicted_depth[i].max() + predicted_depth[i] = (predicted_depth[i] - min_depth) / (max_depth - min_depth) + predicted_depths.append(predicted_depth.detach().cpu()) + predicted_depths = torch.cat(predicted_depths, dim=0) + return predicted_depths.unsqueeze(-1).repeat((1, 1, 1, 3)).numpy() # type: ignore [attr-defined] diff --git a/theia/decoding/dinov2.py b/theia/decoding/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..29c4801c7846dc7f7767993472ca8c59033a94c6 --- /dev/null +++ b/theia/decoding/dinov2.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Optional + +import cv2 +import numpy as np +from numpy.typing import NDArray +from sklearn.decomposition import PCA +from sklearn.preprocessing import minmax_scale + + +def decode_dinov2( + features: NDArray, threshold: int | float = -100, interpolation: bool = False, pca: Optional[PCA] = None +) -> tuple[NDArray, PCA]: + """ + Decode the input `features` in DINOv2 style using PCA. + + Args: + features (NDArray): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim]. + threshold (int | float): threshold of foreground-background split in PCA visualization. + Defaults to -100 (all patches are included). + interpolation (bool): whether interpolate the 16x16 pca map to the original image size. + pca (Optional[PCA]): if provided, use the provided PCA. This is to keep visualizations stable across samples. + + Returns: + tuple[NDArray, PCA]: the rendered image of this visualization, in NDArray in size + [batch_size, height, width, channels] with value ranges [0, 1], and the PCA used in this visualization. + """ + features = features.numpy() + batch_size, spatial_size, latent_dim = features.shape + h = w = int(spatial_size**0.5) + + features = features.reshape(-1, latent_dim) + + if pca is None: + pca = PCA(n_components=3) + pca.fit(features) + + pca_features = pca.transform(features) + + # segment using the first component + bg_mask = pca_features[:, 0] < threshold + fg_mask = ~bg_mask + + # PCA for only foreground patches + # pca.fit(features[fg_mask]) + pca_features_fg = pca.transform(features[fg_mask]) + for i in range(3): + pca_features_fg[:, i] = minmax_scale(pca_features_fg[:, i]) + + pca_features_rgb = pca_features.copy() + pca_features_rgb[bg_mask] = 0 + pca_features_rgb[fg_mask] = pca_features_fg + + pca_features_rgb = pca_features_rgb.reshape(batch_size, h, w, 3) + if not interpolation: + H = W = 224 + scale = H // h + interpolated_pca_features = np.zeros((batch_size, H, W, 3), dtype=pca_features_rgb.dtype) + for i in range(len(pca_features_rgb)): + for j in range(h): + for k in range(w): + interpolated_pca_features[i, scale * j : scale * (j + 1), scale * k : scale * (k + 1)] = ( + pca_features_rgb[i, j, k] + ) + pca_features_rgb = interpolated_pca_features + else: + pca_features_rgb = np.stack([cv2.resize(p, (224, 224)) for p in pca_features_rgb]) + return pca_features_rgb, pca diff --git a/theia/decoding/sam.py b/theia/decoding/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..476d9c4af6c829a42d1621886beb7a8a807976d1 --- /dev/null +++ b/theia/decoding/sam.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any, Generator, Optional + +import numpy as np +import torch +from einops import rearrange +from numpy.typing import NDArray +from PIL import Image +from transformers import SamModel, SamProcessor +from transformers.image_utils import load_image +from transformers.pipelines import MaskGenerationPipeline + + +class MaskGenerationPipelineWithEmbeddings(MaskGenerationPipeline): + """ + The wrapper class for huggingface transformers.pipelines.MaskGenerationPipeline + that can decode from intermediate SAM embeddings. + """ + + def _sanitize_parameters(self, **kwargs: Any) -> tuple[dict[str, Any], ...]: + preprocess_kwargs = {} + postprocess_kwargs = {} + forward_params = {} + # preprocess args + if "embeddings" in kwargs: # inject embeddings here + preprocess_kwargs["embeddings"] = kwargs["embeddings"] + if "points_per_batch" in kwargs: + preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"] + if "points_per_crop" in kwargs: + preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"] + if "crops_n_layers" in kwargs: + preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"] + if "crop_overlap_ratio" in kwargs: + preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"] + if "crop_n_points_downscale_factor" in kwargs: + preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"] + if "timeout" in kwargs: + preprocess_kwargs["timeout"] = kwargs["timeout"] + # postprocess args + if "pred_iou_thresh" in kwargs: + forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"] + if "stability_score_offset" in kwargs: + forward_params["stability_score_offset"] = kwargs["stability_score_offset"] + if "mask_threshold" in kwargs: + forward_params["mask_threshold"] = kwargs["mask_threshold"] + if "stability_score_thresh" in kwargs: + forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"] + if "crops_nms_thresh" in kwargs: + postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"] + if "output_rle_mask" in kwargs: + postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"] + if "output_bboxes_mask" in kwargs: + postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"] + return preprocess_kwargs, forward_params, postprocess_kwargs + + def preprocess( + self, + image: list[Image.Image], + points_per_batch: int = 64, + crops_n_layers: int = 0, + crop_overlap_ratio: float = 512 / 1500, + points_per_crop: int = 32, + crop_n_points_downscale_factor: int = 1, + timeout: Optional[float] = None, + embeddings: Optional[torch.Tensor] = None, + ) -> Generator[Any, Any, Any]: + image = load_image(image, timeout=timeout) + target_size = self.image_processor.size["longest_edge"] + crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes( + image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor + ) + model_inputs = self.image_processor(images=cropped_images, return_tensors="pt") + + with self.device_placement(): + if self.framework == "pt": + inference_context = self.get_inference_context() + with inference_context(): + model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) + if embeddings is None: + image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) + else: + model_inputs.pop("pixel_values") + image_embeddings = embeddings + model_inputs["image_embeddings"] = image_embeddings + + n_points = grid_points.shape[1] + points_per_batch = points_per_batch if points_per_batch is not None else n_points + + if points_per_batch <= 0: + raise ValueError( + "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. " + "To return all points at once, set points_per_batch to None" + ) + + for i in range(0, n_points, points_per_batch): + batched_points = grid_points[:, i : i + points_per_batch, :, :] + labels = input_labels[:, i : i + points_per_batch] + is_last = i == n_points - points_per_batch + yield { + "input_points": batched_points, + "input_labels": labels, + "input_boxes": crop_boxes, + "is_last": is_last, + **model_inputs, + } + + +def draw_mask(mask: NDArray, random_color: bool = False) -> NDArray: + """Draw the mask on an image. + + Args: + mask (NDArray): mask in shape [height, width]. + random_color (bool): if using a random color. Defaults to False. + + Returns: + NDArray: NDArray format of the image. + """ + if random_color: + color = np.concatenate([np.random.random(3)], axis=0) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + return mask_image + + +def decode_sam( + features: torch.Tensor, + images: list[Image.Image], + mask_generator: Any, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.5, + stability_score_thresh: float = 0.6, + random_color: bool = True, + device: int | str | torch.device = 0, +) -> NDArray: + """Decode features using SAM (auto-prompting) mask generation pipeline. + + Args: + features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim]. + images (list[Image.Image]): images corresponding to these features. + mask_generator (Any): mask generation pipeline. + points_per_batch (int): points per batch for auto-prompting. Defaults to 64. + See transformers.pipelines.MaskGenerationPipeline for more details. Same below. + pred_iou_thresh (float): iou threshold. Defaults to 0.5. + stability_score_thresh (float): stability threshold. Defaults to 0.6. + random_color (bool): if using a random color. Defaults to True. + device (int | str | torch.device): device to perform the decoding. Defaults to 0. + + Returns: + NDArray: decoded masks rendered in image format, represented by an NDArray in size + [batch_size, height, width, channels] with value between [0, 1]. + """ + masks_rgbs = [] + num_patches = int(features.size(1) ** 0.5) + features = rearrange(features, "b (h w) c -> b c h w", h=num_patches, w=num_patches) + with torch.no_grad(): + for im, feature in zip(images, features, strict=False): + predicted_ouputs = mask_generator( + im, + points_per_batch=points_per_batch, + embeddings=feature.unsqueeze(0).to(device), + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + ) + predicted_masks = predicted_ouputs["masks"] + masks_rgb = np.zeros((224, 224, 3), dtype=np.float32) + for mask in predicted_masks: + masks_rgb += draw_mask(mask, random_color=random_color) + # masks_rgb = cv2.cvtColor(masks_rgb, cv2.COLOR_RGBA2RGB) + masks_rgbs.append(masks_rgb) + return np.stack(masks_rgbs) + + +def prepare_mask_generator(device: int | str | torch.device = 0) -> MaskGenerationPipeline: + """Prepare a mask generation pipeline on device `device`. + + Args: + device (int | str | torch.device): device to perform mask generation. Defaults to 0. + + Returns: + MaskGenerationPipeline: mask generator. + """ + sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + sam_model.eval() + mask_generator = MaskGenerationPipelineWithEmbeddings( + task="mask_generation", model=sam_model, image_processor=processor.image_processor, device=device + ) + return mask_generator, sam_model diff --git a/theia/example/decode_to_vfms.ipynb b/theia/example/decode_to_vfms.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5f5aae323740508df6e30289fc9979c1960fd50c --- /dev/null +++ b/theia/example/decode_to_vfms.ipynb @@ -0,0 +1,69 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import cv2\n", + "import torch\n", + "from PIL import Image\n", + "import numpy as np\n", + "from transformers import AutoModel\n", + "from torchvision.io import read_video, write_video\n", + "from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything\n", + "\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "theia_model = AutoModel.from_pretrained(\"theaiinstitute/theia-base-patch16-224-cdiv\", trust_remote_code=True)\n", + "theia_model = theia_model.to(device)\n", + "target_model_names = [\n", + " \"google/vit-huge-patch14-224-in21k\",\n", + " \"facebook/dinov2-large\",\n", + " \"openai/clip-vit-large-patch14\",\n", + " \"facebook/sam-vit-huge\",\n", + " \"LiheYoung/depth-anything-large-hf\",\n", + "]\n", + "feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=\"../../../feature_stats\")\n", + "\n", + "mask_generator, sam_model = prepare_mask_generator(device)\n", + "depth_anything_model_name = \"LiheYoung/depth-anything-large-hf\"\n", + "depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)\n", + "\n", + "example_video_path = \"../../../media/example_video_to_visualize.mp4\"\n", + "video, _, _ = read_video(example_video_path, pts_unit=\"sec\", output_format=\"THWC\")\n", + "video = video.numpy()\n", + "images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]\n", + "\n", + "theia_decode_results, gt_decode_results = decode_everything(\n", + " theia_model=theia_model,\n", + " feature_means=feature_means,\n", + " feature_vars=feature_vars,\n", + " images=images,\n", + " mask_generator=mask_generator,\n", + " sam_model=sam_model,\n", + " depth_anything_decoder=depth_anything_decoder,\n", + " pred_iou_thresh=0.5,\n", + " stability_score_thresh=0.7,\n", + " gt=True,\n", + " device=device,\n", + ")\n", + "\n", + "vis_video = np.stack(\n", + " [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]\n", + ")\n", + "vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)\n", + "vis_save_path = \"./visualized.mp4\"\n", + "write_video(vis_save_path, vis_video, fps=10)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/theia/foundation_models/__init__.py b/theia/foundation_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adcb4c57dd9d8d87bc7c4e0fbc4a9da7a0e5cc5d --- /dev/null +++ b/theia/foundation_models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .vision_language_models.clip import get_clip_feature, get_clip_model +from .vision_language_models.llava import get_llava_vision_model, get_llava_visual_feature +from .vision_models.deit import get_deit_feature, get_deit_model +from .vision_models.depth_anything import get_depth_anything_feature, get_depth_anything_model +from .vision_models.dinov2 import get_dinov2_feature, get_dinov2_model +from .vision_models.sam import get_sam_feature, get_sam_model +from .vision_models.vit import get_vit_feature, get_vit_model diff --git a/theia/foundation_models/common.py b/theia/foundation_models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0358b35ebdae4129df80a09dc9e46d991e3d61e3 --- /dev/null +++ b/theia/foundation_models/common.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import math + +import torch + +MODELS = [ + "facebook/dinov2-large", + "facebook/sam-vit-huge", + "google/vit-huge-patch14-224-in21k", + "llava-hf/llava-1.5-7b-hf", + "openai/clip-vit-large-patch14", + "LiheYoung/depth-anything-large-hf", +] + +# handy model feature size constants +# in the format of (latent_dim, width, height) +MODEL_FEATURE_SIZES = { + "facebook/dinov2-large": (1024, 16, 16), + "facebook/sam-vit-huge": (256, 64, 64), + "google/vit-huge-patch14-224-in21k": (1280, 16, 16), + "llava-hf/llava-1.5-7b-hf": (1024, 24, 24), + "openai/clip-vit-large-patch14": (1024, 16, 16), + "LiheYoung/depth-anything-large-hf": (32, 64, 64), +} + + +def get_model_feature_size( + model_name: str, keep_spatial: bool = False, return_torch_size: bool = False +) -> tuple[int, ...] | torch.Size: + """ + Get the size of queried model feature. + + Args: + model_name (str): name of the model. + keep_spatial (bool): whether to preserve spatial dim. Defaults to False. + return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False. + + Returns: + tuple[int, ...] | torch.Size: the size of the feature. + """ + size: tuple[int, ...] = MODEL_FEATURE_SIZES[model_name] + + if not keep_spatial: + size = (size[0], math.prod(size[1:])) + + if return_torch_size: + size = torch.Size(size) + + return size + + +def get_max_model_spatial_size( + keep_spatial: bool = True, + return_torch_size: bool = False, + return_model_name: bool = False, +) -> tuple[int, ...] | tuple[tuple[int, ...], str]: + """Get the maximal spatial dimensions from available models + + Args: + keep_spatial (bool): whether to preserve spatial dim. Defaults to True. + return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False. + return_model_name (bool): the name of the model with maximal size. Defaults to False. + + Returns: + tuple[int, ...] | tuple[tuple[int, ...], str]: the maximal size and optional model name. + """ + max_flatten_size = -1 + max_size: tuple[int, ...] = () + max_size_model_name: str = "" + for model, size in MODEL_FEATURE_SIZES.items(): + flatten_size = math.prod(size[1:]) + if flatten_size > max_flatten_size: + max_flatten_size = flatten_size + max_size = size[1:] + max_size_model_name = model + + if not keep_spatial: + max_size = (max_flatten_size,) + + if return_torch_size: + max_size = torch.Size(max_size) + + if return_model_name: + return max_size, max_size_model_name + else: + return max_size diff --git a/theia/foundation_models/vision_language_models/__init__.py b/theia/foundation_models/vision_language_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/foundation_models/vision_language_models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/foundation_models/vision_language_models/clip.py b/theia/foundation_models/vision_language_models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..eb54c1522e220637ced52204b5ec5ce6a2054466 --- /dev/null +++ b/theia/foundation_models/vision_language_models/clip.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import numpy as np +import torch +from transformers import AutoProcessor, CLIPVisionModel + + +def get_clip_feature( + model: CLIPVisionModel, processor: AutoProcessor, images: list[np.ndarray], requires_grad: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get features from the visual encoder of CLIP. + + Args: + model (CLIPVisionModel): CLIP model. + processor (AutoProcessor): CLIP input processor. + images (list[np.ndarray]): images to be encoded, in RGB, uint8. + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: features from clip ( + cls_token: last layer embedding from cls token # (1, 1, 1024) if vit-large, + visual_tokens: last layer embeddings from image # (1, 1024, 16, 16) BCHW if vit-large, + pooled_cls_token: last layer embedding from cls + layernorm # (1, 1, 1024) if vit-large + ) + """ + inputs = processor(images=images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + cls_token = outputs.last_hidden_state[:, :1] # (1, 1, 1024) if vit-large + visual_tokens = outputs.last_hidden_state[:, 1:] # (1, 256, 1024) if vit-large + pooled_cls_token = outputs.pooler_output.unsqueeze(1) # (1, 1, 1024) if vit-large + batch_size, num_patches, num_channels = visual_tokens.size() + visual_tokens = visual_tokens.transpose(1, 2) + visual_tokens = visual_tokens.reshape( + batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) + ) # (1, 1024, 16, 16) BCHW for vit-huge + return cls_token, visual_tokens, pooled_cls_token + + +def get_clip_model( + model_name: str = "openai/clip-vit-large-patch14", device: str | torch.device = "cuda" +) -> tuple[CLIPVisionModel, AutoProcessor]: + """Get CLIP model and its input processor. + + Args: + model_name (str, optional): name of CLIP model. Defaults to "openai/clip-vit-large-patch14". + device (str | torch.device, optional): device to put the model on. Defaults to "cuda". + + Returns: + tuple[CLIPVisionModel, AutoProcessor]: CLIP model and the correponding input processor. + """ + processor = AutoProcessor.from_pretrained(model_name) + model = CLIPVisionModel.from_pretrained(model_name).to(device) + return model, processor + + +def print_feature_size(model_name: str = "openai/clip-vit-large-patch14") -> None: + """Print the sizes of features from CLIP. + + Args: + model_name (str, optional): the name of CLIP model. Defaults to "openai/clip-vit-large-patch14". + """ + import requests + from PIL import Image + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = [np.array(Image.open(requests.get(url, stream=True).raw))] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_clip_model(model_name, device=device) + cls_token, visual_tokens, pooled_cls_token = get_clip_feature(model, processor, image) + + print(model_name, cls_token.size(), visual_tokens.size(), pooled_cls_token.size()) + + +if __name__ == "__main__": + print_feature_size() diff --git a/theia/foundation_models/vision_language_models/llava.py b/theia/foundation_models/vision_language_models/llava.py new file mode 100644 index 0000000000000000000000000000000000000000..87ea2f7aea12ef7a9c03cd73fc4c8e8110ea14ef --- /dev/null +++ b/theia/foundation_models/vision_language_models/llava.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +from transformers import AutoProcessor, LlavaForConditionalGeneration +from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast + + +@dataclass +class LlavaVisualFeatureOutput(LlavaCausalLMOutputWithPast): + """Visual feature output for LLaVA. + + Args: + visual_embeddings (Optional[torch.FloatTensor]): feature from visual encoder. + """ + + visual_embeddings: Optional[torch.FloatTensor] = None + + +class LlavaVisualFeature(LlavaForConditionalGeneration): + """LLaVA model with only visual feature returned. Borrowed from transformers.""" + + # TODO: reduce VRAM use of language model part, because only vocabulary is used, not the whole model + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> tuple | LlavaVisualFeatureOutput: + """LLaVA visual encoder forward pass, from transformers package. + + Returns: + tuple | LlavaVisualFeatureOutput: feature from visual encoder. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + image_features = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + image_features = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + image_features = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + return LlavaVisualFeatureOutput(visual_embeddings=image_features) + + +def get_llava_visual_feature( + model: LlavaVisualFeature, processor: AutoProcessor, images: list[np.array], requires_grad: bool = False +) -> torch.FloatTensor: + """Get the feature from the visual encoder of LLaVA. + + Args: + model (LlavaVisualFeature): LLaVA model + processor (AutoProcessor): LLaVA input processor + images (list[np.array]): images to be encoded, in RGB, uint8 + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + torch.FloatTensor: LLaVA feature. (1, 1024, 24, 24) if using llava-7b + """ + inputs = processor(text=["placeholder"], images=images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + batch_size, num_patches, num_channels = outputs.visual_embeddings.size() + visual_tokens = outputs.visual_embeddings.transpose(1, 2).reshape( + batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) + ) + return visual_tokens # (1, 1024, 24, 24) if llava-7b + + +def get_llava_vision_model( + model_name: str = "llava-hf/llava-1.5-7b-hf", device: str | torch.device = "cuda" +) -> tuple[LlavaVisualFeature, AutoProcessor]: + """Get LLaVA model and its input processor. + + Args: + model_name (str, optional): name of LLaVA model. Defaults to "llava-hf/llava-1.5-7b-hf". + device (str | torch.device, optional): device to put the model on. Defaults to "cuda". + + Returns: + tuple[LlavaVisualFeature, AutoProcessor]: LLaVA model and the corresponding input processor. + """ + model = LlavaVisualFeature.from_pretrained(model_name).to(device) + processor = AutoProcessor.from_pretrained(model_name) + return model, processor + + +def print_feature_size(model_name: str = "llava-hf/llava-1.5-7b-hf") -> None: + """Print the size of the feature from LLaVA. + + Args: + model_name (str, optional): the name of LLaVA model. Defaults to "llava-hf/llava-1.5-7b-hf". + """ + from datasets import load_dataset + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image = [np.array(image)] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_llava_vision_model(model_name=model_name, device=device) + feature = get_llava_visual_feature(model, processor, image) + print(model_name, feature.size()) + # (1, 1024, 24, 24) if llava-7b + + +if __name__ == "__main__": + print_feature_size() diff --git a/theia/foundation_models/vision_models/__init__.py b/theia/foundation_models/vision_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/foundation_models/vision_models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/foundation_models/vision_models/deit.py b/theia/foundation_models/vision_models/deit.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4c618cc4ba308521051d28cd782a87984358a6 --- /dev/null +++ b/theia/foundation_models/vision_models/deit.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import numpy as np +import torch +from transformers import AutoImageProcessor, AutoModel + + +def get_deit_feature( + model: AutoModel, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False +) -> torch.Tensor: + """Get feature from DeiT model. + + Args: + model (AutoModel): DeiT model. + processor (AutoImageProcessor): DeiT input processor. + images (list[np.ndarray]): images to be encoded. + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + torch.Tensor: feature from last layer, (1, 768, 14, 14) BCHW deit-base + """ + inputs = processor(images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + last_hidden_state = outputs.last_hidden_state[:, 1:] + batch_size, num_patches, num_channels = last_hidden_state.size() + last_hidden_state = last_hidden_state.transpose(1, 2) + last_hidden_state = last_hidden_state.reshape( + batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) + ) + return last_hidden_state # (1, 768, 14, 14) BCHW for deit-base + + +def get_deit_model( + model_name: str = "facebook/deit-tiny-patch16-224", device: str | torch.device = "cuda" +) -> tuple[AutoModel, AutoImageProcessor]: + """Get DeiT model and its corresponding input processor. + + Args: + model_name (str, optional): the name of DeiT model. Defaults to "facebook/deit-tiny-patch16-224". + device (str | torch.device, optional): device to put model on. Defaults to "cuda". + + Returns: + tuple[DeiTModel, AutoImageProcessor]: DeiT model and its processor. + """ + processor = AutoImageProcessor.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name).to(device) + return model, processor + + +def print_feature_size(model_name: str = "facebook/deit-tiny-patch16-224") -> None: + """Print the size of the feature from ViT. + + Args: + model_name (str, optional): the name of ViT model. Defaults to "facebook/deit-tiny-patch16-224". + """ + from datasets import load_dataset + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image = np.array(image) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_deit_model(model_name=model_name, device=device) + feature = get_deit_feature(model, processor, image) + print(feature.size()) + # (1, 768, 14, 14) BCHW for deit-base diff --git a/theia/foundation_models/vision_models/depth_anything.py b/theia/foundation_models/vision_models/depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad54cdede8c3bd4b15cddfb773139bc47e4396b --- /dev/null +++ b/theia/foundation_models/vision_models/depth_anything.py @@ -0,0 +1,681 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. +# File modified. + +# ----------------------------------------------------------------------- +# Copyright 2024 TikTok and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Depth Anything model.""" +import copy +from typing import Any, Optional + +import numpy.typing as npt +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import AutoImageProcessor +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import DepthEstimatorOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.models.auto import AutoBackbone +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class DepthAnythingConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DepthAnythingModel`]. + It is used to instantiate an DepthAnything model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DepthAnything + [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`dict[str, Any] | PretrainedConfig`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights + from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches to extract from the backbone features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + reassemble_hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the reassemble layers. + reassemble_factors (`tuple[int | float, ...]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`tuple[int]`, *optional*, defaults to `[48, 96, 192, 384]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 64): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the depth estimation head. + head_hidden_size (`int`, *optional*, defaults to 32): + The number of output channels in the second convolution of the depth estimation head. + ```""" + + model_type = "depth_anything" + + def __init__( + self, + backbone_config: dict[str, Any] | PretrainedConfig = None, + backbone: Optional[str] = None, + use_pretrained_backbone: bool = False, + patch_size: int = 14, + initializer_range: float = 0.02, + reassemble_hidden_size: int = 384, + reassemble_factors: tuple[int | float, ...] = (4, 2, 1, 0.5), + neck_hidden_sizes: tuple[int, ...] = (48, 96, 192, 384), + fusion_hidden_size: int = 64, + head_in_index: int = -1, + head_hidden_size: int = 32, + **kwargs: Any, + ): + super().__init__(**kwargs) + + if use_pretrained_backbone: + raise ValueError("Pretrained backbones are not supported yet.") + + if backbone_config is not None and backbone is not None: + raise ValueError("You can't specify both `backbone` and `backbone_config`.") + + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") + backbone_config = CONFIG_MAPPING["dinov2"]( + image_size=518, + hidden_size=384, + num_attention_heads=6, + out_indices=[9, 10, 11, 12], + apply_layernorm=True, + reshape_hidden_states=False, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.reassemble_hidden_size = reassemble_hidden_size + self.patch_size = patch_size + self.initializer_range = initializer_range + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.head_hidden_size = head_hidden_size + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output + + +class DepthAnythingReassembleLayer(nn.Module): + def __init__(self, config: DepthAnythingConfig, channels: int, factor: int | float): + super().__init__() + self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) + + # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + + return hidden_state + + +class DepthAnythingReassembleStage(nn.Module): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Take the patch embeddings and reshape them to image-like feature representations. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config: DepthAnythingConfig): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors, strict=False): + self.layers.append(DepthAnythingReassembleLayer(config, channels=channels, factor=factor)) + + def forward( + self, hidden_states: list[torch.Tensor], patch_height: Optional[int] = None, patch_width: Optional[int] = None + ) -> list[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + # reshape to (batch_size, num_channels, height, width) + hidden_state = hidden_state[:, 1:] + batch_size, _, num_channels = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +class DepthAnythingPreActResidualLayer(nn.Module): + """ + ResidualConvUnit, pre-activate residual unit. + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config: DepthAnythingConfig): + super().__init__() + + self.activation1 = nn.ReLU() + self.convolution1 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.activation2 = nn.ReLU() + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.activation1(hidden_state) + hidden_state = self.convolution1(hidden_state) + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + return hidden_state + residual + + +class DepthAnythingFeatureFusionLayer(nn.Module): + """Feature fusion layer, merges feature maps from different stages. + + Args: + config (`[DepthAnythingConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config: DepthAnythingConfig): + super().__init__() + + self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) + + self.residual_layer1 = DepthAnythingPreActResidualLayer(config) + self.residual_layer2 = DepthAnythingPreActResidualLayer(config) + + def forward( + self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None, size: Optional[int] = None + ) -> torch.Tensor: + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + + modifier = {"scale_factor": 2} if size is None else {"size": size} + + hidden_state = nn.functional.interpolate( + hidden_state, + **modifier, + mode="bilinear", + align_corners=True, + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class DepthAnythingFeatureFusionStage(nn.Module): + # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage.__init__ with DPT->DepthAnything + def __init__(self, config: DepthAnythingConfig): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(DepthAnythingFeatureFusionLayer(config)) + + def forward(self, hidden_states: torch.Tensor, size: Optional[int] = None) -> list[torch.Tensor]: + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + # first layer only uses the last hidden_state + size = hidden_states[1].shape[2:] + fused_hidden_state = self.layers[0](hidden_states[0], size=size) + fused_hidden_states.append(fused_hidden_state) + + # looping from the last layer to the second + for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:], strict=False)): + size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None + + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size) + + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything +class DepthAnythingPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DepthAnythingConfig + base_model_prefix = "depth_anything" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DepthAnythingNeck(nn.Module): + """ + DepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. + It takes a list of tensors as input and produces another list of tensors as output. + For DepthAnything, it includes 2 stages: + + * DepthAnythingReassembleStage + * DepthAnythingFeatureFusionStage. + + Args: + config (dict): config dict. + """ + + def __init__(self, config: DepthAnythingConfig): + super().__init__() + self.config = config + + self.reassemble_stage = DepthAnythingReassembleStage(config) + + self.convs = nn.ModuleList() + for channel in config.neck_hidden_sizes: + self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + + # fusion + self.fusion_stage = DepthAnythingFeatureFusionStage(config) + + def forward( + self, hidden_states: list[torch.Tensor], patch_height: Optional[int] = None, patch_width: Optional[int] = None + ) -> list[torch.Tensor]: + """ + Args: + hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` + or `(batch_size, hidden_size, height, width)`): List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise ValueError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features) + + return output + + +class DepthAnythingDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's + supplementary material). + """ + + def __init__(self, config: DepthAnythingConfig): + super().__init__() + + self.head_in_index = config.head_in_index + self.patch_size = config.patch_size + + features = config.fusion_hidden_size + self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1) + self.activation1 = nn.ReLU() + self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0) + self.activation2 = nn.ReLU() + + def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor: + hidden_states = hidden_states[self.head_in_index] + + predicted_depth = self.conv1(hidden_states) + predicted_depth = nn.functional.interpolate( + predicted_depth, + (int(patch_height * self.patch_size), int(patch_width * self.patch_size)), + mode="bilinear", + align_corners=True, + ) + predicted_depth = self.conv2(predicted_depth) + predicted_depth = self.activation1(predicted_depth) + predicted_depth = self.conv3(predicted_depth) + predicted_depth = self.activation2(predicted_depth) + predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width) + + return predicted_depth + + +class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): + def __init__(self, config: DepthAnythingConfig): + super().__init__(config) + + self.backbone = AutoBackbone.from_config(config.backbone_config) + self.neck = DepthAnythingNeck(config) + self.head = DepthAnythingDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> tuple[torch.Tensor, ...] | DepthEstimatorOutput: + r""" + Forward pass for Depth Anything. + + Args: + pixel_values (torch.FloatTensor): input images. + labels (Optional[torch.LongTensor]: labels for loss. Defaults to None. + output_attentions (Optional[bool]): whether to return attentions. Defaults to None. + output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None. + return_dict (Optional[bool]): whether to return dict. Defaults to None. + + Returns: + Tuple[torch.Tensor] | DepthEstimatorOutput: forward output + + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + hidden_states = self.neck(hidden_states, patch_height, patch_width) + + predicted_depth = self.head(hidden_states, patch_height, patch_width) + + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output # noqa + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +class DepthAnythingNeckFeature(DepthAnythingForDepthEstimation): + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.Tensor: + """Forward pass for Depth Anything with only neck feature returned. + + Args: + pixel_values (torch.FloatTensor): input images. + labels (Optional[torch.LongTensor]: labels for loss. Defaults to None. + output_attentions (Optional[bool]): whether to return attentions. Defaults to None. + output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None. + return_dict (Optional[bool]): whether to return dict. Defaults to None. + + Returns: + torch.Tensor: neck feature. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + hidden_states = self.neck(hidden_states, patch_height, patch_width) + + return hidden_states + + +class DepthAnythingHeadFeature(DepthAnythingForDepthEstimation): + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.Tensor: + """Forward pass for Depth Anything with only last layer (head) feature returned. + + Args: + pixel_values (torch.FloatTensor): input images. + labels (Optional[torch.LongTensor]: labels for loss. Defaults to None. + output_attentions (Optional[bool]): whether to return attentions. Defaults to None. + output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None. + return_dict (Optional[bool]): whether to return dict. Defaults to None. + + Returns: + torch.Tensor: last layer (head) feature + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + + _, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + hidden_states = self.neck(hidden_states, patch_height, patch_width) + + hidden_states = hidden_states[-1] + + head_feature = self.head.conv1(hidden_states) + head_feature = nn.functional.interpolate( + head_feature, + (int(patch_height * patch_size), int(patch_width * patch_size)), + mode="bilinear", + align_corners=True, + ) + head_feature = self.head.conv2(head_feature) + + return head_feature + + +def get_depth_anything_feature( + model: DepthAnythingForDepthEstimation, + processor: AutoImageProcessor, + images: list[npt.NDArray], + requires_grad: Optional[bool] = False, +) -> torch.Tensor | list[torch.Tensor]: + """Get feature (after neck) from depth anything model. + + Args: + model (DepthAnythingNeckFeature): Depth Anything model. + processor (AutoImageProcessor): Depth Anything processor. + images (list[npt.NDArray]): images to extract feature. + requires_grad (Optional[bool], optional): whether to keep gradient. Defaults to False. + + Returns: + torch.Tensor: feature from depth anything model. + """ + inputs = processor(images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + # if neck + # [torch.Size([1, D, 37, 49]), torch.Size([1, D, 74, 98]), + # torch.Size([1, D, 148, 196]), torch.Size([1, D, 296, 392])] + # D = 64, 128, 256 for small, base, large + # if head + # torch.Size([1, 32, 518, 686]) + return outputs + + +def get_depth_anything_model( + model_name: Optional[str] = "LiheYoung/depth-anything-large-hf", + device: Optional[str | torch.device] = "cuda", + selected_feature: Optional[str] = "neck", +) -> tuple[DepthAnythingForDepthEstimation, AutoImageProcessor]: + """Get depth anything model. + + Args: + model_name (Optional[str]): name of the model. Defaults to "LiheYoung/depth-anything-large-hf". + device (Optional[str | torch.device]): device to put model on. Defaults to "cuda". + + Returns: + Tuple[DepthAnythingForDepthEstimation, AutoImageProcessor]: Depth Anything model and the processor. + """ + processor = AutoImageProcessor.from_pretrained(model_name) + if selected_feature == "neck": + model = DepthAnythingNeckFeature.from_pretrained(model_name).to(device) + elif selected_feature == "head": + model = DepthAnythingHeadFeature.from_pretrained(model_name).to(device) + else: + raise ValueError(f"{selected_feature} is not supported for Depth Anything") + return model, processor + + +def print_feature_size( + model_name: Optional[str] = "LiheYoung/depth-anything-large-hf", selected_feature: Optional[str] = "neck" +) -> None: + """Print the size of the feature from Depth Anything. + + Args: + model_name (Optional[str]): the name of Depth Anything model. + Defaults to "LiheYoung/depth-anything-large-hf". + """ + import requests + from PIL import Image + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = [Image.open(requests.get(url, stream=True).raw)] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_depth_anything_model(model_name=model_name, device=device, selected_feature=selected_feature) + + with torch.no_grad(): + embedding = get_depth_anything_feature(model, processor, image) + + print([x.size() for x in embedding] if isinstance(embedding, list) else embedding.size()) diff --git a/theia/foundation_models/vision_models/dinov2.py b/theia/foundation_models/vision_models/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8b0a76a08b39bf8327bd5da5da3b729b9e9f2b --- /dev/null +++ b/theia/foundation_models/vision_models/dinov2.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import numpy as np +import torch +from transformers import AutoImageProcessor, Dinov2Model + + +def get_dinov2_feature( + model: Dinov2Model, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get DINOv2 features. + + Args: + model (Dinov2Model): DINOv2 model. + processor (AutoImageProcessor): DINOv2 input processor. + images (list[np.ndarray]): images to be encoded, in RGB, uint8. + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ( + cls_token: last layer embedding from cls token # (1, 1, 1024) if dinov2-large, + visual_tokens: last layer embeddings from image # (1, 1024, 16, 16) BCHW if dinov2-large, + pooled_cls_token: last layer embedding from cls + layernorm # (1, 1, 1024) if dinov2-large + ) + """ + inputs = processor(images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + cls_token = outputs.last_hidden_state[:, :1] # (1, 1, 1024) if vit-large + visual_tokens = outputs.last_hidden_state[:, 1:] # (1, 256, 1024) if vit-large + pooled_cls_token = outputs.pooler_output.unsqueeze(1) # (1, 1, 1024) if vit-large + batch_size, num_patches, num_channels = visual_tokens.size() + visual_tokens = visual_tokens.transpose(1, 2) + visual_tokens = visual_tokens.reshape( + batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) + ) # (1, 1024, 16, 16) BCHW for vit-huge + return cls_token, visual_tokens, pooled_cls_token + + +def get_dinov2_model( + model_name: str = "facebook/dinov2-large", device: str | torch.device = "cuda" +) -> tuple[Dinov2Model, AutoImageProcessor]: + """Get DINOv2 model and its input processor. + + Args: + model_name (str, optional): name of DINOv2 model. Defaults to "facebook/dinov2-large". + device (str | torch.device, optional): device to put the model on. Defaults to "cuda". + + Returns: + tuple[Dinov2Model, AutoImageProcessor]: DINOv2 model and the corresponding input processor + """ + processor = AutoImageProcessor.from_pretrained(model_name) + model = Dinov2Model.from_pretrained(model_name).to(device) + return model, processor + + +def print_feature_size(model_name: str = "facebook/dinov2-large") -> None: + """Print the sizes of features from DINOv2. + + Args: + model_name (str, optional): the name of DINOv2. Defaults to "facebook/dinov2-large". + """ + from datasets import load_dataset + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image = [np.array(image)] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_dinov2_model(model_name=model_name, device=device) + cls_token, visual_tokens, pooled_cls_token = get_dinov2_feature(model, processor, image) + print(cls_token.size(), visual_tokens.size(), pooled_cls_token.size()) + # (1, 1, 1024), (1, 1024, 16, 16), (1, 1, 1024) for dinov2-large diff --git a/theia/foundation_models/vision_models/sam.py b/theia/foundation_models/vision_models/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0a3df3ea0800c6c9cd264a7346d137ab449091 --- /dev/null +++ b/theia/foundation_models/vision_models/sam.py @@ -0,0 +1,393 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import torch +from PIL import Image +from transformers import SamConfig, SamModel, SamProcessor +from transformers.models.sam.modeling_sam import SamMaskDecoder, SamMaskDecoderConfig +from transformers.utils import ModelOutput + + +class SamMaskDecoderWithFeature(SamMaskDecoder): + """Mask decoder with upscaled feature exposed. Borrowed from transformers.""" + + def __init__(self, config: SamMaskDecoderConfig): + super().__init__(config) + + # borrowd from huggingface transformer + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + ) -> Any: + """Predict masks given image and prompt embeddings.""" + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.sum().item() != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-point + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + # Run the transformer, image_positional_embedding are consumed + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs: tuple[Any, ...] = (masks, iou_pred) + + if output_attentions: + outputs = (*outputs, attentions) + else: + outputs = (*outputs, None) + + outputs = (*outputs, upscaled_embedding.reshape(batch_size * point_batch_size, num_channels, height, width)) + return outputs + + +@dataclass +class SamImageSegmentationWithFeatureOutput(ModelOutput): + """Sam segmentation output plus features.""" + + iou_scores: torch.FloatTensor = None + pred_masks: torch.FloatTensor = None + vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None + vision_attentions: Optional[tuple[torch.FloatTensor]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None + image_embeddings: Optional[tuple[torch.FloatTensor]] = None + upscaled_image_embeddings: Optional[tuple[torch.FloatTensor]] = None + + +class SamModelWithFeature(SamModel): + """SAM model with feature exposed. Borrowed from transformers.""" + + def __init__(self, config: SamConfig): + super().__init__(config) + self.mask_decoder = SamMaskDecoderWithFeature(config.mask_decoder_config) + self.post_init() + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Optional[dict[str, Any]], + ) -> tuple | SamImageSegmentationWithFeatureOutput: + """Sam forward pass with feature returned""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`," + " `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] # type: ignore [union-attr] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + if output_hidden_states: + vision_hidden_states = vision_outputs[1] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: # type: ignore [union-attr] + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), # type: ignore [union-attr] + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions, upscaled_image_embeddings = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + if not return_dict: + output: tuple[Any, ...] = (iou_predictions, low_res_masks) + if output_hidden_states: + output = (*output, vision_hidden_states) + if output_attentions: + output = (*output, vision_attentions, mask_decoder_attentions) + + output = (*output,) + return output + + return SamImageSegmentationWithFeatureOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + image_embeddings=image_embeddings, + upscaled_image_embeddings=upscaled_image_embeddings, + ) + + +class SamModelVisionFeature(SamModel): + """Sam with only feature from the vision backbone. Borrowed from transformers.""" + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Optional[dict[str, Any]], + ) -> list[dict[str, torch.Tensor]]: + """Sam forward pass that only goes through vision backbone and returns visual feature.""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`," + " `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] # type: ignore [union-attr] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + if output_hidden_states: + vision_hidden_states = vision_outputs[1] + if output_attentions: + vision_attentions = vision_outputs[-1] + + return SamImageSegmentationWithFeatureOutput( + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + image_embeddings=image_embeddings, + ) + + +def get_sam_feature( + model: SamModel, processor: SamProcessor, images: list[np.ndarray], requires_grad: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + """Get features from SAM. + + Args: + model (SamModel): SAM model. + processor (SamProcessor): SAM input processor. + images (list[np.ndarray]): images to be encoded, in RGB, uint8. + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + tuple[torch.Tensor, torch.Tensor]: ( + image_embeddings: feature from SAM visual encoder # (1, 256, 64, 64) if BCHW vit-huge + upscaled_image_embeddings: features from mask decoder # (1, 32, 256, 256) + ) + """ + inputs = processor(images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + return (outputs.image_embeddings, outputs.upscaled_image_embeddings) + + +def get_sam_model( + model_name: str = "facebook/sam-vit-huge", device: str | torch.device = "cuda", with_upscaled: bool = False +) -> tuple[SamModelWithFeature, SamProcessor]: + """Get sam model and its input processor. + + Args: + model_name (str, optional): name of SAM model. Defaults to "facebook/sam-vit-huge". + device (str | torch.device, optional): device to put the model on. Defaults to "cuda". + with_upscaled (bool, optional): if return upscaled features. Defaults to False. + + Returns: + tuple[SamModelWithFeature, SamProcessor]: SAM and its corresponding input processor + """ + if with_upscaled: + model = SamModelWithFeature.from_pretrained(model_name).to(device) + else: + model = SamModelVisionFeature.from_pretrained(model_name).to(device) + processor = SamProcessor.from_pretrained(model_name) + return model, processor + + +def print_feature_size(model_name: str = "facebook/sam-vit-huge") -> None: + """Print the size of features from sam. + + Args: + model_name (str, optional): the name of SAM model. Defaults to "facebook/sam-vit-huge". + """ + import requests + + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + image_array = [np.array(raw_image)] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_sam_model(model_name=model_name, device=device) + image_embeddings, upscaled_embeddings = get_sam_feature(model, processor, image_array) + + print(image_embeddings.size(), upscaled_embeddings.size() if upscaled_embeddings is not None else None) + # (1, 256, 64, 64) and (1, 32, 256, 256) for vit-huge diff --git a/theia/foundation_models/vision_models/vit.py b/theia/foundation_models/vision_models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..d240e95e677a261f429de90f9ef400d935848d32 --- /dev/null +++ b/theia/foundation_models/vision_models/vit.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import numpy as np +import torch +from transformers import AutoImageProcessor, ViTModel + + +def get_vit_feature( + model: ViTModel, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + """Get feature from ViT model. + + Args: + model (ViTModel): ViT model. + processor (AutoImageProcessor): ViT input processor. + images (list[np.ndarray]): images to be encoded. + requires_grad (bool): maintains gradient. Defaults to False. + + Returns: + torch.Tensor: feature from last layer, (1, 1280, 16, 16) BCHW vit-huge + """ + inputs = processor(images, return_tensors="pt").to(model.device) + if requires_grad: + outputs = model(**inputs) + else: + with torch.no_grad(): + outputs = model(**inputs) + cls_token, last_hidden_state = outputs.last_hidden_state[:, 0], outputs.last_hidden_state[:, 1:] + batch_size, num_patches, num_channels = last_hidden_state.size() + last_hidden_state = last_hidden_state.transpose(1, 2) + last_hidden_state = last_hidden_state.reshape( + batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches)) + ) + return cls_token, last_hidden_state # (1, 1280, 16, 16) BCHW for vit-huge + + +def get_vit_model( + model_name: str = "google/vit-huge-patch14-224-in21k", device: str | torch.device = "cuda" +) -> tuple[ViTModel, AutoImageProcessor]: + """Get ViT model and its corresponding input processor. + + Args: + model_name (str, optional): the name of vit model. Defaults to "google/vit-huge-patch14-224-in21k". + device (str | torch.device, optional): device to put model on. Defaults to "cuda". + + Returns: + tuple[ViTModel, AutoImageProcessor]: _description_ + """ + processor = AutoImageProcessor.from_pretrained(model_name) + model = ViTModel.from_pretrained(model_name).to(device) + return model, processor + + +def print_feature_size(model_name: str = "google/vit-huge-patch14-224-in21k") -> None: + """Print the size of the feature from ViT. + + Args: + model_name (str, optional): the name of ViT model. Defaults to "google/vit-huge-patch14-224-in21k". + """ + from datasets import load_dataset + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image = np.array(image) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, processor = get_vit_model(model_name=model_name, device=device) + cls_token, feature = get_vit_feature(model, processor, image) + print(cls_token.size(), feature.size()) + # cls (1, 1280) + # feature (1, 1280, 16, 16) BCHW for vit-huge diff --git a/theia/lr_schedulers/__init__.py b/theia/lr_schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85cbe8c221af7fa03172959d23da67d622af727 --- /dev/null +++ b/theia/lr_schedulers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .lr_schedulers import get_cos_lrs_with_linear_warm_up, get_constant_lrs_with_linear_warm_up diff --git a/theia/lr_schedulers/lr_schedulers.py b/theia/lr_schedulers/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3431ffa3a5061db4af2bdac9bf7a66ee2e054a --- /dev/null +++ b/theia/lr_schedulers/lr_schedulers.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any +from torch.optim import Optimizer +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR, ConstantLR + + +def get_cos_lrs_with_linear_warm_up( + optimizer: Optimizer, + warm_up_steps: int = 2000, + warm_up_lr_start_factor: float = 1e-2, + warm_up_lr_end_factor: float = 1.0, + cos_lrs_T_0: int = 5000, +) -> SequentialLR: + """Get a cos annealing warm restarts lr scheduler with linear warm up at the beginning. + + Args: + optimizer (Optimizer): original optimizer to be scheduled. + warm_up_steps (int): number of warm up steps. Defaults to 2000. + warm_up_lr_start_factor (float): start factor of the linear schedular. Defaults to 1e-2. + warm_up_lr_end_factor (float): end factor of the linear scheduler. Defaults to 1. + cos_lrs_T_0 (int): T_0 param of cos lrs. Number of steps per cycle. Defaults to 5000. + + Returns: + SequentialLR: a sequential lrs that combines linear and cos to implement warm up. + """ + linear_lrs = LinearLR( + optimizer=optimizer, + start_factor=warm_up_lr_start_factor, + end_factor=warm_up_lr_end_factor, + total_iters=warm_up_steps, + ) + + cos_lrs = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=cos_lrs_T_0, T_mult=1) + + lrs = SequentialLR(optimizer=optimizer, schedulers=[linear_lrs, cos_lrs], milestones=[warm_up_steps]) + + return lrs + + +def get_constant_lrs_with_linear_warm_up( + optimizer: Optimizer, + warm_up_steps: int = 2000, + warm_up_lr_start_factor: float = 1e-2, + warm_up_lr_end_factor: float = 1., + **kwargs: Any +) -> SequentialLR: + """Get a constant lr scheduler with linear warm up at the beginning. + + Args: + optimizer (Optimizer): original optimizer to be scheduled. + warm_up_steps (int): number of warm up steps. Defaults to 2000. + warm_up_lr_start_factor (float): start factor of the linear schedular. Defaults to 1e-2. + warm_up_lr_end_factor (float): end factor of the linear scheduler. Defaults to 1. + + Returns: + SequentialLR: a sequential lrs that combines linear and constant lrs to implement warm up. + """ + linear_lrs = LinearLR( + optimizer = optimizer, + start_factor = warm_up_lr_start_factor, + end_factor = warm_up_lr_end_factor, + total_iters = warm_up_steps + ) + + constant_lrs = ConstantLR( + optimizer = optimizer, + factor=1.0 + ) + + lrs = SequentialLR( + optimizer = optimizer, + schedulers = [linear_lrs, constant_lrs], + milestones = [warm_up_steps] + ) + + return lrs diff --git a/theia/models/__init__.py b/theia/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/models/activations.py b/theia/models/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..641776b99afd5b991ec47cf9785d527e8e85c2ae --- /dev/null +++ b/theia/models/activations.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import torch.nn as nn + + +def get_activation_fn(activation: str) -> nn.Module: + """Return specified activation function. + + Args: + activation (str): the name of the activation function. + + Returns: + nn.Module: the activation function in nn.Module. + """ + if activation == "relu": + return nn.ReLU() + elif activation == "gelu": + return nn.GELU() + elif activation == "tanh": + return nn.Tanh() + elif activation == "leaky_relu": + return nn.LeakyReLU() + else: + raise ValueError(f"{activation} is not defined in theia/models/activations.py:get_activation_fn()") diff --git a/theia/models/adapter_heads.py b/theia/models/adapter_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a12e2ec936bcb6125319acfec85799aeaf0ef0 --- /dev/null +++ b/theia/models/adapter_heads.py @@ -0,0 +1,359 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + + +from itertools import chain + +import torch +import torch.nn as nn +from einops.layers.torch import Rearrange +from torch.nn.functional import interpolate + + +class Interpolation(nn.Module): + """Interpolation nn.Module wrap for nn.functional.interpolate. + + Attributes: + target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation. + """ + + def __init__(self, target_size: tuple[int, int] | torch.Size) -> None: + super().__init__() + self.target_size = target_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Very simple forward pass to call interpolate().""" + return interpolate(x, self.target_size) + + +class LinearAdapterHead(nn.Module): + """Adapter head contains a single linear layer.""" + def __init__( + self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size + ): + """Initialization function for LinearAdapterHead. + Args: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. + target_size (tuple[int, ...] | torch.Size): the size of the target feature. + num_layer (int): number of MLP layers (One linear layer if num_layer = 1). + """ + super().__init__() + + self.source_size = source_size + self.target_size = target_size + + source_channel_size = self.source_size[0] + target_channel_size = self.target_size[0] + + self.adapter = nn.Sequential( + nn.Linear(source_channel_size, target_channel_size), + ) + + def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: + """Forward pass for the adapter. """ + assert backbone_no_cls == False + # x: [B, (1+H*W), C] + # LinearAdapterHead is used only when there is cls token in the backbone. + x = x[:, 0] + x = self.adapter(x) + return x # [B, (H*W), C] + + +class MLPAdapterHead(nn.Module): + """MLP Adapter module. + + Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. + Will first do interpolation to match the spatial size [H_t, W_t], + followed by MLP to project to the target channel dimension [C_t]. + + Attributes: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W] + target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W] + adapter (nn.Module): the adapter module. + interpolation (nn.Module): interpolation to adjust sizes before MLP. + """ + + def __init__( + self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int + ): + """Initialization function for MLPAdapter. + + Args: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. + target_size (tuple[int, ...] | torch.Size): the size of the target feature. + num_layer (int): number of MLP layers (One linear layer if num_layer = 1). + """ + super().__init__() + assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}" + + self.source_size = source_size + self.target_size = target_size + + source_channel_size = self.source_size[0] + target_channel_size = self.target_size[0] + + self.interpolation = nn.Sequential( + nn.Identity(), + ) + if self.source_size[1] != self.target_size[1]: + self.interpolation = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + Interpolation(self.target_size[1:]), + Rearrange("b c h w-> b (h w) c"), + ) + + if num_layer == 1: + self.adapter = nn.Sequential( + nn.Linear(source_channel_size, target_channel_size), + ) + elif num_layer >= 2: + hidden_dim = source_channel_size * 2 + self.adapter = nn.Sequential( + nn.Linear(source_channel_size, hidden_dim), + *list( + chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)]) + ), + nn.ReLU(), + nn.Linear(hidden_dim, target_channel_size), + ) + + def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: + """Forward pass for the adapter. First interpolation then MLP.""" + # x: [B, (1)+H*W, C] + if not backbone_no_cls: + x = x[:, 1:] + # x: [B, (H*W), C] + x = self.interpolation(x) + x = self.adapter(x) + return x # [B, (H*W), C] + + +class ConvAdapterHead(nn.Module): + """Convolutional Adapter module. + + Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. + Uses CNN to map channel and spatial sizes jointly. + Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now. + + Attributes: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. + target_size (tuple[int, ...] | torch.Size): the size of the target feature. + adapter (nn.Module): the adapter module. + interpolation (nn.Module): interpolation to adjust sizes before MLP. + """ + + def __init__( + self, + source_size: tuple[int, ...] | torch.Size, + target_size: tuple[int, ...] | torch.Size, + ): + """Initialization function for ConvAdapter. + + Args: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. + target_size (tuple[int, ...] | torch.Size): the size of the target feature. + """ + super().__init__() + self.source_size = source_size + self.target_size = target_size + + hidden_dim = self.source_size[0] * 2 + source_channel_size = self.source_size[0] + target_channel_size = self.target_size[0] + + if self.source_size[1] < 12: + raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") + elif self.source_size[1] < 16: # pad (any, any), any <= 14 to (16, 16) + self.pad = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + nn.ConvTranspose2d( + source_channel_size, + source_channel_size, + kernel_size=3, + stride=1, + output_padding=14 - self.source_size[1], + ), + ) + self.source_size = (self.source_size[0], 16, 16) + elif self.source_size[1] == 16 or self.source_size[1] == 64: # do nothing for (16, 16) and (64, 64) + self.pad = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + ) + else: + raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.") + + if self.source_size[1] < self.target_size[1]: # (16, 16) / (14, 14) to (64, 64) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 31, 31]), + nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 64, 64]), + nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), # 64 + Rearrange("b c h w-> b (h w) c"), + ) + elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, *self.source_size[1:]]), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, *self.source_size[1:]]), + nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16 + Rearrange("b c h w-> b (h w) c"), + ) + else: # (64, 64) to (16, 16) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 32, 32]), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 16, 16]), + nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16 + Rearrange("b c h w-> b (h w) c"), + ) + + def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: + """Forward pass for ConvAdapter""" + # x: [B, (1)+H*W, C] + if not backbone_no_cls: + x = x[:, 1:] + # x: [B, H*W, C] + x = self.pad(x) + x = self.adapter(x) + return x # B, (H*W), C + + +class LightConvAdapterHead(nn.Module): + """Light Convolutional Adapter module. + + Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. + Uses CNN to map channel and spatial sizes jointly. + Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14, + and target sizes (H_t, W_t): (16, 16) and (64, 64) for now. + + Attributes: + source_size (tuple[int, ...] | torch.Size): the size of the source feature, + channel first (C, H, W). + target_size (tuple[int, ...] | torch.Size): the size of the target feature, + channel first (C, H, W). + adapter (nn.Module): the adapter module. + interpolation (nn.Module): interpolation to adjust sizes before MLP. + """ + + def __init__( + self, + source_size: tuple[int, ...] | torch.Size, + target_size: tuple[int, ...] | torch.Size, + hidden_size_factor: int | float = 1.0, + ): + """Initialization function for ConvAdapter. + + Args: + source_size (tuple[int, ...] | torch.Size): the size of the source feature. + target_size (tuple[int, ...] | torch.Size): the size of the target feature. + hidden_size_factor (int | float): the size of hidden dim of feature translator + as a factor of input feature hidden dim. + """ + super().__init__() + if source_size[1] != source_size[2] or target_size[1] != target_size[2]: + raise NotImplementedError( + "Currently does not support non-square feature maps like source size" + "{source_size} and target size {target_size}." + ) + self.source_size = source_size + self.target_size = target_size + self.hidden_size_factor = hidden_size_factor + + hidden_dim = int(self.source_size[0] * hidden_size_factor) + source_channel_size = self.source_size[0] + target_channel_size = self.target_size[0] + + if self.source_size[1] < 12: + raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") + elif self.source_size[1] < 16 and self.target_size[1] >= 16: # pad (any, any), any <= 14 to (16, 16) + self.pad = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + nn.ConvTranspose2d( + source_channel_size, + source_channel_size, + kernel_size=3, + stride=1, + output_padding=14 - self.source_size[1], + ), + ) + self.source_size = (self.source_size[0], 16, 16) + elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \ + (self.source_size[1] == 14 and self.target_size[1] == 14): + # no padding for (16, 16), (64, 64) and (14, 14) <-> (14, 14) + self.pad = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + ) + elif self.target_size[1] < 14: + self.pad = nn.Sequential( + Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), + ) + else: + raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.") + + if self.source_size[1] == 16 and self.target_size[1] == 64: # (16, 16) to (64, 64) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 31, 31]), + nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 64, 64]), + Rearrange("b c h w-> b (h w) c"), + nn.Linear(hidden_dim, target_channel_size), + ) + elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, *self.source_size[1:]]), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, *self.source_size[1:]]), + Rearrange("b c h w-> b (h w) c"), + nn.Linear(hidden_dim, target_channel_size), + ) + elif self.source_size[1] == 64 and self.target_size[1] == 16: # (64, 64) to (16, 16) + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 32, 32]), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 16, 16]), + Rearrange("b c h w-> b (h w) c"), + nn.Linear(hidden_dim, target_channel_size), + ) + elif self.target_size[1] == 7: + self.adapter = nn.Sequential( + nn.LayerNorm(self.source_size), + nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7 + nn.ReLU(), + nn.LayerNorm([hidden_dim, 7, 7]), + Rearrange("b c h w-> b (h w) c"), + nn.Linear(hidden_dim, target_channel_size) + ) + else: + NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.") + + def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: + """Forward pass for ConvAdapter""" + # x: [B, (1)+H*W, C] + if not backbone_no_cls: + x = x[:, 1:] + x = self.pad(x) + x = self.adapter(x) + return x # [B, H*W, C] diff --git a/theia/models/backbones.py b/theia/models/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ea20e7cd42f249122b275d3fd1def5bb70b448 --- /dev/null +++ b/theia/models/backbones.py @@ -0,0 +1,526 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import math +from typing import Any, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel, AutoProcessor +from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel + + +# Modified from huggingface transformers ViTEmbeddings +# Original Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +class ViTEmbeddingsNoCLS(ViTEmbeddings): + """ViT Embedding Module without CLS token.""" + + def __init__(self, config: AutoConfig, use_mask_token: bool = False): + """Initialization. + + Args: + config (AutoConfig): config for ViT. + use_mask_token (bool, optional): whether to use mask token. Defaults to False. + """ + super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token) + self.cls_token = None + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings[:, 1:] + + embeddings = self.dropout(embeddings) + + return embeddings + + +# modified from huggingface transformers ViTModel +class ViTModelNoCLS(ViTModel): + """ViT Model without CLS token.""" + + def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None: + super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token) + self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token) + self.no_cls = True + + def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +# modified from huggingface transformers ViTEmbeddings +class ViTEmbeddingsReg(ViTEmbeddings): + """ + ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1 + """ + + def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7): + super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token) + self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) + self.num_reg_tokens = num_reg_tokens + self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) + + self.reg_pos_embed.data = nn.init.trunc_normal_( + self.reg_pos_embed.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(self.reg_pos_embed.dtype) + + self.reg_token.data = nn.init.trunc_normal_( + self.reg_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(self.reg_token.dtype) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + reg_pos_embed = self.reg_pos_embed + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + reg_tokens = self.reg_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1) + + embeddings = self.dropout(embeddings) + + return embeddings + + +# modified from huggingface transformers ViTModel +class ViTModelReg(ViTModel): + """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1""" + + def __init__( + self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7 + ): + super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token) + self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens) + self.num_reg_tokens = num_reg_tokens + + def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +class DeiT(nn.Module): + """DeiT model. + + Paper: Training data-efficient image transformers & distillation through attention + https://arxiv.org/abs/2012.12877 + Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit + + Attributes: + model_name (str): name of the model. + pretrained (bool): whether to use pretrained weights. + """ + + def __init__( + self, + model_name: str = "facebook/deit-small-patch16-224", + pretrained: bool = False, + image_size: int = 224, + ): + super().__init__() + self.image_size = image_size + model = AutoModel.from_pretrained(model_name) + if pretrained: + self.model = model + else: + deit_config = model.config + self.model = AutoModel.from_config(deit_config) + del model + + self.model.pooler = nn.Identity() + + self.processor = AutoProcessor.from_pretrained(model_name) + + def get_feature_size( + self, + keep_spatial: bool = False, + return_torch_size: bool = False, + ) -> torch.Size | tuple[int, ...]: + """Get the size of the feature. + + Args: + keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. + return_torch_size (bool): if true, return torch.Size type. Defaults to False. + + Returns: + torch.Size | tuple[int, ...]: returned feature shape. + """ + with torch.inference_mode(): + image_size = (224, 224) + x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) + y = self.forward(x)[:, 1:] # for getting feature size, discard cls token + size = y.size()[1:][::-1] + if keep_spatial: + assert math.isqrt(size[-1]) + h = w = int(math.sqrt(size[-1])) + size = (size[0], h, w) + if return_torch_size: + size = torch.Size(size) + return size + + def forward( + self, + x: torch.Tensor, + do_resize: bool = True, + interpolate_pos_encoding: Optional[bool] = None, + do_rescale: bool = True, + do_normalize: bool = True, + ) -> torch.Tensor: + """Forward pass of the model + + Args: + x (torch.Tensor): model input. + + - arguments for self.processor. Details can be find at + https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor + do_resize (bool): if do resizing in processor. Defaults to True. + interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. + do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. + do_normalize (bool): if do normalize in processor. Defaults to True. + + Returns: + torch.Tensor: model output. + """ + input = self.processor( + x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize + ).to(self.model.device) + y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) + return y.last_hidden_state + + +class DeiTNoCLS(nn.Module): + """Modified DeiT model without CLS token.""" + + def __init__( + self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224 + ): + super().__init__() + self.image_size = image_size + pretrained_model_name = model_name.replace("nocls-", "") + deit_config = AutoConfig.from_pretrained(pretrained_model_name) + self.model = ViTModelNoCLS(deit_config) + if pretrained: + pretrained_model = AutoModel.from_pretrained(pretrained_model_name) + pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} + self.load_state_dict(pretrained_dict, strict=False) + del pretrained_model, pretrained_dict + + self.model.pooler = nn.Identity() + self.processor = AutoProcessor.from_pretrained(pretrained_model_name) + self.no_cls = True + + def get_feature_size( + self, + keep_spatial: bool = False, + return_torch_size: bool = False, + ) -> torch.Size | tuple[int, ...]: + """Get the size of the feature. + + Args: + keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. + return_torch_size (bool): if true, return torch.Size type. Defaults to False. + + Returns: + torch.Size | tuple[int, ...]: returned feature shape. + """ + with torch.inference_mode(): + image_size = (self.image_size, self.image_size) + x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) + y = self.forward(x) + size = y.size()[1:][::-1] + if keep_spatial: + assert math.isqrt(size[-1]) + h = w = int(math.sqrt(size[-1])) + size = (size[0], h, w) + if return_torch_size: + size = torch.Size(size) + return size + + def forward( + self, + x: torch.Tensor, + do_resize: bool = True, + interpolate_pos_encoding: Optional[bool] = None, + do_rescale: bool = True, + do_normalize: bool = True, + ) -> torch.Tensor: + """Forward pass of the model + + Args: + x (torch.Tensor): model input. + + - arguments for self.processor. Details can be find at + https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor + do_resize (bool): if do resizing in processor. Defaults to True. + do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. + do_normalize (bool): if do normalize in processor. Defaults to True. + + - argument for forward + interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. + + Returns: + torch.Tensor: model output. + """ + input = self.processor( + x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize + ).to(self.model.device) + y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) + return y.last_hidden_state + + +class DeiTReg(nn.Module): + """Modified DeiT model with register tokens.""" + + def __init__( + self, + model_name: str = "reg-facebook/deit-small-patch16-224", + pretrained: bool = False, + image_size: int = 224, + num_reg_tokens: int = 7, + ): + super().__init__() + self.image_size = image_size + pretrained_model_name = model_name.replace("reg-", "") + deit_config = AutoConfig.from_pretrained(pretrained_model_name) + self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens) + if pretrained: + pretrained_model = AutoModel.from_pretrained(pretrained_model_name) + pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} + self.load_state_dict(pretrained_dict, strict=False) + del pretrained_model, pretrained_dict + + self.model.pooler = nn.Identity() + self.processor = AutoProcessor.from_pretrained(pretrained_model_name) + self.num_reg_tokens = num_reg_tokens + + def get_feature_size( + self, + keep_spatial: bool = False, + return_torch_size: bool = False, + ) -> torch.Size | tuple[int, ...]: + """Get the size of the feature. + + Args: + keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. + return_torch_size (bool): if true, return torch.Size type. Defaults to False. + + Returns: + torch.Size | tuple[int, ...]: returned feature shape. + """ + with torch.inference_mode(): + image_size = (self.image_size, self.image_size) + x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) + y = self.forward(x)[:, 1 : -self.num_reg_tokens] + size = y.size()[1:][::-1] + if keep_spatial: + assert math.isqrt(size[-1]) + h = w = int(math.sqrt(size[-1])) + size = (size[0], h, w) + if return_torch_size: + size = torch.Size(size) + return size + + def forward( + self, + x: torch.Tensor, + do_resize: bool = True, + interpolate_pos_encoding: Optional[bool] = None, + do_rescale: bool = True, + do_normalize: bool = True, + ) -> torch.Tensor: + """Forward pass of the model + + Args: + x (torch.Tensor): model input. + + - arguments for self.processor. Details can be find at + https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor + do_resize (bool): if do resizing in processor. Defaults to True. + interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. + do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. + do_normalize (bool): if do normalize in processor. Defaults to True. + + Returns: + torch.Tensor: model output. + """ + input = self.processor( + x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize + ).to(self.model.device) + y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) + return y.last_hidden_state + + +def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module: + """Build the backbone visual encoder of robot vision foundation model. + + Args: + model_name (str): name of the model. + pretrained (bool): whether to use pretrained weights. Defaults to False. + image_size (int): size of the image. Assume a square image. Defaults to 224 + kwargs (Any): any kwargs specific to some models. For example, + `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name` + + Returns: + nn.Module: backbone network. + """ + if "reg" in model_name: + return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) + elif "nocls" in model_name: + return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) + elif "deit" in model_name: + return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size) + else: + raise NotImplementedError(f"Requested {model_name} is not implemented.") diff --git a/theia/models/feature_translators.py b/theia/models/feature_translators.py new file mode 100644 index 0000000000000000000000000000000000000000..a30422b08963638609fa2a0f96defd8e4e744e40 --- /dev/null +++ b/theia/models/feature_translators.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import math +from typing import Any, Optional + +import torch +import torch.nn as nn + +from theia.models.adapter_heads import ConvAdapterHead, LightConvAdapterHead, MLPAdapterHead, LinearAdapterHead + + +class FeatureTranslator(nn.Module): + """Base class for the feature translator. + + The flow is backbone_adapter -> translator_stem -> translator_heads. + + Attributes: + backbone_feature_size (torch.Size): the size of features of the backbone. + target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. + translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. + target_model_names (list[str]): convenient attribute to hold all the names of the target models. + + backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim. + translator_stem (nn.Module): the shared stem for all target models. + translator_heads (nn.ModuleDict): specific heads for different target models. + """ + + def __init__( + self, + backbone_feature_size: torch.Size, + target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], + translator_hidden_size: int = 1024, + ) -> None: + """Initalization function for FeatureTranslator. + + Args: + backbone_feature_size (torch.Size): the size of features of the backbone. + target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. + translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. + """ + super().__init__() + self.backbone_feature_size = backbone_feature_size # (C, H, W) + self.target_feature_sizes = target_feature_sizes # [(C, H, W)] + self.translator_hidden_size = translator_hidden_size # C + self.target_model_names = list(target_feature_sizes.keys()) + self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names} + self.translator_heads: nn.ModuleDict = None + + self.backbone_adapter = nn.Sequential( + nn.LayerNorm(self.backbone_feature_size[0]), # do a pre-norm + nn.Linear( + self.backbone_feature_size[0], # C in [C,H,W] + self.translator_hidden_size, + ), + ) + self.translator_stem: nn.Module = nn.Identity() + self.build_translator_heads() + + def build_translator_heads(self) -> None: + """Build translator heads to match the dimension of each target feature set. + + Example: + translator_heads: dict[str, nn.Module] = ... + self.translator_heads = nn.ModuleDict(translator_heads) + """ + raise NotImplementedError("build_translator_heads() should be overridden") + + def forward( + self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False + ) -> torch.Tensor: + """Forward pass for a base feature translator. + + Args: + x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C]. + (1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C]. + target_model_names (Optional[list[str]]): names of the target models. + backbone_no_cls (bool): indicate backbone has cls token or not. + Can use it to customize whether to drop cls. + + Returns: + dict[str, torch.Tensor]: predicted features for target models. + """ + # x: [B, (1)+H*W, C] + x = self.backbone_adapter(x) + x = self.translator_stem(x) + target_model_names = target_model_names if target_model_names is not None else self.target_model_names + features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names} + return features + + +class MLPFeatureTranslator(FeatureTranslator): + def __init__( + self, + backbone_feature_size: torch.Size, + target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], + translator_hidden_size: int = 1024, + translator_n_layer: int = 3, + ) -> None: + """Initalization function for MLPFeatureTranslator. + + Args: + backbone_feature_size (torch.Size): the size of features of the backbone. + target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. + translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. + translator_n_layer (int): number of MLP layers. Defaults to 3. + """ + self.translator_n_layer = translator_n_layer + + super().__init__( + backbone_feature_size=backbone_feature_size, + target_feature_sizes=target_feature_sizes, + translator_hidden_size=translator_hidden_size, + ) + + def build_translator_heads(self) -> nn.ModuleDict: + """Build MLP translator heads to match the dimension of each target feature set.""" + translator_heads = {} + source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) + for target_model, target_size in self.target_feature_sizes.items(): + head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer) + translator_heads[self.legit_target_model_name_map[target_model]] = head + self.translator_heads = nn.ModuleDict(translator_heads) + + +class ConvFeatureTranslator(FeatureTranslator): + def __init__( + self, + backbone_feature_size: torch.Size, + target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], + translator_hidden_size: int = 1024, + ) -> None: + """Initalization function for ConvFeatureTranslator. + + Args: + backbone_feature_size (torch.Size): the size of features of the backbone. + target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. + translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. + """ + super().__init__( + backbone_feature_size=backbone_feature_size, + target_feature_sizes=target_feature_sizes, + translator_hidden_size=translator_hidden_size, + ) + + def build_translator_heads(self) -> nn.ModuleDict: + """Build translator heads to match the dimension of each target feature set. + + Returns: + nn.ModuleDict: the translator heads. + """ + translator_heads = {} + source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) + for target_model, target_size in self.target_feature_sizes.items(): + head = ConvAdapterHead(source_size=source_size, target_size=target_size) + translator_heads[self.legit_target_model_name_map[target_model]] = head + self.translator_heads = nn.ModuleDict(translator_heads) + + +class LightConvFeatureTranslator(FeatureTranslator): + def __init__( + self, + backbone_feature_size: torch.Size, + target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], + translator_hidden_size: int = 1024, + hidden_size_factor: int | float = 1.0, + ) -> None: + """Initalization function for LightConvFeatureTranslator. + It's for a smaller translator compared to ConvFeatureTranslator. + + Args: + backbone_feature_size (torch.Size): the size of features of the backbone. + target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. + translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024. + hidden_size_factor: the size of hidden dim of feature translator + as a factor of input feature hidden dim. Defaults to 1.0 + """ + self.hidden_size_factor = hidden_size_factor + super().__init__( + backbone_feature_size=backbone_feature_size, + target_feature_sizes=target_feature_sizes, + translator_hidden_size=translator_hidden_size, + ) + self.backbone_adapter = nn.Identity() + + def build_translator_heads(self) -> nn.ModuleDict: + """Build translator heads to match the dimension of each target feature set. + + Returns: + nn.ModuleDict: the translator heads. + """ + translator_heads = {} + for target_model, target_size in self.target_feature_sizes.items(): + if "_cls" in target_model: + head = LinearAdapterHead( + source_size=self.backbone_feature_size, + target_size=target_size + ) + else: + head = LightConvAdapterHead( + source_size=self.backbone_feature_size, + target_size=target_size, + hidden_size_factor=self.hidden_size_factor + ) + translator_heads[self.legit_target_model_name_map[target_model]] = head + self.translator_heads = nn.ModuleDict(translator_heads) + + +class TransformerFreatureTranslator(FeatureTranslator): + def __init__( + self, + backbone_feature_size: torch.Size, + target_feature_sizes: dict[str, torch.Size | tuple[int, int]], + translator_hidden_size: int = 1024, + translator_n_layers: int = 2, + translator_n_heads: int = 8, + translator_activation: str = "gelu", + ) -> None: + super().__init__( + backbone_feature_size=backbone_feature_size, + target_feature_sizes=target_feature_sizes, + translator_hidden_size=translator_hidden_size, + ) + + self.translator_stem = nn.TransformerDecoder( + nn.TransformerDecoderLayer( + d_model=translator_hidden_size, + nhead=translator_n_heads, + dim_feedforward=translator_hidden_size * 2, + activation=translator_activation, + batch_first=True, + norm_first=True, + ), + num_layers=translator_n_layers, + ) + + self.decode_tokens = nn.Parameter( + torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size)) + ) + + self.target_model_emb = nn.ParameterDict( + { + self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size) + for t in self.target_model_names + } + ) + + def build_translator_heads(self) -> None: + """Build Transformer translator heads to match the dimension of each target feature set.""" + translator_heads = {} + for target_model, target_size in self.target_feature_sizes.items(): + head = MLPAdapterHead( + source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]), + target_size=target_size, + num_layer=2, + ) + translator_heads[self.legit_target_model_name_map[target_model]] = head + self.translator_heads = nn.ModuleDict(translator_heads) + + def forward( + self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False + ) -> torch.Tensor: + """Forward pass for a simple linear translator. + + Args: + x (torch.Tensor): input features from the backbone. + target_model_names (Optional[str]): names of the target models. + backbone_no_cls (bool): indicate backbone has cls token or not. + Can use it to customize whether to drop cls. + + Returns: + dict[str, torch.Tensor]: predicted features for target models. + """ + if not backbone_no_cls: + x = x[:, 1:] + x = self.backbone_adapter(x) + features = {} + target_model_names = target_model_names if target_model_names is not None else self.target_model_names + for t in target_model_names: + feature = self.translator_stem( + torch.cat( + [ + self.decode_tokens.repeat(x.size(0), 1, 1), + self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1), + ], + dim=1, + ), + memory=x, + )[:, 1:, ...] + features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature) + return features + + +def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator: + """Handy function to build feature translators given the type + + Args: + translator_type (str): the type of the translator, + one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`). + At the moment we are actively using `"lconv"`. + + Returns: + FeatureTranslator: the corresponding FeatureTranslator + """ + if translator_type == "mlp": + return MLPFeatureTranslator(**kwargs) + elif translator_type == "conv": + return ConvFeatureTranslator(**kwargs) + elif translator_type == "lconv": + return LightConvFeatureTranslator(**kwargs) + elif translator_type == "transformer" or translator_type == "trans": + return TransformerFreatureTranslator(**kwargs) + else: + raise NotImplementedError(f"Requested {translator_type} is not implemented yet.") diff --git a/theia/models/rvfm.py b/theia/models/rvfm.py new file mode 100644 index 0000000000000000000000000000000000000000..958a9ba781382b1014ba2a2c08090d9923fee950 --- /dev/null +++ b/theia/models/rvfm.py @@ -0,0 +1,185 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +from theia.models.backbones import build_backbone +from theia.models.feature_translators import build_feature_translator +from theia.models.utils import handle_feature_output + + +class RobotVisionFM(nn.Module): + """Robot Vision Foundation Model (temporary name). + + Attributes: + backbone (str | nn.Module): backbone network. Defaults to "deit-small-patch16-224". + pretrained (bool): whether to use pretrained weights. Default to False. + translator (str | nn.Module): feature translator module. Defaults to "conv". + target_feature_sizes (Optional[dict[str, torch.Size | tuple[int, ...]]]): + a dict to hold target feature sizes. + translator_kwargs (Optional[dict[str, Any]]): other keyword arguments to the translator. + target_loss_weights (Optional[dict[str, float]]): + weights to balance loss from different target models. If not specified, use even weights. + checkpoint_path: (Optional[str]): filename of pretrained weights to load. + feature_reduce_method: (Optional[str]): how to reduce the feature in downstream applications. + """ + + def __init__( + self, + backbone: str | nn.Module = "deit-small-patch16-224", + pretrained: bool = False, + translator: str | nn.Module = "lconv", + target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None, + translator_kwargs: Optional[dict[str, Any]] = None, + target_loss_weights: Optional[dict[str, float]] = None, + checkpoint_path: Optional[str] = None, + feature_reduce_method: Optional[str] = None, + image_size: int = 224, + **kwargs: Any + ) -> None: + super().__init__() + + self.target_feature_sizes = target_feature_sizes + self.preprocessor = None + self.pretrained = pretrained + + # backbone + self.image_size = image_size + self.backbone: nn.Module = build_backbone(backbone, pretrained, image_size=image_size, **kwargs) + self.final_spatial = None + if hasattr(self.backbone, "final_spatial"): + self.final_spatial = self.backbone.final_spatial + + # handle output feature (feature reduce) + self.feature_reduce_method = feature_reduce_method + self.no_cls = hasattr(self.backbone, "no_cls") + self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0 + + # translator + backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True) + if self.target_feature_sizes: + translator_kwargs = {} if translator_kwargs is None else OmegaConf.to_container(translator_kwargs) + translator_kwargs["backbone_feature_size"] = backbone_feature_size + translator_kwargs["target_feature_sizes"] = target_feature_sizes + self.translator = build_feature_translator(translator, **translator_kwargs) + + # loss + self.mse_loss = nn.MSELoss() + self.l1_loss = nn.SmoothL1Loss() + self.cos_loss = nn.CosineEmbeddingLoss() + self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False) + self.target_loss_weights = target_loss_weights + + def load_pretrained_weights(self, checkpoint_path: str): + """Load pretrained weights. + + Args: + checkpoint_path (str): path to checkpoint / weight. + """ + if checkpoint_path: + weights_dict = torch.load(checkpoint_path, map_location="cpu") + # Filter out unnecessary keys + pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()} + self.load_state_dict(pretrained_dict, strict=False) + + def freeze_translator(self) -> None: + """Freeze the feature translator.""" + for param in self.translator.parameters(): + param.requires_grad = False + + def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Forward RVFM feature only (before translators). + + Args: + x (torch.Tensor): input image. By default it accepts images + in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. + kwargs (Any): kwargs including mainly those for huggingface preprocessor: + `do_resize` (bool) defaults to True. + `interpolate_pos_encoding` (Optional[bool]) defaults to None. + `do_rescale` (bool) defaults to True. + `do_normalize` (bool) defaults to True. + + Returns: + torch.Tensor: RVFM feature. + """ + feature = self.backbone(x, **kwargs) + # [B, 1+H*W+N, C] if including both CLS and register tokens. + # [B, 1+H*W, C] for standard model (N=0). + # [B, H*W, C] for model without CLS. + return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens) + + def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor]: + """Forward pass of Robot Vision Foundation Model. + + Args: + x (torch.Tensor): input image. By default it accepts images + in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. + target_model_names (Optional[list[str]]): names of the target foundation models. + kwargs (Any): kwargs including mainly those for huggingface preprocessor: + `do_resize` (bool) defaults to True. + `interpolate_pos_encoding` (Optional[bool]) defaults to None. + `do_rescale` (bool) defaults to True. + `do_normalize` (bool) defaults to True. + + Returns: + dict[str, torch.Tensor]: features that match to each foundation model. + Each feature is in [B, (H*W), C] or [B, C]. + """ + x = self.backbone(x, **kwargs) + if self.num_reg_tokens > 0: + x = x[:, :-self.num_reg_tokens] # [B, (1)+H*W, C] + features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) # each is [B, H*W, C] or [B, C] + return features + + def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]: + """Get loss terms given predictions and targets. + + Args: + pred_features (dict[str, torch.Tensor]): predictions. + y (dict[str, torch.Tensor]): targets. + + Returns: + tuple[Any, ...]: loss terms + """ + mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0 + mse_losses_per_model = {} + cos_losses_per_model = {} + l1_losses_per_model = {} + + for t in pred_features: + pred = pred_features[t] + target = y[t] + + # mse loss + mse_loss = self.mse_loss(pred, target) + weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features) + + # l1 loss + l1_loss = self.l1_loss(pred, target) + + # cos loss + pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2) + target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2) + target = self.cos_target.repeat(pred.size(0)).to(pred.device) + cos_loss = self.cos_loss(pred_norm, target_norm, target) + + mse_loss_avg += mse_loss * weight + cos_loss_avg += cos_loss / len(pred_features) # balance cos by default for meaningful eval + l1_loss_avg += l1_loss * weight + + mse_losses_per_model[t] = mse_loss.item() + cos_losses_per_model[t] = cos_loss.item() + l1_losses_per_model[t] = l1_loss.item() + + return { + "mse_loss": mse_loss_avg, + "cos_loss": cos_loss_avg, + "l1_loss": l1_loss_avg, + "mse_losses_per_model": mse_losses_per_model, + "cos_losses_per_model": cos_losses_per_model, + "l1_losses_per_model": l1_losses_per_model, + } diff --git a/theia/models/utils.py b/theia/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1391c375e87c207c6917de3a50ed3a596065c08e --- /dev/null +++ b/theia/models/utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Optional + +import torch + + +def handle_feature_output( + x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0 +) -> torch.Tensor: + """Handle feature output from transformer. + + Args: + x (torch.Tensor): input feature to be handled. shape is + [B, 1+H*W+N, C] if including both CLS and register tokens. + [B, 1+H*W, C] for standard model (N=0). + [B, H*W, C] for model without CLS. + feature_reduce_method (Optional[str]): method to select token. Options: + - `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C]. + - `max_pooling`: max over spatial tokens, output shape = [B, C]. + - `cls`: return CLS token only, output shape = [B, C]. + - `identity`: return the feature without touching it, output shape = input shape. + - `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]). + suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token. + num_discard_tokens (int): + number of tokens to be discarded. Assuming they are at the end of the sequence. + Returns: + torch.Tensor: selected feature tokens. + """ + + match feature_reduce_method: + case "mean_pooling": + return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C] + case "max_pooling": + return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C] + case "cls": + return x[:, 0] # [B, C] + case "identity": + return x + case None: + return x[:, 1 : x.size(1) - num_discard_tokens] + case _: + raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.") diff --git a/theia/models/vfm.py b/theia/models/vfm.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4e21970ff8f36c26b035f55c4a694f080df975 --- /dev/null +++ b/theia/models/vfm.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any, Optional + +import torch +import torch.nn as nn +from theia.foundation_models import get_clip_model, get_deit_model, get_dinov2_model, get_sam_model, get_vit_model +from transformers import AutoImageProcessor, AutoModel + +from theia.models.utils import handle_feature_output + + +class VFMEncoder(nn.Module): + """Wrapper class of an individual VFM Encoder for feature extraction. + + Attrs: + model_name (str): name of the model. + feature_reduce_method (str): how to select the output feature token and shape. + processor (AutoProcessor): input pre-processor. + """ + + def __init__(self, model_name: str, feature_reduce_method: Optional[str] = None, **kwargs: Any): + """Instanciate a (off-the-shelf) VFM encoder. + + Args: + model_name (str): name of the encoder + feature_reduce_method (Optional[str]): how to select the output feature token and shape. Defaults to None. + **kwargs (Any): anything not needed got pass-through + """ + super().__init__() + self.model_name = model_name + if "google/vit" in model_name: + model, processor = get_vit_model(model_name, device="cpu") + elif "facebook/dino" in model_name: + model, processor = get_dinov2_model(model_name, device="cpu") + elif "facebook/sam" in model_name: + model, processor = get_sam_model(model_name, device="cpu") + elif "openai/clip" in model_name: + model, processor = get_clip_model(model_name, device="cpu") + elif "facebook/deit" in model_name: + model, processor = get_deit_model(model_name, device="cpu") + elif "nvidia" in model_name: + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + processor = AutoImageProcessor.from_pretrained(model_name) + elif "mvp" in model_name: + import mvp + + model_name_mvp = model_name.replace("mvp-", "") + model = mvp.load(model_name_mvp) + processor = None + elif "vip" in model_name: + from vip import load_vip + + model = load_vip() + processor = None + elif "r3m" in model_name: + from r3m import load_r3m + + model_name_r3m = model_name.replace("r3m-", "") + model = load_r3m(model_name_r3m) + processor = None + else: + raise NotImplementedError(f"{model_name} is not supported in theia.models.vfm.VFM") + + self.model = model + self.processor = processor + self.feature_reduce_method = feature_reduce_method + if "image_size" in kwargs: + self.image_size = kwargs["image_size"] + if "final_spatial" in kwargs: + self.final_spatial = kwargs["final_spatial"] + + def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Return the feature from the VFM. + + Args: + x (torch.Tensor): input image. + kwargs: any arguments pass-through (mainly for processor currently). + For example, `do_rescale`, `do_resize`, `interpolate_pos_encoding` + to control image preprocessing pipeline. + + Returns: + torch.Tensor: feature. + """ + if ( + "google/vit" in self.model_name + or "facebook/dinov2" in self.model_name + or "facebook/deit" in self.model_name + ): + inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device) + feature = self.model(**inputs).last_hidden_state + elif "openai/clip" in self.model_name: + inputs = self.processor(images=x, return_tensors="pt", **kwargs).to(self.model.device) + feature = self.model(**inputs).last_hidden_state + elif "facebook/sam" in self.model_name: + inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device) + feature = self.model(**inputs).image_embeddings + elif "nvidia" in self.model_name: + inputs = ( + self.processor(images=x, return_tensors="pt", **kwargs) + .pixel_values.to(torch.bfloat16) + .to(self.model.device) + ) + summary, feature = self.model(inputs) + if self.feature_reduce_method == "cls_identity": + feature = summary.to(torch.float32) + else: + feature = feature.to(torch.float32) + elif "mvp" in self.model_name: + feature = self.model(x) + elif "vip" in self.model_name: + feature = self.model(x) + elif "r3m" in self.model_name: + feature = self.model(x) + return feature + + def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Forward method, including getting the feature and handle the output token / shape. + + Args: + x (torch.Tensor): input image. + + Returns: + torch.Tensor: output feature with token or shape handled. + """ + feature = self.get_feature(x, **kwargs) # [B, 1+H*W, C] + return handle_feature_output(feature, self.feature_reduce_method) + + def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Alias of forward() to accommandate some downstream usage. + + Args: + x (torch.Tensor): input image. + + Returns: + torch.Tensor: output feature with token or shape handled. + """ + return self.forward(x, **kwargs) + + +class ConcatVFMEncoder(nn.Module): + """Wrapper class that combines features from multiple VFM Encoders. The combination is channel-wise concatenation. + + Attrs: + model_names (list[str]): names of the models. + feature_reduce_method (Optional[str]): how to select the output feature token and shape. + model (nn.ModuleDict): a dict to hold different VFM encoders. + """ + + def __init__(self, model_names: list[str], feature_reduce_method: Optional[str] = None, **kwargs: Any): + """Instanciate a (off-the-shelf) VFM encoder. + + Args: + model_names (list[str]): name of the encoder + feature_reduce_method (str, optional): how to select the output feature token and shape. Defaults to None. + **kwargs (Any): anything not needed got pass-through + """ + super().__init__() + self.model_names = model_names + self.model = {} + for model_name in model_names: + model = VFMEncoder(model_name, feature_reduce_method=feature_reduce_method, **kwargs) + self.model[model_name] = model + + self.model = nn.ModuleDict(self.model) + self.feature_reduce_method = feature_reduce_method + + def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Get different features from VFMs. + + Args: + x (torch.Tensor): input image. + + Returns: + torch.Tensor: features concatenated at channel dimension. + """ + features = [] + for model_name in self.model_names: + features.append(self.model[model_name](x, **kwargs)) + features = torch.cat(features, dim=-1) + return features + + def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Forward method, including getting the feature and handle the output token / shape. + + Args: + x (torch.Tensor): input image. + + Returns: + torch.Tensor: output feature with token or shape handled. + """ + feature = self.get_feature(x, **kwargs) # [B, 1+H*W, C] + return handle_feature_output(feature, self.feature_reduce_method) + + def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Alias of forward() to accommandate some downstream usage. + + Args: + x (torch.Tensor): input image. + + Returns: + torch.Tensor: output feature with token or shape handled. + """ + return self.forward(x, **kwargs) diff --git a/theia/optimizers/__init__.py b/theia/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/optimizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/optimizers/utils.py b/theia/optimizers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4db1331a8fb940f18af7ebe48b49ba5285eec149 --- /dev/null +++ b/theia/optimizers/utils.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any, Iterable + +import torch.nn as nn + + +def param_groups_weight_decay( + model: nn.Module, weight_decay: float = 1e-5, no_weight_decay_parameters: Iterable[str] = () +) -> list[dict[str, Any]]: + """Group parameters into sets with decay applied and no decay. + + Args: + model (nn.Module): the model. + weight_decay (float): weight decay. Defaults to 1e-5. + no_weight_decay_parameters (Iterable[str]): parameters added to no weight decay + in addition to defaults. Defaults to (). + + Returns: + list[dict[str, Any]]: parameter groups with different weight decay values. + Follow the format required by torch.optim.Optimizer. + """ + no_weight_decay_parameters = set(no_weight_decay_parameters) + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters: + no_decay.append(param) + else: + decay.append(param) + + return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] + + +def param_groups_lr_weight_decay( + model: nn.Module, + backbone_lr: float = 1e-3, + translator_lr: float = 1e-3, + weight_decay: float = 1e-5, + no_weight_decay_parameters: Iterable[str] = (), +) -> list[dict[str, Any]]: + """Group parameters into set with decay applied and no decay. + + Args: + model (nn.Module): the model. + weight_decay (float): weight decay. Defaults to 1e-5. + no_weight_decay_parameters (Iterable[str]): parameters added to no weight decay + in addition to defaults. Defaults to (). + + Returns: + list[dict[str, Any]]: parameter groups with different weight decay values. + Follow the format required by torch.optim.Optimizer. + """ + no_weight_decay_parameters = set(no_weight_decay_parameters) + decay_backbone = [] + no_decay_backbone = [] + decay_translator = [] + no_decay_translator = [] + + for name, param in model.module.backbone.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters: + no_decay_backbone.append(param) + else: + decay_backbone.append(param) + + for name, param in model.module.translator.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters: + no_decay_translator.append(param) + else: + decay_translator.append(param) + + return [ + {"params": no_decay_backbone, "weight_decay": 0.0, "lr": backbone_lr}, + {"params": decay_backbone, "weight_decay": weight_decay, "lr": backbone_lr}, + {"params": no_decay_translator, "weight_decay": 0.0, "lr": translator_lr}, + {"params": decay_translator, "weight_decay": weight_decay, "lr": translator_lr}, + ] diff --git a/theia/preprocessing/feature_extraction_core/__init__.py b/theia/preprocessing/feature_extraction_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1f2b64556713b62ff6daceee10c1f2a43b8290 --- /dev/null +++ b/theia/preprocessing/feature_extraction_core/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from .models import get_feature_outputs, get_model, get_models +from .webdataset_utils import check_existing_shard, decode_image_npy_only, read_shard diff --git a/theia/preprocessing/feature_extraction_core/models.py b/theia/preprocessing/feature_extraction_core/models.py new file mode 100644 index 0000000000000000000000000000000000000000..63d1b0d351bc01200a000c0d331f2c738b1a51f4 --- /dev/null +++ b/theia/preprocessing/feature_extraction_core/models.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any + +import torch +import torch.nn as nn +from numpy.typing import NDArray +from torch.nn.functional import interpolate +from theia.foundation_models import ( + get_clip_feature, + get_clip_model, + get_depth_anything_feature, + get_depth_anything_model, + get_dinov2_feature, + get_dinov2_model, + get_llava_vision_model, + get_llava_visual_feature, + get_sam_feature, + get_sam_model, + get_vit_feature, + get_vit_model, +) + + +def get_model(model_name: str, device: int | str | torch.device = "cpu") -> tuple[nn.Module, Any]: + if "google/vit" in model_name: + model, processor = get_vit_model(model_name, device=device) + elif "facebook/sam" in model_name: + model, processor = get_sam_model(model_name, device=device, with_upscaled=False) + elif "openai/clip" in model_name: + model, processor = get_clip_model(model_name, device=device) + elif "facebook/dinov2" in model_name: + model, processor = get_dinov2_model(model_name, device=device) + elif "llava" in model_name: + model, processor = get_llava_vision_model(model_name, device=device) + elif "depth-anything" in model_name: + model, processor = get_depth_anything_model(model_name, device=device, selected_feature="head") + else: + raise NotImplementedError(f"{model_name} is not implemented") + return model, processor + + +def get_models( + model_names: list[str], device: int | str | torch.device = "cpu" +) -> tuple[dict[str, nn.Module], dict[str, Any]]: + models: dict[str, nn.Module] = {} + processors: dict[str, Any] = {} + for model_name in model_names: + model, processor = get_model(model_name, device) + models[model_name.replace("/", "_")] = model + processors[model_name.replace("/", "_")] = processor + return models, processors + + +def get_feature_outputs( + model_name: str, model: nn.Module, processor: Any, batch_images: list[NDArray], dtype: torch.dtype = torch.bfloat16 +) -> dict[str, dict[str, torch.Tensor]]: + features: dict[str, dict[str, torch.Tensor]] = {model_name: {}} + if "google_vit" in model_name: + cls_token, feature = get_vit_feature(model, processor, batch_images) + features[model_name] = { + "cls_token": cls_token.detach().cpu().to(dtype).contiguous(), + "embedding": feature.detach().cpu().to(dtype).contiguous() + } + elif "facebook_sam" in model_name: + feature, upscaled_feature = get_sam_feature(model, processor, batch_images) + features[model_name] = {"embedding": feature.detach().cpu().to(dtype).contiguous()} + features[model_name + "_32"] = { + "embedding": interpolate(feature, (32, 32)).detach().cpu().to(dtype).contiguous() + } + + if upscaled_feature: + features[model_name]["upscaled_embedding"] = upscaled_feature.detach().cpu().to(dtype).contiguous() + elif "openai_clip" in model_name: + cls_token, visual_tokens, pooled_cls_token = get_clip_feature(model, processor, batch_images) + features[model_name] = { + "embedding": visual_tokens.detach().cpu().to(dtype).contiguous(), + "cls_token": cls_token.detach().cpu().to(dtype).contiguous(), + "pooled_cls_token": pooled_cls_token.detach().cpu().to(dtype).contiguous(), + } + elif "facebook_dinov2" in model_name: + cls_token, visual_tokens, pooled_cls_token = get_dinov2_feature(model, processor, batch_images) + features[model_name] = { + "embedding": visual_tokens.detach().cpu().to(dtype).contiguous(), + "cls_token": cls_token.detach().cpu().to(dtype).contiguous(), + "pooled_cls_token": pooled_cls_token.detach().cpu().to(dtype).contiguous(), + } + elif "llava" in model_name: + feature = get_llava_visual_feature(model, processor, batch_images) + features[model_name] = {"embedding": feature.detach().cpu().to(dtype).contiguous()} + elif "depth-anything" in model_name: + feature = get_depth_anything_feature(model, processor, batch_images) + features[model_name] = {"embedding": interpolate(feature, (64, 64)).detach().cpu().to(dtype).contiguous()} + else: + raise NotImplementedError(f"model {model_name} is not supported") + + return features diff --git a/theia/preprocessing/feature_extraction_core/webdataset_utils.py b/theia/preprocessing/feature_extraction_core/webdataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9400eb973c329256c37306cbfb0b811f4fe1d567 --- /dev/null +++ b/theia/preprocessing/feature_extraction_core/webdataset_utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import os +import tarfile +from io import BytesIO + +import cv2 +import numpy as np +from numpy.typing import NDArray + + +def check_existing_shard(path: str, keys: list[str]) -> tuple[int, dict]: + """ + Check the integrity of a shard given path. + + Returns: + tuple[int, dict]: + code (int): 1 the file is ok, 0 not + count_per_key (dict): if the file is ok, how many samples are generated per key + """ + count_per_key = {k: 0 for k in keys} + if os.path.exists(path): + try: + with tarfile.open(path, "r") as tarf: + tar_members = tarf.getmembers() + tar_members = sorted(tar_members, key=lambda x: x.name) + for tar_mem in tar_members: + for k in keys: + if k in tar_mem.name: + count_per_key[k] += 1 + return 1, count_per_key + except tarfile.TarError: + return 0, count_per_key + else: + return 0, count_per_key + + +def read_shard(path: str) -> dict[str, bytes]: + """Read a (half) processed tar shard and store file contents in bytes. + + The tar should be complete to read. + + Args: + path (str): path to the tar file. + + Returns: + dict[str, bytes]: tarfile content in a dictionary where key is the tarinfo.name of each member + """ + samples = {} + with tarfile.open(path, "r") as tarf: + tar_members = tarf.getmembers() + tar_members = sorted(tar_members, key=lambda x: x.name) + for tar_mem in tar_members: + f = tarf.extractfile(tar_mem.name) + if f: + samples[tar_mem.name] = f.read() + return samples + + +def decode_image_npy_only(key: str, data: bytes) -> NDArray | bytes: + if "image" in key: + image = np.load(BytesIO(data)) + if len(image.shape) == 2: + return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif len(image.shape) == 3 and image.shape[-1] == 4: + return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + else: + return image + else: + return data diff --git a/theia/scripts/decoding/decoding_example.py b/theia/scripts/decoding/decoding_example.py new file mode 100644 index 0000000000000000000000000000000000000000..da3a5313782472b38dec17f64f035772863232cc --- /dev/null +++ b/theia/scripts/decoding/decoding_example.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +""" +Example script to decode features from theia model to corresponding visual task output, + including DINOv2 visualization, SAM segmentation masks, and Depth Anything predicted depths. +""" + +import argparse +import os + +import cv2 +import numpy as np +import torch +import transformers + +from PIL import Image +from theia.foundation_models.common import get_model_feature_size +from theia.decoding import decode_everything, load_feature_stats, prepare_depth_decoder, prepare_mask_generator +from theia.models.rvfm import RobotVisionFM +from theia.utils.seed import seed_everything +from torchvision.io import read_video, write_video + +transformers.logging.set_verbosity_error() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--backbone", type=str, default="facebook/deit-tiny-patch16-224", help="name of the backbone") + parser.add_argument("--checkpoint-path", type=str, help="path to the model weights") + parser.add_argument("--feature-stat-dir", type=str, help="the directory to find feature stats") + parser.add_argument("--media-to-vis-path", type=str, help="the location of source video / image for visualization") + parser.add_argument( + "--vis-output-dir", type=str, default="./vis_output/", help="output dir to save visualization result" + ) + args = parser.parse_args() + seed_everything(0) + device = 0 + + target_model_names = [ + "google/vit-huge-patch14-224-in21k", + "facebook/dinov2-large", + "openai/clip-vit-large-patch14", + "facebook/sam-vit-huge", + "LiheYoung/depth-anything-large-hf", + ] + target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names} + theia_model = RobotVisionFM( + translator="lconv", target_feature_sizes=target_feature_sizes, backbone=args.backbone, pretrained=False + ) + + theia_model.load_pretrained_weights(args.checkpoint_path) + theia_model = theia_model.to(device) + feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=args.feature_stat_dir) + + mask_generator, sam_model = prepare_mask_generator(device) + depth_anything_model_name = "LiheYoung/depth-anything-large-hf" + depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device) + + if args.media_to_vis_path.lower().endswith((".mp4")): + video, _, _ = read_video(args.media_to_vis_path, pts_unit="sec", output_format="THWC") + video = video.numpy() + images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video] + elif args.media_to_vis_path.lower().endswith((".jpg", ".png", ".jpeg", ".bmp")): + images = [Image.open(args.media_to_vis_path).resize((224, 224))] + + theia_decode_results, gt_decode_results = decode_everything( + theia_model=theia_model, + feature_means=feature_means, + feature_vars=feature_vars, + images=images, + mask_generator=mask_generator, + sam_model=sam_model, + depth_anything_decoder=depth_anything_decoder, + pred_iou_thresh=0.5, + stability_score_thresh=0.7, + gt=True, + device=device, + ) + + + if not os.path.exists(args.vis_output_dir): + os.makedirs(args.vis_output_dir) + if len(images) > 1: + vis_output_save_fn = ( + f"{args.media_to_vis_path.split('/')[-1].split('.')[0]}_{args.checkpoint_path.split('/')[-1].replace('.pth', '')}.mp4" + ) + vis_video = np.stack( + [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)] + ) + vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8) + + vis_save_path = os.path.join(args.vis_output_dir, vis_output_save_fn) + write_video(vis_save_path, vis_video, fps=10) + else: + vis_output_save_fn = ( + f"{args.media_to_vis_path.split('/')[-1].split('.')[0]}_{args.checkpoint_path.split('/')[-1].replace('.pth', '')}.png" + ) + vis_image = np.stack( + [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)] + ) + vis_image = Image.fromarray((vis_image * 255.0).astype(np.uint8)[0]) + vis_save_path = os.path.join(args.vis_output_dir, vis_output_save_fn) + vis_image.save(vis_save_path) + + +if __name__ == "__main__": + main() diff --git a/theia/scripts/preprocessing/calc_feature_mean.py b/theia/scripts/preprocessing/calc_feature_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..f27bb8a4c495138b74df1f8641c89f172643ab15 --- /dev/null +++ b/theia/scripts/preprocessing/calc_feature_mean.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +""" +Calculate the channel-wise mean and var of extracted features on ImageNet dataset. +The resulting mean and var will be used in distillation process. +""" + +import argparse +import glob +import os +from io import BytesIO + +import numpy as np +import torch +import webdataset as wds +from einops import rearrange +from safetensors.torch import load as sft_load +from torch.utils.data import default_collate + + +def decode_dataset_sample(key: str, data: bytes) -> bytes | torch.Tensor: + """ + Decode a feature / column in webdataset sample in bytes to its original format. + + Args: + key (str): name of the feature / column. + data (bytes): data in bytes. + + Returns: + bytes | torch.Tensor: decoded feature. + """ + if ".safetensors" in key: + sft = sft_load(data) + return rearrange(sft["embedding"], "c h w -> (h w) c") + elif key == ".image": + return torch.from_numpy(np.load(BytesIO(data))) + else: + return data + + +def main() -> None: + """Entry point of this script for calculating mean and var.""" + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-path", type=str) + parser.add_argument("--output-path", type=str) + args = parser.parse_args() + + all_datasets = {} + all_datasets.update({"imagenet": {"steps": 1_281_167}}) + ds_dir = args.dataset_path + models = [m for m in os.listdir(ds_dir) if os.path.isdir(os.path.join(ds_dir, m))] + for model in models: + print(model) + if model == "images" or model == "image" or model == "images_val": + continue + if os.path.exists(f"{args.output_path}/imagenet_mean_{model}.npy"): + continue + model_mean: torch.Tensor = None + model_var_sum: torch.Tensor = None + n = 0 + ds = ( + wds.WebDataset( + sorted(glob.glob(f"{ds_dir}/{model}/*.tar")), + shardshuffle=False, + ) + .decode(decode_dataset_sample) + .batched(256, collation_fn=default_collate) + ) + + key = f"{model}.safetensors".lower() + for batch_idx, batch in enumerate(ds): + if model_mean is None: + model_mean = torch.zeros((batch[key].size(-1))) + new_n = np.prod(batch[key].size()[:2]) + batch_mean = batch[key].float().mean((0, 1)) + model_mean = (model_mean * n + batch_mean * new_n) / (n + new_n) + n += new_n + print(f"calc {model} mean {batch_idx*256:07d}\r", end="") + + model_mean_npy = model_mean.numpy() + np.save(f"{args.output_path}/imagenet_mean_{model}.npy", model_mean_npy) + + # var + for i, b in enumerate(ds): + if model_var_sum is None: + model_var_sum = torch.zeros((b[key].size(-1))) + model_var_sum += ((b[key].float() - model_mean) ** 2).sum((0, 1)) + print(f"calc {model} var {i*256:07d}\r", end="") + + model_var = torch.sqrt(model_var_sum / (n - 1)) + np.save(f"{args.output_path}/imagenet_var_{model}.npy", model_var.numpy()) + + +if __name__ == "__main__": + main() diff --git a/theia/scripts/preprocessing/check_feature.py b/theia/scripts/preprocessing/check_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..f19403eee16668e2dca8cee8c1c5ba46c9bd36ca --- /dev/null +++ b/theia/scripts/preprocessing/check_feature.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import argparse +import json +import os +import tarfile +from io import BytesIO +from typing import Any + +import cv2 +import numpy as np +import torch +from numpy.typing import NDArray +from PIL import Image +from safetensors.torch import load as sft_load + +from theia.dataset import ALL_IMAGE_DATASETS, ALL_VIDEO_DATASETS +from theia.foundation_models.common import MODELS +from theia.preprocessing.feature_extraction_core import ( + get_feature_outputs, + get_model, +) +from theia.utils.seed import seed_everything + + +def decode_oxe_sample(data: bytes, data_type: str) -> Any: + """Decode the sample from bytes. + + Args: + data (bytes): data to be decoded. + data_type (str): the type of the data. + Usually is part of the key (filename of the sample) in the webdataset. + + Returns: + Any: decoded data or pass-through bytes without touch. + """ + if ".safetensors" in data_type: + sftensor = sft_load(data) + return sftensor["embedding"] + elif data_type == ".image": + image = np.load(BytesIO(data)) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif len(image.shape) == 3 and image.shape[-1] == 4: + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + # return torch.from_numpy(image) + return image + else: + return data + + +def get_tar_sample(tarf: tarfile.TarFile, sample_index: int) -> bytes: + """Get bytes of a sample with index `sample_index` in tarfile `tarf`. + + Args: + tarf (tarfile.TarFile): tar file. + sample_index (int): index of the sample + + Returns: + bytes: content of the sample in bytes + """ + tar_members = tarf.getmembers() + tar_members = sorted(tar_members, key=lambda x: x.name) + tar_mem = tar_members[sample_index] + f = tarf.extractfile(tar_mem.name) + if f: + return f.read() + else: + raise IOError(f"failed to read tarfile {tarf}.") + + +def get_tar_sample_name(tarf: tarfile.TarFile, sample_index: int) -> str: + """Get the name of the sample with index `sample_index` in the tarfile `tarf`. + + Args: + tarf (tarfile.TarFile): tar file. + sample_index (int): index of the sample + + Returns: + str: name of the file + """ + tar_members = tarf.getmembers() + tar_members = sorted(tar_members, key=lambda x: x.name) + tar_mem = tar_members[sample_index] + return tar_mem.name + + +def check_feature( + args: argparse.Namespace, + dataset: str, + modelnames_to_check: list[str], + models: dict[str, Any], + processors: dict[str, Any], + shard_idx: int, + sample_indices: list[int] | NDArray, + split: str = "train", + dtype: torch.dtype = torch.bfloat16, +) -> dict[str, bool]: + """Check feature consistency given a dataset, names of models to check, + shard index and sample indices within that shard. + + Args: + args (argparse.Namespace): arguments. + dataset (str): name of the dataset + modelnames_to_check (list[str]): names of the features (models) to check. + models (dict[str, Any]): original models to produce features on the fly. + processors (dict[str, Any]): original processor of the models. + shard_idx (int): index of the shard. + sample_indices (list[int] | NDArray): indices of samples to be checked. + split (str, optional): name of the split of the dataset. Defaults to "train". + dtype (torch.dtype, optional): dtype of the generated feature. Defaults to torch.bfloat16. + + Returns: + dict[str, bool]: check result. The keys are model names. True means passing the check. + """ + data_dir = os.path.join(args.dataset_root, dataset, "images") + shard_filenames = sorted([filename for filename in os.listdir(data_dir) if f"{split}.tar" in filename]) + image_tar = tarfile.open(os.path.join(data_dir, shard_filenames[shard_idx]), "r") + images = [ + decode_oxe_sample(get_tar_sample(image_tar, sample_index), data_type=".image") + for sample_index in sample_indices + ] + for image, sample_index in zip(images, sample_indices, strict=False): + if args.save_image: + if not os.path.exists(args.image_save_dir): + os.makedirs(args.image_save_dir) + image = Image.fromarray(image) + image.save(os.path.join(args.image_save_dir, f"image_{shard_idx}_{sample_index}.jpg")) + image_names = [get_tar_sample_name(image_tar, sample_index).split(".")[0] for sample_index in sample_indices] + + model_check_pass = {m: False for m in modelnames_to_check} + for model_name in modelnames_to_check: + legit_model = model_name.replace("/", "_") + data_dir = os.path.join(args.dataset_root, dataset, legit_model) + shard_filenames = sorted([filename for filename in os.listdir(data_dir) if f"{split}.tar" in filename]) + feature_tar = tarfile.open(os.path.join(data_dir, shard_filenames[shard_idx]), "r") + features = torch.stack( + [ + decode_oxe_sample(get_tar_sample(feature_tar, sample_index), data_type=".safetensors") + for sample_index in sample_indices + ] + ) + gt_features = get_feature_outputs( + legit_model, models[legit_model], processors[legit_model], images, dtype=dtype + )[legit_model]["embedding"] + print(torch.sum(torch.abs(features - gt_features)), torch.max(torch.abs(features - gt_features))) + model_check_pass[model_name] = torch.all((features - gt_features) == 0) + if args.check_feature_name: + names = [get_tar_sample_name(feature_tar, sample_index).split(".")[0] for sample_index in sample_indices] + model_check_pass[model_name] = ( + all([imname == filename for imname, filename in zip(image_names, names, strict=False)]) + and model_check_pass[model_name] + ) + return model_check_pass + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-root", type=str) + parser.add_argument("--dataset", type=str) + parser.add_argument("--split", type=str, default="val") + parser.add_argument("--samples-per-shard", type=int, default=1000, help="number of samples per webdataset shard.") + parser.add_argument("--check-feature-name", action="store_true") + parser.add_argument("--save-image", action="store_true") + parser.add_argument("--image-save-dir", type=str, default="./tmp") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + seed_everything(0) + + all_datasets = {} + all_datasets.update(ALL_IMAGE_DATASETS) + all_datasets.update(ALL_VIDEO_DATASETS) + + with open(os.path.join(args.dataset_root, args.dataset, "splits.json"), "r") as f: + dataset_len = json.load(f)[args.split] + + n_shards = dataset_len // args.samples_per_shard + + model_names = [model_name for model_name in MODELS if "llava" not in model_name] + models, processors = {}, {} + for model_name in model_names: + legit_model_name = model_name.replace("/", "_") + model, processor = get_model(model_name, device=0) + models[legit_model_name] = model + processors[legit_model_name] = processor + + shard_indices = np.random.permutation(n_shards)[:5] + print(f"randomly check {args.dataset} shards {shard_indices}") + model_check_pass: dict[str, list[bool]] = {model_name: [] for model_name in model_names} + for shard_idx in shard_indices: + sample_indices = np.random.permutation(1000)[:8] + print(f"randomly check {args.dataset} shard {shard_idx} sample_indices {sample_indices}") + check_result = check_feature( + args, args.dataset, model_names, models, processors, shard_idx, sample_indices, split=args.split + ) + for model_name in model_check_pass: + model_check_pass[model_name].append(check_result[model_name]) + for model_name in model_check_pass: + if not all(model_check_pass[model_name]): + print(f"{args.dataset} {args.split} {model_name} check failed!!!") + + +if __name__ == "__main__": + main() diff --git a/theia/scripts/preprocessing/feature_extraction.py b/theia/scripts/preprocessing/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..1c671009d78107d70c79eff46ad8cb314671f722 --- /dev/null +++ b/theia/scripts/preprocessing/feature_extraction.py @@ -0,0 +1,401 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import argparse +import gc +import glob +import json +import math +import multiprocessing +import os +from io import BytesIO +from os.path import join +from typing import Any, Generator, Iterable, Optional + +import cv2 +import numpy as np +import torch +import webdataset as wds +from numpy.typing import NDArray +from safetensors.torch import save as safe_torch_save + +try: + import tensorflow_datasets as tfds + from tensorflow.python.ops.numpy_ops import np_config +except ImportError as e: + print (e) + print ("No TF usable. It's ok if you are not processing OXE dataset.") + +from theia.dataset import ALL_IMAGE_DATASETS, ALL_OXE_DATASETS, ALL_VIDEO_DATASETS +from theia.dataset.oxe.oxe_common import oxe_dsname2path +from theia.preprocessing.feature_extraction_core import ( + check_existing_shard, + decode_image_npy_only, + get_feature_outputs, + get_model, +) +from torch.utils.data import IterableDataset + + +def get_dataset(dataset_name: str, split: str, dataset_root: Optional[str] = None) -> tuple[Iterable, list[str]]: + """Get the dataset and its subset keys (if has) given a dataset name. + + Args: + dataset_name (str): name of the dataset. + split (str): split of the dataset. + dataset_root (Optional[str]): root dir of the dataset, if the dataset is stored locally. + Defaults to None (remote dataset). + + Returns: + tuple[Iterable, list[str]]: dataset and its subset keys + """ + if dataset_name in ALL_OXE_DATASETS: + builder = tfds.builder_from_directory(builder_dir=oxe_dsname2path(dataset_name)) + split = f"{split}[0:]" # don't change this to skip samples + dataset = builder.as_dataset(split=split) + visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] + return dataset, visual_observation_keys + elif dataset_name in ALL_VIDEO_DATASETS or dataset_name in ALL_IMAGE_DATASETS: + if dataset_root is None: + raise ValueError("`dataset_root` is not given.") + dataset_dir = os.path.join(dataset_root, dataset_name, "images") + if not os.path.exists(dataset_dir) or not os.path.isdir(dataset_dir): + raise ValueError(f"{dataset_dir} is not found or is not a directory.") + print("dataset shards", sorted(glob.glob(f"{dataset_dir}/*-{split}.tar"))) + dataset = wds.WebDataset( + sorted(glob.glob(f"{dataset_dir}/*-{split}.tar")), + shardshuffle=False, + ).decode(decode_image_npy_only) + return dataset, ["__self__"] + else: + raise NotImplementedError(f"{dataset_name} is not available") + + +def get_episode(ds: Any) -> Generator[tuple[Any, int], Any, Any]: + """Get an episode / a trajectory / a segment form the dataset + + Args: + ds (Any): oxe dataset in tfds format or image/video dataset in webdataset format. + + Yields: + Generator[tuple[Any, int], Any, Any]: a trajectory with its length. + """ + if isinstance(ds, IterableDataset): + it = iter(ds) + while True: + sample_buff = [] + try: + for _ in range(1000): + sample = next(it) + sample_buff.append(sample) + yield sample_buff, len(sample_buff) + except StopIteration: + yield sample_buff, len(sample_buff) + break + else: + for ep in ds: + yield ep, len(ep["steps"]) + + +def get_images(ep: Any, subset: str) -> tuple[list[NDArray], Optional[list[str]]]: + """Get images from an episode / a trajectory. + + Args: + ep (Any): an episode / a trajectory. + subset (str): subset name. + + Returns: + tuple[list[NDArray], Optional[list[str]]]: extracted images with optional info. + """ + if isinstance(ep, list): # for image / video dataset, no subsets + return [step["image"] for step in ep], [step["__key__"] for step in ep] + else: # for oxe dataset, subset means multiple camera views + images: list[NDArray] = [] + for step in ep["steps"]: + image = cv2.resize(step["observation"][subset].numpy(), (224, 224)) + images.append(image) + return images, None + + +def get_shard_dir(root: str, subset: str, key: str) -> str: + """Get the directory to hold shards. + + Args: + root (str): root directory. + subset (str): subset name. + key (str): key (column) name of the processed dataset. Usually it is the name of the feature / input. + + Returns: + str: directory to hold the shards. + """ + if subset == "__self__": + return os.path.join(root, key) + else: + return os.path.join(root, subset, key) + + +def get_shard_filename(dataset_name: str, subset: str, split: str, shard_idx: int) -> str: + """Get file name of the shard. + + Args: + dataset_name (str): name of the dataset. + subset (str): name of the subset. + split (str): name of the split. + shard_idx (int): index of this shard. + + Returns: + str: shard file name. + """ + if dataset_name in ALL_OXE_DATASETS: + if subset == "__self__": + return f"{dataset_name}_{split}-{shard_idx:06d}.tar" + else: + return f"{dataset_name}_{subset}_{split}-{shard_idx:06d}.tar" + else: + if subset == "__self__": + return f"{dataset_name}_{split}-{shard_idx:06d}-{split}.tar" + else: + return f"{dataset_name}_{subset}_{split}-{shard_idx:06d}-{split}.tar" + + +def feature_extractor( + args: argparse.Namespace, + shard_queue: multiprocessing.Queue, + worker_id: int, + dataset_len: int = 0, +) -> None: + """Feature extractor, operating on each `worker_id`. + + Args: + args (argparse.Namespace): configurations. + shard_queue (multiprocessing.Queue): queue to get shard index to work on. + worker_id (int): id of this worker. + dataset_len (int): length of the entire dataset to be processed. + """ + if args.model != "image": + model, processor = get_model(args.model, device=worker_id) + else: + model, processor = None, None + dataset, subsets = get_dataset(args.dataset, args.split, args.dataset_root) + dataset_output_root = join(args.output_path, args.dataset) + + cum_traj_len, traj_index = 0, 0 + shard_idx = shard_queue.get() + data_iter = get_episode(dataset) + episode, traj_len = next(data_iter) + remain_traj_len = traj_len + while shard_idx is not None: + print(f"{args.dataset} {args.model} shard {shard_idx:04d} worker {worker_id} " f"Subsets: {subsets}") + # navigate (stream) the dataset to the correct trajectory + while (cum_traj_len + remain_traj_len) <= shard_idx * args.samples_per_shard: + cum_traj_len += remain_traj_len + try: + episode, traj_len = next(data_iter) + remain_traj_len = traj_len + traj_index += 1 + except StopIteration: + break + + # check shard + model_names_legit = args.model.replace("/", "_") + shard_keys = [model_names_legit] + subset_check_codes = {subset: {k: 0 for k in shard_keys} for subset in subsets} + + for subset in subsets: + for k in shard_keys: + shard_dir = get_shard_dir(dataset_output_root, subset, k) + shard_filename = get_shard_filename(args.dataset, subset, args.split, shard_idx) + shard_path = os.path.join(shard_dir, shard_filename) + shard_check_code, _ = check_existing_shard(shard_path, shard_keys) + subset_check_codes[subset][k] = shard_check_code + + # generate data to the shard buffers + subset_shard_buffers: dict[str, dict[str, list[dict[str, str | bytes]]]] = { + subset: {k: [] for k in shard_keys} for subset in subsets + } + while cum_traj_len < min((shard_idx + 1) * args.samples_per_shard, dataset_len): + for subset in subsets: + images, info = None, None + + start_frame_index = traj_len - remain_traj_len + if start_frame_index >= traj_len: + raise ValueError("calculate start frame index error, needs more trajectories") + # end of the trajectory + end_frame_index = min((shard_idx + 1) * args.samples_per_shard - cum_traj_len, traj_len) + + # generate shard data per key, including images and model features + # skip any indices that are completed + for k in shard_keys: + if subset_check_codes[subset][k] == 1: + print(f"{args.dataset} {subset} {k} shard {shard_idx:04d} check pass") + continue + if k == "image": + if images is None: + # read all the images in the trejectory + images, info = get_images(episode, subset) + for frame_index in range(start_frame_index, end_frame_index): + if args.dataset in ALL_OXE_DATASETS: + basename = ( + f"{args.dataset}" + f"{'' if subset=='__self__' else '_'+subset}_seq{traj_index:06d}_{frame_index:06d}" + ) + else: + basename = info[frame_index] if info else "" + if not args.dry_run: + image_out = BytesIO() + np.save(image_out, images[frame_index]) + subset_shard_buffers[subset][k].append({"__key__": basename, k: image_out.getvalue()}) + else: + if images is None: + images, info = get_images(episode, subset) + processed = start_frame_index + # batch processing images + while processed < end_frame_index: + # take a batch + batch_images = images[processed : processed + args.batch_size] + if not args.dry_run: + effective_batch_size = len(batch_images) + features = get_feature_outputs(k, model, processor, batch_images) + for frame_index in range(processed, processed + effective_batch_size): + if args.dataset in ALL_OXE_DATASETS: + basename = ( + f"{args.dataset}" + f"{'' if subset=='__self__' else '_'+subset}" + f"_seq{traj_index:06d}_{frame_index:06d}" + ) + else: + basename = info[frame_index] if info else "" + tensor_sample_buffer = {} + for feature_key in features[k]: + tensor_sample_buffer[feature_key] = features[k][feature_key][ + frame_index - processed + ] + subset_shard_buffers[subset][k].append( + {"__key__": basename, f"{k}.safetensors": safe_torch_save(tensor_sample_buffer)} + ) + + # next batch + processed += args.batch_size + + cum_traj_len += ( + end_frame_index - start_frame_index + ) # only increase processed traj len by the actual number of frames processed + remain_traj_len -= end_frame_index - start_frame_index + print(f"{args.dataset} {args.model} shard {shard_idx:04d} traj {traj_index:06d} remains {remain_traj_len}") + # if the trajectory is exhausted, get the next one + if remain_traj_len == 0: + try: + episode, traj_len = next(data_iter) + remain_traj_len = traj_len + traj_index += 1 + except StopIteration: + break + + # shard_buffer generate done, write shard + if not args.dry_run: + for subset in subsets: + for k in shard_keys: + if subset_check_codes[subset][k] == 1: + continue + shard_dir = get_shard_dir(dataset_output_root, subset, k) + shard_filename = get_shard_filename(args.dataset, subset, args.split, shard_idx) + shard_path = os.path.join(shard_dir, shard_filename) + if not os.path.exists(shard_dir): + os.makedirs(shard_dir) + print(len(subset_shard_buffers[subset][k])) + with wds.TarWriter(shard_path) as tar_writer: + for sample in subset_shard_buffers[subset][k]: + tar_writer.write(sample) + + print(f"{args.dataset} {args.model} shard {shard_idx:04d} done") + del subset_shard_buffers + gc.collect() + # get a new shard to process + shard_idx = shard_queue.get() + + +def main() -> None: + """Main entry of feature extraction""" + np_config.enable_numpy_behavior() + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str) + parser.add_argument("--dataset-root", type=str) + parser.add_argument("--output-path", type=str) + parser.add_argument("--model", type=str) + parser.add_argument("--split", default="train") + parser.add_argument("--start", type=int, default=0, help="start index (form 0) of **steps** to process") + parser.add_argument( + "--num-to-process", + type=int, + default=-1, + help="number of **steps** to process based on start. -1 means all remaining from the start.", + ) + parser.add_argument("--batch-size", type=int, default=8, help="batch size for the model forward pass") + parser.add_argument("--force", action="store_true", help="force overwrite existing feature files.") + parser.add_argument("--dry-run", action="store_true", help="do not do model forward pass and write out.") + parser.add_argument( + "--samples-per-shard", type=int, default=1000, help="number of samples per webdataset shard. Rarely changed." + ) + parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus to parallel") + args = parser.parse_args() + + if torch.cuda.is_available(): + args.num_gpus = min(args.num_gpus, torch.cuda.device_count()) + else: + args.num_gpus = 0 + + # make directories + dataset_output_root = os.path.join(args.output_path, args.dataset) + if not os.path.exists(dataset_output_root): + os.makedirs(dataset_output_root) + + # organize the start index to start of a shard + start_fi = args.start // args.samples_per_shard * args.samples_per_shard + start_shard_idx = start_fi // args.samples_per_shard + + all_datasets = {} + all_datasets.update(ALL_OXE_DATASETS) + all_datasets.update(ALL_IMAGE_DATASETS) + all_datasets.update(ALL_VIDEO_DATASETS) + dataset_dir = os.path.join(args.dataset_root, args.dataset) + + if args.dataset in ALL_IMAGE_DATASETS or args.dataset in ALL_VIDEO_DATASETS: + with open(os.path.join(dataset_dir, "splits.json"), "r") as f: + splits = json.load(f) + dataset_len = splits[args.split] + else: + dataset_len = all_datasets[args.dataset]["steps"] + + # calculate how many shards to create + if args.num_to_process > 0: + end_sample_index = args.start + args.num_to_process + else: + end_sample_index = dataset_len + + if end_sample_index % args.samples_per_shard == 0: + end_shard_idx = end_sample_index // args.samples_per_shard + else: + end_shard_idx = math.ceil((end_sample_index) / args.samples_per_shard) + shards = list(range(start_shard_idx, end_shard_idx)) + + # create a queue to hold shards + shard_queue: multiprocessing.Queue = multiprocessing.Queue() + for shard_idx in shards: + shard_queue.put(shard_idx) + for _ in range(args.num_gpus * 2 + 1): + shard_queue.put(None) + + # create workers + workers = [ + multiprocessing.Process(target=feature_extractor, args=(args, shard_queue, worker_id, dataset_len)) + for worker_id in range(max(args.num_gpus, 1)) + ] + + for w in workers: + w.start() + for w in workers: + w.join() + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + main() diff --git a/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py b/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d3bf47191be0043aacf8811c13017f714cd5a224 --- /dev/null +++ b/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +"""Organize imagefolder-like images (ImageNet) to webdataset format.""" + +import argparse +import glob +import os +import shutil +import tarfile +from io import BytesIO + +import numpy as np +import webdataset as wds +from numpy.typing import NDArray +from PIL import Image +from torchvision.transforms.v2 import Compose, Resize + + +def check_existing_shard(path: str) -> bool: + """Check the integrity of the existing webdataset shard. + + Args: + path (str): path to the webdataset shard. + + Returns: + bool: True for complete shard. + False for non-existing or broken shard. + """ + try: + tarf = tarfile.open(path) + for _ in tarf.getmembers(): + pass + except (ValueError, tarfile.ReadError, tarfile.CompressionError) as e: + print(e) + return False + return True + + +def create_shard( + args: argparse.Namespace, + shard_idx: int, + shard_path: str | None, + remote_shard_path: str, + frames: list[tuple[NDArray, str]], +) -> None: + """Create a webdataset shard. + + Args: + args (argparse.Namespace): arguments. + shard_idx (int): index of this shard. + shard_path (str): (local) path to save the shard. + remote_shard_path (str): final destination (remote) to save the shard. + frames (list[tuple[NDArray, str]]): images to save in this shard. + """ + if check_existing_shard(remote_shard_path): + print(f"creating {args.dataset} shard {shard_idx:06d} - check pass, skip\r", end="") + return + print(f"creating {args.dataset} shard {shard_idx:06d}\r", end="") + if shard_path is None: + shard_path = remote_shard_path + with wds.TarWriter(shard_path) as tar_writer: + for i, (image, basename) in enumerate(frames): + image_out = BytesIO() + np.save(image_out, image) + sample = {"__key__": basename, "image": image_out.getvalue()} + tar_writer.write(sample) + if (i + 1) % 20 == 0: + print(f"creating {args.dataset} shard {shard_idx:06d} - {(i+1) * 100 // len(frames):02d}%\r", end="") + if shard_path != remote_shard_path: + shutil.move(shard_path, remote_shard_path) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str) + parser.add_argument("--output-path", type=str) + parser.add_argument("--imagenet-raw-path", type=str) + parser.add_argument("--tmp-shard-path", type=str, default="None") + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--samples-per-shard", type=int, default=1000) + args = parser.parse_args() + + match args.dataset: + case "imagenet": + IMAGE_DATASET_RAW_DIR = args.imagenet_raw_path + case _: + raise NotImplementedError(f"{args.dataset} is not supported") + + if args.tmp_shard_path == "None": + TMP_SHARD_PATH = None + else: + TMP_SHARD_PATH = os.path.join(args.tmp_shard_path, args.dataset) + if not os.path.exists(TMP_SHARD_PATH): + os.makedirs(TMP_SHARD_PATH) + + OUTPUT_SHARD_PATH = os.path.join(args.output_path, args.dataset) + if not os.path.exists(OUTPUT_SHARD_PATH): + os.makedirs(OUTPUT_SHARD_PATH, exist_ok=True) + + if args.split == "train": + image_paths = sorted(glob.glob(f"{IMAGE_DATASET_RAW_DIR}/{args.split}/*/*.JPEG")) + else: + image_paths = sorted(glob.glob(f"{IMAGE_DATASET_RAW_DIR}/{args.split}/*.JPEG")) + + transform = Compose([Resize((224, 224), antialias=True)]) + + shard_idx = 0 + shard_buffer: list[tuple[NDArray, str]] = [] + for image_path in image_paths: + basename = image_path.split("/")[-1].split(".")[0] + image = np.array(transform(Image.open(image_path))) + shard_buffer.append((image, basename)) + if len(shard_buffer) % 20 == 0: + print(f"shard {shard_idx: 04d} frames {len(shard_buffer)}\r", end="") + if len(shard_buffer) == args.samples_per_shard: + shard_fn = f"{args.dataset}_{args.split}-{shard_idx:06d}-{args.split}.tar" + local_shard_path = os.path.join(TMP_SHARD_PATH, shard_fn) if TMP_SHARD_PATH else None + remote_shard_path = os.path.join(OUTPUT_SHARD_PATH, shard_fn) + create_shard(args, shard_idx, local_shard_path, remote_shard_path, shard_buffer) + shard_buffer = [] + shard_idx += 1 + + shard_fn = f"{args.dataset}_{args.split}-{shard_idx:06d}-{args.split}.tar" + local_shard_path = os.path.join(TMP_SHARD_PATH, shard_fn) if TMP_SHARD_PATH else None + remote_shard_path = os.path.join(OUTPUT_SHARD_PATH, shard_fn) + if len(shard_buffer) > 0: + create_shard(args, shard_idx, local_shard_path, remote_shard_path, shard_buffer) + + +if __name__ == "__main__": + main() diff --git a/theia/scripts/preprocessing/iv_feature_extraction.sh b/theia/scripts/preprocessing/iv_feature_extraction.sh new file mode 100644 index 0000000000000000000000000000000000000000..258cde177b80856c41903e30d3c2980e0d6b4cb4 --- /dev/null +++ b/theia/scripts/preprocessing/iv_feature_extraction.sh @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. +#! /bin/bash + +dataset=$1 +numgpus=$2 + +# modify models below +models=(facebook/dinov2-large google/vit-huge-patch14-224-in21k openai/clip-vit-large-patch14 LiheYoung/depth-anything-large-hf) # facebook/sam-vit-huge +for model in ${models[@]} +do + ( + python feature_extraction.py --dataset $dataset --output-path /storage/nfs/datasets/jshang/ --model $model --split train --num-gpus $numgpus; \ + python feature_extraction.py --dataset $dataset --output-path /storage/nfs/datasets/jshang/ --model $model --split val --num-gpus $numgpus + ) & +done +wait diff --git a/theia/scripts/preprocessing/split_dataset.py b/theia/scripts/preprocessing/split_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c79d429812a16902dc119e649a294002562389 --- /dev/null +++ b/theia/scripts/preprocessing/split_dataset.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import argparse +import json +import math +import os +import tarfile +from collections import OrderedDict + +from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS +from theia.dataset.video import ALL_VIDEO_DATASETS + +DATASET_RATIOS = OrderedDict({"train": 0.8, "val": 0.05, "test": 0.15}) + +all_datasets = {} +all_datasets.update(ALL_OXE_DATASETS) +all_datasets.update(ALL_VIDEO_DATASETS) +# all_datasets.update(ALL_IMAGE_DATASETS) imagenet has its own splits, can be done seperately + + +def count_steps(tar_path: str) -> int: + """Count how many samples are in the shard + + Args: + tar_path (str): path to the shard + """ + with tarfile.open(tar_path) as tarf: + return len(list(set([x.name.split(".")[0] for x in tarf.getmembers()]))) + + +def do_dataset_split(args: argparse.Namespace, dataset_name: str) -> None: + """Split the dataset given a dataset name. + The dataset will be split based on shards in the lexical order of their filenames. + The first part goes to `training` set, the second part goes to `validation` set, + and the last part goes to `test` set. + + Args: + dataset_name (str): name of the dataset. + """ + dataset_dir = os.path.join(args.dataset_root, dataset_name) + split_json_file = os.path.join(dataset_dir, "splits.json") + + if os.path.exists(split_json_file): + return + + # only apply to images + # then feature extraction script will handle splits for features + shard_dirs = [os.path.join(dataset_dir, "images")] + for shard_dir in shard_dirs: + shard_names = sorted( + [filename for filename in os.listdir(shard_dir) if filename.endswith(".tar") and "-" in filename] + ) + n_shards = len(shard_names) + print(f"{dataset_name} total {n_shards} shards") + + cum_n_shards = 0 + split_steps_count = {} + for _, split in enumerate(DATASET_RATIOS): + ratio = DATASET_RATIOS[split] + split_n_shards = math.ceil(n_shards * ratio) + split_steps_count[split] = 0 + print(f"{dataset_name} {split} {split_n_shards} shards") + + for shard_idx in range(cum_n_shards, min(cum_n_shards + split_n_shards, n_shards)): + original_path = os.path.join(shard_dir, shard_names[shard_idx]) + if shard_idx == n_shards - 1: + split_steps_count[split] += count_steps(original_path) + else: + split_steps_count[split] += args.samples_per_shard + split_shard_filename = shard_names[shard_idx].replace(".tar", f"-{split}.tar") + split_shard_path = os.path.join(shard_dir, split_shard_filename) + + if not args.dry_run: + os.rename(original_path, split_shard_path) + cum_n_shards += split_n_shards + + with open(os.path.join(dataset_dir, "splits.json"), "w") as f: + json.dump(split_steps_count, f, indent=4) + print(split_steps_count) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-root", type=str) + parser.add_argument("--dry-run", action="store_true") + parser.add_argument( + "--samples-per-shard", + type=int, + default=1000, + help="Number of samples per webdataset shard. Rarely changed. Replace with your actual setting.", + ) + args = parser.parse_args() + for dataset in all_datasets: + if dataset in ALL_OXE_DATASETS: + if "_sim" in dataset: + continue + if "uiuc_d3field" in dataset or "cmu_playing_with_food" in dataset or "robot_vqa" in dataset: + continue + do_dataset_split(args, dataset) + + +if __name__ == "__main__": + main() diff --git a/theia/scripts/preprocessing/video_datasets/subsampling_videos.py b/theia/scripts/preprocessing/video_datasets/subsampling_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..eda7d4d250fc52970bf7c7d6c5026b60170746c1 --- /dev/null +++ b/theia/scripts/preprocessing/video_datasets/subsampling_videos.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import argparse +import os +import shutil +import tarfile +from io import BytesIO + +import numpy as np +import torch +import webdataset as wds +from numpy.typing import NDArray +from PIL import Image +from torchvision.io import VideoReader, read_video +from torchvision.transforms import Compose, Resize, ToPILImage + +# torchvision.set_video_backend("video_reader") + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str) +parser.add_argument( + "--dataset-path", + type=str, + help="please provide the dataset path directly contains videos (.mp4, .webm) or frames (.tar for epic_kitchen)", +) +parser.add_argument("--output-path", type=str, help="will create a subfolder within this output path") +parser.add_argument("--subsampling-rate", type=int, default=-1) +parser.add_argument("--samples-per-shard", type=int, default=1000) +args = parser.parse_args() + + +if args.dataset == "ego4d": + # default sampling rate for ego4d + SUBSAMPLING_RATE = 150 if args.subsampling_rate > 0 else args.subsampling_rate + video_ext = ".mp4" +elif args.dataset == "ssv2": + # default sampling rate for ego4d + SUBSAMPLING_RATE = 32 if args.subsampling_rate > 0 else args.subsampling_rate + video_ext = ".webm" +elif args.dataset == "epic_kitchen": + # default sampling rate for ego4d + SUBSAMPLING_RATE = 60 if args.subsampling_rate > 0 else args.subsampling_rate + video_ext = ".tar" +else: + raise NotImplementedError(f"{args.dataset} is not supported.") + +print(f"subsampling {args.dataset} by 1/{SUBSAMPLING_RATE}") + +RAW_VIDEO_PATH = args.dataset_path +TMP_SAMPLED_FRAMES_PATH = f"/storage/nvme/tmp_video_subsampling/{args.dataset}_1in{SUBSAMPLING_RATE}_images" +SAMPLED_FRAMES_PATH = os.path.join(args.output_path, f"{args.dataset}_1in{SUBSAMPLING_RATE}_images") +os.makedirs(SAMPLED_FRAMES_PATH, exist_ok=True) +os.makedirs(TMP_SAMPLED_FRAMES_PATH, exist_ok=True) + +SAMPLES_PER_SHARD = args.samples_per_shard + +video_fns = sorted([fn for fn in os.listdir(RAW_VIDEO_PATH) if video_ext in fn]) + +transform = Compose([Resize((224, 224), antialias=True), ToPILImage()]) + + +def check_existing_shard(path: str) -> bool: + """ + Check the integrity of a shard given path. + + Returns: + bool: True if the shard exists and is complete. + """ + if os.path.exists(path): + try: + tarf = tarfile.open(path) + for _ in tarf.getmembers(): + pass + except tarfile.TarError: + return False + else: + return False + return True + + +def create_shard(shard_idx: int, frames: list[tuple[NDArray, str]]) -> None: + """Create a shard given index and frame list. + + Args: + shard_idx (int): index of this shard. Used to determine file paths. + frames (list[tuple[NDArray, str]]): frames to write to this shard. + """ + shard_fn = f"{args.dataset}_1in{SUBSAMPLING_RATE}-{shard_idx:06d}.tar" + local_shard_path = os.path.join(TMP_SAMPLED_FRAMES_PATH, shard_fn) + remote_shard_path = os.path.join(SAMPLED_FRAMES_PATH, shard_fn) + if check_existing_shard(remote_shard_path): + print(f"creating {args.dataset} shard {shard_idx:06d} - check pass, skip\r", end="") + return + print(f"creating {args.dataset} shard {shard_idx:06d}\r", end="") + with wds.TarWriter(local_shard_path) as tar_writer: + for i, (image, basename) in enumerate(frames): + image_out = BytesIO() + np.save(image_out, image) + sample = {"__key__": basename, "image": image_out.getvalue()} + tar_writer.write(sample) + if (i + 1) % 20 == 0: + print( + f"creating {args.dataset} shard {shard_idx:06d} - {int((i+1) / len(frames) * 100):02d}%\r", end="" + ) + + # move from local to remote + shutil.move(local_shard_path, remote_shard_path) + + +shard_idx = 0 +shard_buffer: list[tuple[NDArray, str]] = [] +cum_video_len = 0 +for vfn in video_fns: + if args.dataset == "ego4d": + print(vfn) + video_path = os.path.join(RAW_VIDEO_PATH, vfn) + if video_ext == ".mp4": # for ego4d + video = VideoReader(video_path, stream="video", num_threads=32) + metadata = video.get_metadata() + fps = metadata["video"]["fps"][0] + duration = metadata["video"]["duration"][0] + fi = 0 + while fi < (duration * fps): + frame = next(video.seek(fi / fps)) + basename = f"{vfn.replace(video_ext, '')}_{fi:06d}" + image = np.array(transform(frame["data"])) + # print (image.dtype, image.shape) + shard_buffer.append((image, basename)) + if len(shard_buffer) % 20 == 0: + print(f"shard {shard_idx: 04d} frames {len(shard_buffer)}\r", end="") + if len(shard_buffer) == SAMPLES_PER_SHARD: + create_shard(shard_idx, shard_buffer) + shard_buffer = [] + shard_idx += 1 + fi += SUBSAMPLING_RATE + + elif video_ext == ".webm": # for ssv2 + video, _, info = read_video(video_path, output_format="TCHW") + video_len = video.size(0) # for webm, only fps is available; 12 fps for ssv2 + for fi in range(video_len): + if (fi + cum_video_len) % SUBSAMPLING_RATE == 0: + frame = video[fi] + basename = f"{vfn.replace(video_ext, '')}_{fi:06d}" + image = np.array(transform(frame)) + shard_buffer.append((image, basename)) + if len(shard_buffer) % 20 == 0: + print(f"shard {shard_idx: 04d} frames {len(shard_buffer)} - file progress {vfn} - {fi}\r", end="") + if len(shard_buffer) == SAMPLES_PER_SHARD: + create_shard(shard_idx, shard_buffer) + shard_buffer = [] + shard_idx += 1 + cum_video_len += video_len + + elif video_ext == ".tar": # for epic_kitchen + tar = tarfile.open(video_path) + frame_fns = sorted([tinfo.name for tinfo in tar.getmembers() if ".jpg" in tinfo.name]) + video_len = len(frame_fns) + for fi in range(video_len): + if (fi + cum_video_len) % SUBSAMPLING_RATE == 0: + frame_tarf = tar.extractfile(frame_fns[fi]) + if frame_tarf: + frame_bytes = frame_tarf.read() + else: + continue + image = np.array( + transform(torch.from_numpy(np.array(Image.open(BytesIO(frame_bytes)))).permute(-1, 0, 1)) + ) + basename = f"{vfn.replace(video_ext, '')}_{fi:06d}" + shard_buffer.append((image, basename)) + if len(shard_buffer) % 20 == 0: + print(f"shard {shard_idx: 04d} frames {len(shard_buffer)} - file progress {vfn} - {fi}\r", end="") + if len(shard_buffer) == SAMPLES_PER_SHARD: + create_shard(shard_idx, shard_buffer) + shard_buffer = [] + shard_idx += 1 + cum_video_len += video_len + +# create a shard for final remainings +if len(shard_buffer) > 0: + create_shard(shard_idx, shard_buffer) + shard_buffer = [] + shard_idx += 1 diff --git a/theia/scripts/train/sanity_check_train_rvfm.sh b/theia/scripts/train/sanity_check_train_rvfm.sh new file mode 100755 index 0000000000000000000000000000000000000000..b0aae3572708cdd321a829f431cd453719677032 --- /dev/null +++ b/theia/scripts/train/sanity_check_train_rvfm.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +torchrun --nproc_per_node=1 --nnodes 1 --rdzv_backend c10d --rdzv_endpoint localhost:0 scripts/train/train_rvfm.py \ + +logging.note=sanitycheck +dataset.data_portion=0.001 diff --git a/theia/scripts/train/train_rvfm.py b/theia/scripts/train/train_rvfm.py new file mode 100644 index 0000000000000000000000000000000000000000..176f0a9885f989e16f2262d335c82cd17dbca39b --- /dev/null +++ b/theia/scripts/train/train_rvfm.py @@ -0,0 +1,349 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +""" +Training script for theia, also called robot visual foundation model (RVFM) in +the code. +This training script uses hydra. To change configurations go for theia/configs. +""" + +import math +import os.path as osp +import random +import warnings +from typing import Any, Callable + +import hydra +import wandb +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.optim.lr_scheduler import LRScheduler +from torchvision.transforms.v2 import Compose +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +from omegaconf import DictConfig, OmegaConf + +from theia.models.rvfm import RobotVisionFM +from theia.optimizers.utils import param_groups_weight_decay +from theia.utils.logging import create_meters, log_metrics +from theia.utils.seed import seed_everything +from theia.foundation_models.common import MODEL_FEATURE_SIZES, get_model_feature_size +from theia.dataset.data_utils import get_frame_dataloader, get_frame_iterator, get_image_video_dataset +from theia.dataset.oxe.oxe_transforms import totensor + + +warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") + + +def train( + rvfm: nn.Module, + target_model_names: list[str], + optimizer: torch.optim.Optimizer, + lr_scheduler: LRScheduler, + train_dataset: Any, + eval_dataset: Any, + cfg: DictConfig, + device: int = 0, + train_epoch_steps: int = 0, + eval_epoch_steps: int = 0, + total_train_steps: int = 0, + warmup_steps: int = 0, +) -> None: + """Training and evaluation for robot visual foundation model (rvfm). + + Args: + rvfm (nn.Module): model to train. + target_model_names (list[str]): list of teacher model names. + optimizer (torch.optim.Optimizer): optimizer. + lr_scheduler (LRScheduler): learning rate scheduler. + train_dataset (Any): train dataset. + eval_dataset (Any): eval dataset. + cfg (DictConfig): train config + device (int, optional): device (of this process). Defaults to 0. + train_epoch_steps (int, optional): steps per training epoch. Defaults to 0. + eval_epoch_steps (int, optional): steps per eval epoch. Defaults to 0. + total_train_steps (int, optional): total training steps. Defaults to 0. + warmup_steps (int, optional): warmup steps. Defaults to 0. + """ + epochs = cfg.training.epochs + steps = 0 + # wrap the loaders so handle sync dataloaders easily + for ep in range(epochs): + + train_loaders = get_frame_dataloader( + train_dataset, + batch_size=cfg.training.batch_size, + pin_memory=True, + num_workers=cfg.training.num_workers, + shuffle=cfg.dataset.shuffle, + shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, + seed=cfg.seed + device * 100 + ep, # either cfg.seed or cfg.seed + rank + ) + eval_loaders = get_frame_dataloader( + eval_dataset, + batch_size=cfg.training.batch_size, + pin_memory=True, + num_workers=cfg.training.num_workers, + shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, + seed=cfg.seed, # either cfg.seed or cfg.seed + rank + ) + train_iter = get_frame_iterator(train_loaders) + + metric_meters = create_meters(target_model_names) + rvfm.train() + train_tqdm = tqdm(range(train_epoch_steps), ncols=80) if device == 0 else range(train_epoch_steps) + for _ in train_tqdm: + try: + batch = next(train_iter) + except StopIteration: + train_iter = get_frame_iterator(train_loaders) + batch = next(train_iter) + images_batch = batch["image"].to(device, non_blocking=True) + if cfg.training.random_target_models > 0: + batch_target_model_names = random.sample(target_model_names, 2) + else: + batch_target_model_names = target_model_names + + target_features_batch = {} + for t in batch_target_model_names: + base_name = t.replace("_cls", "") + cls = True if "_cls" in t else False + if cls: + target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float() + else: + target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float() + + pred = rvfm(images_batch) + losses = rvfm.module.get_loss(pred, target_features_batch) + + if cfg.training.main_loss == "mse" or cfg.training.main_loss is None: + main_loss = losses["mse_loss"] + elif cfg.training.main_loss == "cos_l1": + main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"] + + optimizer.zero_grad() + main_loss.backward() + if cfg.training.grad_clip: + nn.utils.clip_grad_norm_( + rvfm.parameters(), + cfg.training.grad_clip_norm_warmup if steps < warmup_steps else cfg.training.grad_clip_norm, + ) + optimizer.step() + + lr_scheduler.step() + + steps += 1 + batch_size = images_batch.size(0) + + log_metrics( + metric_meters, + target_model_names=target_model_names, + device=device, + batch_size=batch_size, + mode="train", + upload_wandb=True, + main_loss=main_loss, + **losses, + ) + + if cfg.training.freeze_translator: + if steps == int(cfg.training.freeze_translator_start_steps_ratio * total_train_steps): + rvfm.module.freeze_translator() + + if steps % cfg.logging.save_ckpt_interval == 0 and device == 0: + model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth" + save_path = osp.join(cfg.logging.model_path, model_save_fn) + torch.save(rvfm.module.state_dict(), save_path) + + dist.barrier() + rvfm.eval() + eval_iter = get_frame_iterator(eval_loaders) + eval_tqdm = tqdm(range(eval_epoch_steps), ncols=80) if device == 0 else range(eval_epoch_steps) + with torch.no_grad(): + for _ in eval_tqdm: + batch = next(eval_iter) + images_batch = batch["image"] + target_features_batch = {} + for t in target_model_names: + base_name = t.replace("_cls", "") + cls = True if "_cls" in t else False + if cls: + target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float() + else: + target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float() + + pred = rvfm(images_batch) + losses = rvfm.module.get_loss(pred, target_features_batch) + if cfg.training.main_loss == "mse" or cfg.training.main_loss is None: + main_loss = losses["mse_loss"] + elif cfg.training.main_loss == "cos_l1": + main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"] + + batch_size = images_batch.size(0) + log_metrics( + metric_meters, + target_model_names=target_model_names, + device=device, + batch_size=batch_size, + mode="eval", + upload_wandb=False, + main_loss=main_loss, + **losses, + ) + + log_metrics( + metric_meters, + mode="eval", + upload_wandb=True, + only_upload=True, + target_model_names=target_model_names, + device=device, + ) + + if device == 0: + model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth" + save_path = osp.join(cfg.logging.model_path, model_save_fn) + torch.save(rvfm.module.state_dict(), save_path) + + dist.barrier() + + +def ddp_setup() -> None: + """Initialize stuff for DDP.""" + dist.init_process_group("nccl") + + +def ddp_cleanup() -> None: + """Clean up stuff for DDP.""" + dist.destroy_process_group() + + +def ddp_main(cfg: DictConfig) -> None: + """Entry point of DDP. + + Args: + cfg (DictConfig): settings for training. + """ + ddp_setup() + rank, world_size = dist.get_rank(), dist.get_world_size() + + target_model_names = ( + cfg.training.target_models.target_model_names + if len(cfg.training.target_models.target_model_names) > 0 + else list(MODEL_FEATURE_SIZES.keys()) + ) + target_model_names = [t for t in target_model_names if "llava" not in t] # llava is currently not supported + target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names} + + target_model_names_wocls = target_model_names[:] + if hasattr(cfg.training, "distill_cls") and cfg.training.distill_cls == True: + target_model_names_copy = target_model_names[:] + for t in target_model_names: + if "google/vit" in t or "facebook/dino" in t or "openai/clip" in t: + target_feature_sizes[t+"_cls"] = get_model_feature_size(t, keep_spatial=True)[:1] + target_model_names_copy.append(t+"_cls") + + target_model_names = target_model_names_copy + + rvfm = RobotVisionFM( + translator=cfg.model.translator.type, + translator_kwargs=cfg.model.translator.kwargs, + target_feature_sizes=target_feature_sizes, + target_loss_weights=cfg.training.target_models.target_model_weights, + **cfg.model.backbone, + ) + + rvfm.to(rank) + + rvfm_ddp = DDP(rvfm, device_ids=[rank], find_unused_parameters=False) + + image_transform: Compose | Callable = totensor # currently just ndarray to tensor + + train_dataset, train_dataset_expected_length = get_image_video_dataset( + dataset_root=cfg.dataset.dataset_root, + dataset_mix=cfg.dataset.dataset_mix, + split="train", + dataset_ratio=cfg.dataset.dataset_ratio, + feature_models=target_model_names_wocls, + image_transform=image_transform, + feature_norm=cfg.dataset.feature_norm, + rank=rank, + world_size=world_size, + shuffle=cfg.dataset.shuffle, + seed=cfg.seed, + shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, + num_workers=cfg.training.num_workers, + ) + + eval_dataset, eval_dataset_expected_length = get_image_video_dataset( + dataset_root=cfg.dataset.dataset_root, + dataset_mix=cfg.dataset.dataset_mix, + split="val", + dataset_ratio=0.1, + feature_models=target_model_names_wocls, + image_transform=image_transform, + feature_norm=cfg.dataset.feature_norm, + rank=rank, + world_size=world_size, + shuffle=False, + seed=cfg.seed, + shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, + num_workers=cfg.training.num_workers, + ) + + train_epoch_steps = math.ceil(train_dataset_expected_length / cfg.training.batch_size / world_size) + eval_epoch_steps = math.ceil(eval_dataset_expected_length / cfg.training.batch_size / world_size) + total_train_steps = train_epoch_steps * cfg.training.epochs + + rvfm_param_groups = param_groups_weight_decay(rvfm_ddp, cfg.training.weight_decay) + lr = cfg.training.base_lr * ( + (cfg.training.batch_size * world_size) / (cfg.training.base_batch_size * cfg.training.base_world_size) + ) + optimizer = hydra.utils.instantiate(cfg.training.optimizer, rvfm_param_groups, lr=lr) + lr_scheduler = hydra.utils.instantiate( + cfg.training.lr_scheduler, + optimizer=optimizer, + warm_up_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps), + cos_lrs_T_0=int(total_train_steps * (1 - cfg.training.warm_up_steps_ratio)), + ) + + if rank == 0: + print(OmegaConf.to_yaml(cfg)) + wandb.init(project=cfg.logging.project, name=cfg.logging.run_identifier_prefix, config=OmegaConf.to_object(cfg)) + + train( + rvfm_ddp, + target_model_names, + optimizer, + lr_scheduler, + train_dataset, + eval_dataset, + cfg=cfg, + device=rank, + train_epoch_steps=train_epoch_steps, + eval_epoch_steps=eval_epoch_steps, + total_train_steps=total_train_steps, + warmup_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps), + ) + + ddp_cleanup() + + +@hydra.main(version_base=None, config_path="../../configs", config_name="train_rvfm_imagenet") +def main(cfg: DictConfig) -> None: + """Main. Dealing with arguments and call DDP.""" + + backbone_fn = f"_{cfg.model.backbone.backbone.replace('/', '-')}" + notes_fn = f"_{cfg.logging.notes}" if cfg.logging.notes else "" + translator_fn = f"_{cfg.model.translator.type}" + pretrained_fn = "_pretrained" if cfg.model.backbone.pretrained else "" + dp_fn = f"_dp{cfg.dataset.dataset_ratio:.3f}" + cfg.logging.run_identifier_prefix = f"rvfm{dp_fn}{backbone_fn}{translator_fn}{pretrained_fn}{notes_fn}" + + seed_everything(cfg.seed) + + ddp_main(cfg) + + +if __name__ == "__main__": + main() diff --git a/theia/utils/__init__.py b/theia/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/utils/cortexbench/__init__.py b/theia/utils/cortexbench/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/utils/cortexbench/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/utils/cortexbench/load_model.py b/theia/utils/cortexbench/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1c074ff0b17c9a31df59e75d3a3a202ac04abe0a --- /dev/null +++ b/theia/utils/cortexbench/load_model.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision.transforms import Compose + + +def load_model( + model: nn.Module, transform: Compose, metadata: Any, **kwargs: Any +) -> tuple[nn.Module, torch.Size, Compose, Any]: + """Helper function for loading model for cortexbench. + + Args: + model (nn.Module): model. + transform (torchvision.transforms.Compose): transform applied to input image. + metadata (Any): any metadata embedded in the model. + kwargs (Any): any parameters for loading the model. Including + `checkpoint_path` for loading weights for rvfm. + + Returns: + tuple[nn.Module, torch.Size, Compose, Any]: return model, size of the embedding, transform, and the metadata. + """ + + if kwargs.get("checkpoint_path"): + model.load_pretrained_weights(kwargs["checkpoint_path"]) + + with torch.inference_mode(): + zero_img = np.array(Image.new("RGB", (100, 100))) # for getting the embedding shape + transformed_img = transform(zero_img).unsqueeze(0) + embedding_dim = model.forward_feature(transformed_img).size()[1:] # [H*W, C] + if len(embedding_dim) > 1: + h = w = int(math.sqrt(embedding_dim[0])) + embedding_dim = torch.Size((embedding_dim[1], h, w)) # [C, H, W] + + return model, embedding_dim, transform, metadata diff --git a/theia/utils/cortexbench/policy_heads.py b/theia/utils/cortexbench/policy_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fe7f31ae8cf43589b9ae58f9d4dc9ac0167103 --- /dev/null +++ b/theia/utils/cortexbench/policy_heads.py @@ -0,0 +1,240 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + +# since this code will be only used for running cortexbench +# the following dependency won't be added to the project by default +from mjrl.policies.gaussian_mlp import BatchNormMLP +from numpy.typing import NDArray + + +class ConvBatchNormMLP(BatchNormMLP): + """Convolution followed with a BatchNormMLP (BatchNormMLP is from mjrl). + + Attrs: + embedding_dim (tuple[int, ...] | list[int, ...] | torch.Size): dimension of the representation. + proprio_dim (tuple[int, ...] | list[int, ...] | torch.Size): + dimension of the proprio information from the environment. + history_window (int): the number of history observations considered. + model (nn.ModuleDict): the dict to original BatchNormMLP (as a "head") and newly created Conv (as a "neck"). + device (str | torch.device): track the device that the model is on. + """ + + def __init__( + self, + env_spec: Any, + hidden_sizes: str = "(64, 64)", # str is to adapt with mjrl side + min_log_std: float = -3.0, + init_log_std: float = 0.0, + seed: Optional[int] = None, + nonlinearity: str = "relu", + dropout: float = 0.0, + *args: Any, + **kwargs: Any, + ): + """ + Args: + env_spec (gym.EnvSpec): specs of the environment that this policy will run on. + hidden_sizes (tuple): size of hidden layers of MLP. Defaults to (64,64). + min_log_std (float): minimum log std value for action. This is to match mjrl. Defaults to -3. + init_log_std (float): initial log std value for action. This is to match mjrl. Defaults to 0. + seed (Optional[int]): seed. Defaults to None. + nonlinearity (str): kind of non-linearility activation function. Defaults to 'relu'. + dropout (float): dropout rate. Defaults to 0. + """ + self.embedding_dim = kwargs["embedding_dim"] # [C, H, W] + self.proprio_dim = kwargs["proprio_dim"] + self.history_window = kwargs["history_window"] + hidden_sizes = eval(hidden_sizes) # hack to match mjrl + env_spec.observation_dim = hidden_sizes[0] + self.proprio_dim + super().__init__( + env_spec, hidden_sizes, min_log_std, init_log_std, seed, nonlinearity, dropout, *args, **kwargs + ) + + neck = nn.Sequential( + nn.Conv2d(self.embedding_dim[0] * self.history_window, 256, kernel_size=4, stride=2, padding=1), + nn.LayerNorm([256, 7, 7]), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), # 14x14 -> 7x7 # just to keep the same as super class + nn.Conv2d(256, 256, kernel_size=3, stride=2), + nn.LayerNorm([256, 3, 3]), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), # 7x7 -> 3x3 + nn.Conv2d(256, 256, kernel_size=3, stride=1), + nn.LayerNorm([256, 1, 1]), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), # 3x3 -> 1x1 + nn.Flatten(), + ) + + # re-encapsule so that all nn parts are in self.model + # so that explicit operations on self.model by cortexbench are applied on all nn parts + # e.g. policy.model.eval() + head: nn.Module = self.model # type:ignore [has-type] + self.model = nn.ModuleDict({"neck": neck, "head": head}) + self.device: Optional[str | torch.device] = None + + def to(self, device: str | torch.device) -> None: + """Put the model on the `device`. + + Args: + device (str | torch.device): the device to put the model + """ + for k in self.model: + self.model[k].to(device) + self.device = device + + def eval(self) -> None: + """Set the model in eval mode.""" + for k in self.model: + self.model[k].eval() + + def train(self) -> None: + """:Set the model in train mode.""" + for k in self.model: + self.model[k].train() + + def get_action_mean(self, observation: torch.Tensor) -> torch.Tensor: + """Get the mean action given the observation. + + Args: + observation (torch.Tensor): observation. + + Returns: + torch.Tensor : mean action. + """ + if len(self.embedding_dim) > 0: + # observation (B, T*H*W*C+C_pripro) + if self.proprio_dim > 0: + emb_obs, proprio_obs = observation[..., : -self.proprio_dim], observation[..., -self.proprio_dim :] + emb_obs = rearrange( + emb_obs, + "b (t h w c) -> b (c t) h w", + t=self.history_window, + c=self.embedding_dim[0], + h=self.embedding_dim[1], + w=self.embedding_dim[2], + ) + emb_obs = self.model["neck"](emb_obs) + self.obs_var = torch.cat([emb_obs, proprio_obs], dim=1) + else: + emb_obs = rearrange( + observation, + "b (t h w c) -> b (c t) h w", + t=self.history_window, + c=self.embedding_dim[0], + h=self.embedding_dim[1], + w=self.embedding_dim[2], + ) + self.obs_var = self.model["neck"](emb_obs) + else: + raise ValueError(f"input observation {observation.size()} is not from a valid spatial embedding.") + mean = self.model["head"](self.obs_var) + return mean + + def forward(self, observation: torch.Tensor) -> torch.Tensor: + """Model forward. Wrapper for get_action_mean() used during training. + + Args: + observation (torch.Tensor): observation. + + Returns: + torch.Tensor: mean action. + """ + return self.get_action_mean(observation) + + def get_action(self, observation: NDArray) -> tuple[NDArray, dict[str, Any]]: + """Get action with some noise used in evaluation / rollout. No gradient. + + Args: + observation (NDArray): observation. + + Returns: + tuple[NDArray, dict[str, Any]]: action and some statistics (required by mjrl) + """ + with torch.no_grad(): + observation = torch.from_numpy(observation.astype(np.float32)).unsqueeze(0).to(self.device) + mean = self.get_action_mean(observation).detach().cpu().numpy().ravel() + noise = np.exp(self.log_std_val) * np.random.randn(self.m) + action = mean + noise + return (action, {"mean": mean, "log_std": self.log_std_val, "evaluation": mean}) + + def get_action_deterministic(self, observation: NDArray) -> tuple[NDArray, dict[str, Any]]: + """Get action without noise (using mean) used in evaluation / rollout. No gradient. + + Args: + observation (NDArray): observation. + + Returns: + tuple[NDArray, dict[str, Any]]: action and some statistics (required by mjrl) + """ + with torch.no_grad(): + observation = torch.from_numpy(observation.astype(np.float32)).unsqueeze(0).to(self.device) + action = self.get_action_mean(observation).detach().cpu().numpy().ravel() + return (action, {"mean": action, "log_std": 0, "evaluation": action}) + + +class ConvPolicyHead(ConvBatchNormMLP): + """A smaller Convolution followed with a smaller BatchNormMLP (BatchNormMLP is from mjrl). + + Attrs: + embedding_dim (tuple[int, ...] | list[int, ...] | torch.Size): dimension of the representation. + proprio_dim (tuple[int, ...] | list[int, ...] | torch.Size): + dimension of the proprio information from the environment. + history_window (int): the number of history observations considered. + model (nn.ModuleDict): the dict to original BatchNormMLP (as a "head") and newly created Conv (as a "neck"). + device (str | torch.device): track the device that the model is on. + """ + + def __init__( + self, + env_spec: Any, + hidden_sizes: str = "(64, 64)", # str is to adapt with mjrl side + min_log_std: float = -3.0, + init_log_std: float = 0.0, + seed: Optional[int] = None, + nonlinearity: str = "relu", + dropout: float = 0.0, + *args: Any, + **kwargs: Any, + ): + """ + Args: + env_spec (gym.EnvSpec): specs of the environment that this policy will run on. + hidden_sizes (tuple): size of hidden layers of MLP. Defaults to (64,64). + min_log_std (float): minimum log std value for action. This is to match mjrl. Defaults to -3. + init_log_std (float): initial log std value for action. This is to match mjrl. Defaults to 0. + seed (Optional[int]): seed. Defaults to None. + nonlinearity (str): kind of non-linearility activation function. Defaults to 'relu'. + dropout (float): dropout rate. Defaults to 0. + """ + self.embedding_dim = kwargs["embedding_dim"] # [C, H, W] + self.proprio_dim = kwargs["proprio_dim"] + self.history_window = kwargs["history_window"] + hidden_sizes = eval(hidden_sizes) # hack to match mjrl + env_spec.observation_dim = hidden_sizes[0] + self.proprio_dim + super().__init__( + env_spec, hidden_sizes, min_log_std, init_log_std, seed, nonlinearity, dropout, *args, **kwargs + ) + + del self.model + + neck = nn.Sequential( + nn.Conv2d(self.embedding_dim[0] * self.history_window, 60, kernel_size=4, stride=2, padding=1), + nn.LayerNorm([60, 7, 7]), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), # 14x14-> 7x7 # just to keep the same as super class + nn.Conv2d(60, 60, kernel_size=3, stride=2), + nn.LayerNorm([60, 3, 3]), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), # 3x3 + nn.Flatten(), + ) + head = nn.Sequential( + nn.Linear(60 * 3 * 3 + self.proprio_dim, 256), + nn.LayerNorm(256), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Linear(256, self.m), + ) + self.model = nn.ModuleDict({"neck": neck, "head": head}) + self.device = None diff --git a/theia/utils/cortexbench/transforms.py b/theia/utils/cortexbench/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdaba491da065b1c05a6082fa6b63f52700c130 --- /dev/null +++ b/theia/utils/cortexbench/transforms.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import torch +import torchvision.transforms.v2 as T +from torchvision.transforms import InterpolationMode + + +def rvfm_image_transforms(output_size: int = 224) -> T.Transform: + """Image transform used by RVFM. + + Args: + output_size (int): output size of the image. + + Returns: + T.Compose: the transform + """ + return T.Compose( + [ + T.ToImage(), + T.Resize(output_size, interpolation=InterpolationMode.BICUBIC), + ] + ) + + +def vit_transforms(resize_size: int = 256, output_size: int = 224) -> T.Transform: + return T.Compose( + [ + T.ToImage(), + T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), + T.CenterCrop(output_size), + T.ToDtype(torch.float32, scale=True), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + +def r3m_transforms(resize_size: int = 256, output_size: int = 224) -> T.Transform: + return T.Compose( + [ + T.ToImage(), + T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), + T.CenterCrop(output_size), + T.ToDtype(torch.float32, scale=False), + ] + ) diff --git a/theia/utils/cortexbench/trifinger/__init__.py b/theia/utils/cortexbench/trifinger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9 --- /dev/null +++ b/theia/utils/cortexbench/trifinger/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. diff --git a/theia/utils/cortexbench/trifinger/policy.py b/theia/utils/cortexbench/trifinger/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4aba893c77b316d51df620c3f658fa00803e03 --- /dev/null +++ b/theia/utils/cortexbench/trifinger/policy.py @@ -0,0 +1,123 @@ +# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the CC-BY-NC license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Any + +import torch +import torch.nn as nn +from einops.layers.torch import Rearrange + + +class ConvBatchNormMLPDeterministicPolicy(nn.Module): + def __init__( + self, + in_dim: tuple[int, ...], + extra_dim: int, + out_dim: int, + max_a: Any = None, + hidden_size: int = 256, + nonlinearity: str = "relu", + device: str | int | torch.device = "cpu", + ) -> None: + super().__init__() + self.extra_dim = extra_dim + self.in_dim = in_dim + self.neck = nn.Sequential( + Rearrange("b (h w c) -> b c h w", h=14, w=14), + nn.Conv2d(in_dim[0], 256, kernel_size=4, stride=2, padding=1), # 14x14 -> 7x7 + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Conv2d(256, 256, kernel_size=3, stride=2), # 7x7 -> 3x3 + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Conv2d(256, 256, kernel_size=3, stride=1), # 3x3 -> 1x1 + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Flatten(), + ) + self.policy = nn.Sequential( + nn.Linear(256 + extra_dim, hidden_size), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU() if nonlinearity == "relu" else nn.Tanh(), + nn.Linear(hidden_size, out_dim), + ) + self.neck.to(device) + self.policy.to(device) + self.device = device + + self.init_state = copy.deepcopy(self.policy.state_dict()) + self.neck_init_state = copy.deepcopy(self.neck.state_dict()) + + self.max_a = max_a + self.in_dim = in_dim + self.out_dim = out_dim + + def forward(self, state: torch.Tensor) -> torch.Tensor: + visual_state = state[..., : -self.extra_dim] + feature = self.neck(visual_state) + if self.extra_dim > 0: + feature = torch.cat([feature, state[..., -self.extra_dim :]], dim=1) + action = self.policy(feature) + return action + + def reset(self) -> None: + self.policy.load_state_dict(self.init_state) + self.neck.load_state_dict(self.neck_init_state) + + def clip_action(self, a: torch.Tensor) -> torch.Tensor: + if self.max_a is not None: + a = torch.where(a > self.max_a, torch.tensor([self.max_a]).to(self.device), a) + a = torch.where(a < -self.max_a, -torch.tensor([self.max_a]).to(self.device), a) + return a + + def scale_to_range(self, a: torch.Tensor) -> torch.Tensor: + """Does not do anything; just returns a""" + return a + + +def construct_policy( + type: str, + task_state_type: str, + train_ft_state_shape: int, + pretrained_dim: tuple[int, ...], + task_goal_type: str, + out_dim: int, + max_a: Any, + device: str | int | torch.device, + hidden_size: int = 256, + nonlinearity: str = "relu", + **kwargs: Any, +) -> ConvBatchNormMLPDeterministicPolicy: + in_dim = pretrained_dim + extra_dim = 0 + if task_state_type == "obj": + extra_dim += 0 + elif task_state_type in ["ftpos_obj", "ftpos"]: + extra_dim += train_ft_state_shape + else: + raise NameError("Invalid state_type") + + if task_goal_type == "goal_none": + in_dim = pretrained_dim + elif task_goal_type == "goal_cond": + in_dim = (pretrained_dim[0] * 2, *pretrained_dim[1:]) + elif task_goal_type == "goal_o_pos": + extra_dim += 3 + else: + raise NameError("Invalid goal_type") + + if type == "ConvBatchNormMLP": + policy = ConvBatchNormMLPDeterministicPolicy( + in_dim=in_dim, + extra_dim=extra_dim, + out_dim=out_dim, + max_a=max_a, + hidden_size=hidden_size, + nonlinearity=nonlinearity, + device=device, + ) + else: + raise NotImplementedError(f"Policy network {type} is not supported.") + return policy diff --git a/theia/utils/logging.py b/theia/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..8f16b3d5c473edb39b378b7e0bf55964048fbf70 --- /dev/null +++ b/theia/utils/logging.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +from enum import Enum +from typing import Any + +import torch +import torch.distributed as dist +import wandb + + +class SummaryType(Enum): + NONE = 0 + AVERAGE = 1 + SUM = 2 + COUNT = 3 + + +class AverageMeter(object): + """Computes and stores the average and current value + Attributes: + name (str): name of the meter. + fmt (str): format string. Defaults to ':f'. + summary_type (Enum): reduce method. Defaults to Summary.AVERAGE. + + val (float): last mean value over batch. + avg (float): average value since meter creation. + sum (float): sum of all the values = self.avg * self.count. + count (int): number of values considered since meter creation. + """ + + def __init__(self, name: str, fmt: str = ":f", summary_type: SummaryType = SummaryType.AVERAGE) -> None: + """Initialize an average meter.""" + self.name = name + self.fmt = fmt + self.summary_type = summary_type + self.reset() + + def reset(self) -> None: + """Reset the meter.""" + self.val: float = 0.0 + self.avg: float = 0.0 + self.sum: float = 0.0 + self.count: int = 0 + + def update(self, val: float, n: int = 1) -> None: + """Update the meter. + + Args: + val (float): (mean) value over n samples. + n (int): number of samples. Defaults to 1. + """ + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def all_reduce(self) -> None: + """Reduce meters across ranks.""" + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) + dist.all_reduce(total, dist.ReduceOp.SUM, async_op=True) + self.sum, self.count = total.tolist() + self.avg = self.sum / self.count + + def __str__(self) -> str: + """String representation of the meter.""" + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + def summary(self) -> str: + """Print the summary of the meter status.""" + fmtstr = "" + match self.summary_type: + case SummaryType.NONE: + fmtstr = "" + case SummaryType.AVERAGE: + fmtstr = "{name} {avg:.3f}" + case SummaryType.SUM: + fmtstr = "{name} {sum:.3f}" + case SummaryType.COUNT: + fmtstr = "{name} {count:.3f}" + case _: + raise ValueError("invalid summary type %r" % self.summary_type) + + return fmtstr.format(**self.__dict__) + + +def create_meters(target_model_names: list[str]) -> dict[str, AverageMeter]: + """Create meters for logging statistics, including individual meters for each target model. + + Args: + target_model_names (list[str]): names of the target models. + + Returns: + dict[str, AverageMeter]: meters created + """ + meters = {} + for loss in ["mse", "cos", "l1"]: + meters[f"train_{loss}_loss"] = AverageMeter(f"train_{loss}_loss") + meters[f"eval_{loss}_loss"] = AverageMeter(f"eval_{loss}_loss") + + for t in target_model_names: + for loss in ["mse", "cos", "l1"]: + for mode in ["train", "eval"]: + meters[f"{mode}_{t}_{loss}_loss"] = AverageMeter(f"{mode}_{t}_{loss}_loss") + + return meters + + +def log_metrics(meters: dict[str, AverageMeter], **kwargs: Any) -> None: + """log metrics to wandb. + + Args: + meters (dict[str, AverageMeter]): _description_ + """ + metrics = {} + + mode = kwargs["mode"] + batch_size = kwargs["batch_size"] if "batch_size" in kwargs else 0 + + if not kwargs.get("only_upload", False): + # update meters + meters[f"{mode}_mse_loss"].update(kwargs["mse_loss"].item(), n=batch_size) + meters[f"{mode}_cos_loss"].update(kwargs["cos_loss"].item(), n=batch_size) + meters[f"{mode}_l1_loss"].update(kwargs["l1_loss"].item(), n=batch_size) + + for t in kwargs["target_model_names"]: + for loss in ["mse", "cos", "l1"]: + meters[f"{mode}_{t}_{loss}_loss"].update(kwargs[f"{loss}_losses_per_model"][t], n=batch_size) + + # read out from meters or the raw for logging + if kwargs["upload_wandb"]: + if mode == "train": + metrics["loss"] = kwargs["main_loss"].item() + metrics["mse_loss"] = kwargs["mse_loss"].item() + metrics["cos_loss"] = kwargs["cos_loss"].item() + metrics["l1_loss"] = kwargs["l1_loss"].item() + + metrics[f"avg_{mode}_mse_loss"] = meters[f"{mode}_mse_loss"].avg + metrics[f"avg_{mode}_cos_loss"] = meters[f"{mode}_cos_loss"].avg + metrics[f"avg_{mode}_l1_loss"] = meters[f"{mode}_l1_loss"].avg + for t in kwargs["target_model_names"]: + for loss in ["mse", "cos", "l1"]: + metrics[f"avg_{mode}_{t}_{loss}_loss"] = meters[f"{mode}_{t}_{loss}_loss"].avg + + if kwargs["device"] == 0: + wandb.log(metrics) diff --git a/theia/utils/seed.py b/theia/utils/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..69bf644b7b76946906c5a1bdd9e0028a309869c1 --- /dev/null +++ b/theia/utils/seed.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. + +import os +import random +from typing import Any, Optional + +import numpy as np +import torch + +max_seed_value = np.iinfo(np.uint32).max +min_seed_value = np.iinfo(np.uint32).min + + +def seed_everything(seed: Optional[Any] = None, workers: bool = False) -> int: + """Seed everything adopted from lightning_fabric.utilities.seed.seed_everything. + + Avoid using lightning only for seeding. + + Args: + seed (Optional[Any]): seed, preferably an integer, or other stuff can be converted to an integer. + + Returns: + int: the actual seed used. It should be the same as input seed in most of the cases. + """ + if seed is None: + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is None: + seed = 0 + else: + try: + seed = int(env_seed) + except ValueError: + seed = 0 + elif not isinstance(seed, int): + seed = int(seed) + + if not (min_seed_value <= seed <= max_seed_value): + seed = 0 + + os.environ["PL_GLOBAL_SEED"] = str(seed) + os.environ["PYTHON_SEED"] = str(seed) # add python seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" + + return seed