Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 98df275

Browse files
authored
Refactor code with new pre commit configuration (#207)
* Remove CLang and build options from runtests.sh * Remove CLang and build options from runtests.sh * Run ./runtests.sh to format code, fix flake8 errors * Add changes to the ipynb files * Comment black and isort from pre-commit Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
1 parent fb6b463 commit 98df275

File tree

64 files changed

+279
-912
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+279
-912
lines changed

.clang-format

Lines changed: 0 additions & 88 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ repos:
6363
- id: pycln
6464
args: [--config=pyproject.toml]
6565

66-
- repo: https://github.com/psf/black
67-
rev: 22.3.0
68-
hooks:
69-
- id: black
70-
71-
- repo: https://github.com/PyCQA/isort
72-
rev: 5.9.3
73-
hooks:
74-
- id: isort
66+
# - repo: https://github.com/psf/black
67+
# rev: 22.3.0
68+
# hooks:
69+
# - id: black
70+
#
71+
# - repo: https://github.com/PyCQA/isort
72+
# rev: 5.9.3
73+
# hooks:
74+
# - id: isort

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Before submitting a pull request, we recommend that all linting and unit tests
8989
should pass, by running the following command locally:
9090

9191
```bash
92-
./runtests.sh -u --net
92+
./runtests.sh -f -u --net
9393
```
9494
or (for new features that would not break existing functionality):
9595

generative/engines/trainer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
14+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
1515

1616
import torch
1717
from monai.config import IgniteInfo
@@ -88,34 +88,34 @@ class AdversarialTrainer(Trainer):
8888

8989
def __init__(
9090
self,
91-
device: Union[torch.device, str],
91+
device: torch.device | str,
9292
max_epochs: int,
93-
train_data_loader: Union[Iterable, DataLoader],
93+
train_data_loader: Iterable | DataLoader,
9494
g_network: torch.nn.Module,
9595
g_optimizer: Optimizer,
9696
g_loss_function: Callable,
9797
recon_loss_function: Callable,
9898
d_network: torch.nn.Module,
9999
d_optimizer: Optimizer,
100100
d_loss_function: Callable,
101-
epoch_length: Optional[int] = None,
101+
epoch_length: int | None = None,
102102
non_blocking: bool = False,
103-
prepare_batch: Union[Callable[[Engine, Any], Any], None] = default_prepare_batch,
104-
iteration_update: Optional[Callable] = None,
105-
g_inferer: Optional[Inferer] = None,
106-
d_inferer: Optional[Inferer] = None,
107-
postprocessing: Optional[Transform] = None,
108-
key_train_metric: Optional[Dict[str, Metric]] = None,
109-
additional_metrics: Optional[Dict[str, Metric]] = None,
103+
prepare_batch: Callable[[Engine, Any], Any] | None = default_prepare_batch,
104+
iteration_update: Callable | None = None,
105+
g_inferer: Inferer | None = None,
106+
d_inferer: Inferer | None = None,
107+
postprocessing: Transform | None = None,
108+
key_train_metric: dict[str, Metric] | None = None,
109+
additional_metrics: dict[str, Metric] | None = None,
110110
metric_cmp_fn: Callable = default_metric_cmp_fn,
111-
train_handlers: Optional[Sequence] = None,
111+
train_handlers: Sequence | None = None,
112112
amp: bool = False,
113-
event_names: Union[List[Union[str, EventEnum]], None] = None,
114-
event_to_attr: Union[dict, None] = None,
113+
event_names: list[str | EventEnum] | None = None,
114+
event_to_attr: dict | None = None,
115115
decollate: bool = True,
116116
optim_set_to_none: bool = False,
117-
to_kwargs: Union[dict, None] = None,
118-
amp_kwargs: Union[dict, None] = None,
117+
to_kwargs: dict | None = None,
118+
amp_kwargs: dict | None = None,
119119
):
120120
super().__init__(
121121
device=device,
@@ -183,8 +183,8 @@ def _complete_state_dict_user_keys(self) -> None:
183183
self._state_dict_user_keys.append("recon_loss_function")
184184

185185
def _iteration(
186-
self, engine: AdversarialTrainer, batchdata: Dict[str, torch.Tensor]
187-
) -> Dict[str, Union[torch.Tensor, int, float, bool]]:
186+
self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
187+
) -> dict[str, torch.Tensor | int | float | bool]:
188188
"""
189189
Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
190190
Return below items in a dictionary:
@@ -219,8 +219,8 @@ def _iteration(
219219

220220
if len(batch) == 2:
221221
inputs, targets = batch
222-
args: Tuple = ()
223-
kwargs: Dict = {}
222+
args: tuple = ()
223+
kwargs: dict = {}
224224
else:
225225
inputs, targets, args, kwargs = batch
226226

generative/inferers/inferer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_likelihood(
141141
progress_bar = iter(scheduler.timesteps)
142142
intermediates = []
143143
noise = torch.randn_like(inputs).to(inputs.device)
144-
total_kl = torch.zeros((inputs.shape[0])).to(inputs.device)
144+
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
145145
for t in progress_bar:
146146
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
147147
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
@@ -228,8 +228,8 @@ def _get_decoder_log_likelihood(
228228
inputs: torch.Tensor,
229229
means: torch.Tensor,
230230
log_scales: torch.Tensor,
231-
original_input_range: Optional[Tuple] = [0, 255],
232-
scaled_input_range: Optional[Tuple] = [0, 1],
231+
original_input_range: Optional[Tuple] = (0, 255),
232+
scaled_input_range: Optional[Tuple] = (0, 1),
233233
) -> torch.Tensor:
234234
"""
235235
Compute the log-likelihood of a Gaussian distribution discretizing to a
@@ -304,11 +304,7 @@ def __call__(
304304
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
305305

306306
prediction = super().__call__(
307-
inputs=latent,
308-
diffusion_model=diffusion_model,
309-
noise=noise,
310-
timesteps=timesteps,
311-
condition=condition,
307+
inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition
312308
)
313309

314310
return prediction

generative/losses/adversarial_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def forward(
128128
if type(input) is not list:
129129
input = [input]
130130
target_ = []
131-
for disc_ind, disc_out in enumerate(input):
131+
for _, disc_out in enumerate(input):
132132
if self.criterion != AdversarialCriterions.HINGE.value:
133133
target_.append(self.get_target_tensor(disc_out, target_is_real))
134134
else:

generative/losses/perceptual.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ class PerceptualLoss(nn.Module):
3838
"""
3939

4040
def __init__(
41-
self,
42-
spatial_dims: int,
43-
network_type: str = "alex",
44-
is_fake_3d: bool = True,
45-
fake_3d_ratio: float = 0.5,
41+
self, spatial_dims: int, network_type: str = "alex", is_fake_3d: bool = True, fake_3d_ratio: float = 0.5
4642
):
4743
super().__init__()
4844

@@ -58,11 +54,7 @@ def __init__(
5854
elif "radimagenet_" in network_type:
5955
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
6056
else:
61-
self.perceptual_function = LPIPS(
62-
pretrained=True,
63-
net=network_type,
64-
verbose=False,
65-
)
57+
self.perceptual_function = LPIPS(pretrained=True, net=network_type, verbose=False)
6658
self.is_fake_3d = is_fake_3d
6759
self.fake_3d_ratio = fake_3d_ratio
6860

@@ -90,26 +82,12 @@ def batchify_axis(x: torch.Tensor, fake_3d_perm: Tuple) -> torch.Tensor:
9082
preserved_axes.remove(spatial_axis)
9183

9284
channel_axis = 1
93-
input_slices = batchify_axis(
94-
x=input,
95-
fake_3d_perm=(
96-
spatial_axis,
97-
channel_axis,
98-
)
99-
+ tuple(preserved_axes),
100-
)
85+
input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
10186
indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to(
10287
input_slices.device
10388
)
10489
input_slices = torch.index_select(input_slices, dim=0, index=indices)
105-
target_slices = batchify_axis(
106-
x=target,
107-
fake_3d_perm=(
108-
spatial_axis,
109-
channel_axis,
110-
)
111-
+ tuple(preserved_axes),
112-
)
90+
target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
11391
target_slices = torch.index_select(target_slices, dim=0, index=indices)
11492

11593
axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices))
@@ -150,11 +128,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
150128
verbose: if false, mute messages from torch Hub load function.
151129
"""
152130

153-
def __init__(
154-
self,
155-
net: str = "medicalnet_resnet10_23datasets",
156-
verbose: bool = False,
157-
) -> None:
131+
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
158132
super().__init__()
159133
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
160134
self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose)
@@ -216,11 +190,7 @@ class RadImageNetPerceptualSimilarity(nn.Module):
216190
verbose: if false, mute messages from torch Hub load function.
217191
"""
218192

219-
def __init__(
220-
self,
221-
net: str = "radimagenet_resnet50",
222-
verbose: bool = False,
223-
) -> None:
193+
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
224194
super().__init__()
225195
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
226196
self.eval()

generative/losses/spectral_loss.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,7 @@ def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor:
8080
Returns:
8181
fourier transformation amplitude
8282
"""
83-
img_fft = fftn(
84-
images,
85-
s=self.fft_signal_size,
86-
dim=self.fft_dim,
87-
norm=self.fft_norm,
88-
)
83+
img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm)
8984

9085
amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2)
9186

generative/metrics/fid.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6060
y_pred = y_pred.float()
6161

6262
if y.ndimension() > 2:
63-
raise ValueError(f"Inputs should have (number images, number of features) shape.")
63+
raise ValueError("Inputs should have (number images, number of features) shape.")
6464

6565
mu_y_pred = torch.mean(y_pred, dim=0)
6666
sigma_y_pred = _cov(y_pred, rowvar=False)
@@ -114,9 +114,9 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[to
114114
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)
115115

116116
for _ in range(num_iters):
117-
T = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
118-
y_matrix = y_matrix.mm(T)
119-
z_matrix = T.mm(z_matrix)
117+
t = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
118+
y_matrix = y_matrix.mm(t)
119+
z_matrix = t.mm(z_matrix)
120120

121121
s_matrix = y_matrix * torch.sqrt(norm_of_matrix)
122122

generative/metrics/mmd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor
6969

7070
if y_pred.shape != y.shape:
7171
raise ValueError(
72-
f"y_pred and y shapes dont match after being processed by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
72+
"y_pred and y shapes dont match after being processed "
73+
f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
7374
)
7475

7576
for d in range(len(y.shape) - 1, 1, -1):

0 commit comments

Comments
 (0)