|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import os |
7 | 8 | import unittest |
8 | 9 |
|
9 | 10 | import torch |
| 11 | +from omegaconf import DictConfig, OmegaConf |
10 | 12 | from pytorch3d.implicitron.models.generic_model import GenericModel |
11 | 13 | from pytorch3d.implicitron.models.renderer.base import EvaluationMode |
12 | 14 | from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args |
13 | 15 | from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras |
14 | 16 |
|
15 | 17 |
|
| 18 | +if os.environ.get("FB_TEST", False): |
| 19 | + from common_testing import get_pytorch3d_dir |
| 20 | +else: |
| 21 | + from tests.common_testing import get_pytorch3d_dir |
| 22 | + |
| 23 | +IMPLICITRON_CONFIGS_DIR = ( |
| 24 | + get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs" |
| 25 | +) |
| 26 | + |
| 27 | + |
16 | 28 | class TestGenericModel(unittest.TestCase): |
| 29 | + def setUp(self): |
| 30 | + torch.manual_seed(42) |
| 31 | + |
17 | 32 | def test_gm(self): |
18 | 33 | # Simple test of a forward and backward pass of the default GenericModel. |
19 | 34 | device = torch.device("cuda:1") |
20 | 35 | expand_args_fields(GenericModel) |
21 | 36 | model = GenericModel() |
22 | 37 | model.to(device) |
| 38 | + self._one_model_test(model, device) |
| 39 | + |
| 40 | + def test_all_gm_configs(self): |
| 41 | + # Tests all model settings in the implicitron_trainer config folder. |
| 42 | + device = torch.device("cuda:0") |
| 43 | + config_files = [] |
| 44 | + |
| 45 | + for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): |
| 46 | + config_files.extend( |
| 47 | + [ |
| 48 | + f |
| 49 | + for f in IMPLICITRON_CONFIGS_DIR.glob(pattern) |
| 50 | + if not f.name.endswith("_base.yaml") |
| 51 | + ] |
| 52 | + ) |
| 53 | + |
| 54 | + for config_file in config_files: |
| 55 | + with self.subTest(name=config_file.stem): |
| 56 | + cfg = _load_model_config_from_yaml(str(config_file)) |
| 57 | + model = GenericModel(**cfg) |
| 58 | + model.to(device) |
| 59 | + self._one_model_test(model, device, eval_test=True) |
| 60 | + |
| 61 | + def _one_model_test( |
| 62 | + self, |
| 63 | + model, |
| 64 | + device, |
| 65 | + n_train_cameras: int = 5, |
| 66 | + eval_test: bool = True, |
| 67 | + ): |
23 | 68 |
|
24 | | - n_train_cameras = 2 |
25 | 69 | R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) |
26 | 70 | cameras = PerspectiveCameras(R=R, T=T, device=device) |
27 | 71 |
|
28 | | - # TODO: make these default to None? |
29 | | - defaulted_args = { |
30 | | - "fg_probability": None, |
31 | | - "depth_map": None, |
32 | | - "mask_crop": None, |
33 | | - "sequence_name": None, |
| 72 | + N, H, W = n_train_cameras, model.render_image_height, model.render_image_width |
| 73 | + |
| 74 | + random_args = { |
| 75 | + "camera": cameras, |
| 76 | + "fg_probability": _random_input_tensor(N, 1, H, W, True, device), |
| 77 | + "depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1, |
| 78 | + "mask_crop": _random_input_tensor(N, 1, H, W, True, device), |
| 79 | + "sequence_name": ["sequence"] * N, |
| 80 | + "image_rgb": _random_input_tensor(N, 3, H, W, False, device), |
34 | 81 | } |
35 | 82 |
|
36 | | - with self.assertWarnsRegex(UserWarning, "No main objective found"): |
37 | | - model( |
38 | | - camera=cameras, |
39 | | - evaluation_mode=EvaluationMode.TRAINING, |
40 | | - **defaulted_args, |
41 | | - image_rgb=None, |
42 | | - ) |
43 | | - target_image_rgb = torch.rand( |
44 | | - (n_train_cameras, 3, model.render_image_height, model.render_image_width), |
45 | | - device=device, |
46 | | - ) |
| 83 | + # training foward pass |
| 84 | + model.train() |
47 | 85 | train_preds = model( |
48 | | - camera=cameras, |
| 86 | + **random_args, |
49 | 87 | evaluation_mode=EvaluationMode.TRAINING, |
50 | | - image_rgb=target_image_rgb, |
51 | | - **defaulted_args, |
52 | 88 | ) |
53 | 89 | self.assertGreater(train_preds["objective"].item(), 0) |
54 | 90 | train_preds["objective"].backward() |
55 | 91 |
|
56 | | - model.eval() |
57 | | - with torch.no_grad(): |
58 | | - # TODO: perhaps this warning should be skipped in eval mode? |
59 | | - with self.assertWarnsRegex(UserWarning, "No main objective found"): |
| 92 | + if eval_test: |
| 93 | + model.eval() |
| 94 | + with torch.no_grad(): |
60 | 95 | eval_preds = model( |
61 | | - camera=cameras[0], |
62 | | - **defaulted_args, |
63 | | - image_rgb=None, |
| 96 | + **random_args, |
| 97 | + evaluation_mode=EvaluationMode.EVALUATION, |
| 98 | + ) |
| 99 | + self.assertEqual( |
| 100 | + eval_preds["images_render"].shape, |
| 101 | + (1, 3, model.render_image_height, model.render_image_width), |
64 | 102 | ) |
65 | | - self.assertEqual( |
66 | | - eval_preds["images_render"].shape, |
67 | | - (1, 3, model.render_image_height, model.render_image_width), |
68 | | - ) |
69 | 103 |
|
70 | 104 | def test_idr(self): |
71 | 105 | # Forward pass of GenericModel with IDR. |
@@ -104,3 +138,44 @@ def test_idr(self): |
104 | 138 | **defaulted_args, |
105 | 139 | ) |
106 | 140 | self.assertGreater(train_preds["objective"].item(), 0) |
| 141 | + |
| 142 | + |
| 143 | +def _random_input_tensor( |
| 144 | + N: int, |
| 145 | + C: int, |
| 146 | + H: int, |
| 147 | + W: int, |
| 148 | + is_binary: bool, |
| 149 | + device: torch.device, |
| 150 | +) -> torch.Tensor: |
| 151 | + T = torch.rand(N, C, H, W, device=device) |
| 152 | + if is_binary: |
| 153 | + T = (T > 0.5).float() |
| 154 | + return T |
| 155 | + |
| 156 | + |
| 157 | +def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig: |
| 158 | + default_cfg = get_default_args(GenericModel) |
| 159 | + cfg = _load_model_config_from_yaml_rec(default_cfg, config_path) |
| 160 | + return cfg |
| 161 | + |
| 162 | + |
| 163 | +def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig: |
| 164 | + cfg_loaded = OmegaConf.load(config_path) |
| 165 | + if "generic_model_args" in cfg_loaded: |
| 166 | + cfg_model_loaded = cfg_loaded.generic_model_args |
| 167 | + else: |
| 168 | + cfg_model_loaded = None |
| 169 | + defaults = cfg_loaded.pop("defaults", None) |
| 170 | + if defaults is not None: |
| 171 | + for default_name in defaults: |
| 172 | + if default_name in ("_self_", "default_config"): |
| 173 | + continue |
| 174 | + default_name = os.path.splitext(default_name)[0] |
| 175 | + defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml") |
| 176 | + cfg = _load_model_config_from_yaml_rec(cfg, defpath) |
| 177 | + if cfg_model_loaded is not None: |
| 178 | + cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
| 179 | + elif cfg_model_loaded is not None: |
| 180 | + cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
| 181 | + return cfg |
0 commit comments