Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 0 additions & 88 deletions .clang-format

This file was deleted.

18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ repos:
- id: pycln
args: [--config=pyproject.toml]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black

- repo: https://github.com/PyCQA/isort
rev: 5.9.3
hooks:
- id: isort
# - repo: https://github.com/psf/black
# rev: 22.3.0
# hooks:
# - id: black
#
# - repo: https://github.com/PyCQA/isort
# rev: 5.9.3
# hooks:
# - id: isort
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Before submitting a pull request, we recommend that all linting and unit tests
should pass, by running the following command locally:

```bash
./runtests.sh -u --net
./runtests.sh -f -u --net
```
or (for new features that would not break existing functionality):

Expand Down
40 changes: 20 additions & 20 deletions generative/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import torch
from monai.config import IgniteInfo
Expand Down Expand Up @@ -88,34 +88,34 @@ class AdversarialTrainer(Trainer):

def __init__(
self,
device: Union[torch.device, str],
device: torch.device | str,
max_epochs: int,
train_data_loader: Union[Iterable, DataLoader],
train_data_loader: Iterable | DataLoader,
g_network: torch.nn.Module,
g_optimizer: Optimizer,
g_loss_function: Callable,
recon_loss_function: Callable,
d_network: torch.nn.Module,
d_optimizer: Optimizer,
d_loss_function: Callable,
epoch_length: Optional[int] = None,
epoch_length: int | None = None,
non_blocking: bool = False,
prepare_batch: Union[Callable[[Engine, Any], Any], None] = default_prepare_batch,
iteration_update: Optional[Callable] = None,
g_inferer: Optional[Inferer] = None,
d_inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
prepare_batch: Callable[[Engine, Any], Any] | None = default_prepare_batch,
iteration_update: Callable | None = None,
g_inferer: Inferer | None = None,
d_inferer: Inferer | None = None,
postprocessing: Transform | None = None,
key_train_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Optional[Sequence] = None,
train_handlers: Sequence | None = None,
amp: bool = False,
event_names: Union[List[Union[str, EventEnum]], None] = None,
event_to_attr: Union[dict, None] = None,
event_names: list[str | EventEnum] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: Union[dict, None] = None,
amp_kwargs: Union[dict, None] = None,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
):
super().__init__(
device=device,
Expand Down Expand Up @@ -183,8 +183,8 @@ def _complete_state_dict_user_keys(self) -> None:
self._state_dict_user_keys.append("recon_loss_function")

def _iteration(
self, engine: AdversarialTrainer, batchdata: Dict[str, torch.Tensor]
) -> Dict[str, Union[torch.Tensor, int, float, bool]]:
self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor | int | float | bool]:
"""
Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand Down Expand Up @@ -219,8 +219,8 @@ def _iteration(

if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
kwargs: Dict = {}
args: tuple = ()
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch

Expand Down
12 changes: 4 additions & 8 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_likelihood(
progress_bar = iter(scheduler.timesteps)
intermediates = []
noise = torch.randn_like(inputs).to(inputs.device)
total_kl = torch.zeros((inputs.shape[0])).to(inputs.device)
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
Expand Down Expand Up @@ -228,8 +228,8 @@ def _get_decoder_log_likelihood(
inputs: torch.Tensor,
means: torch.Tensor,
log_scales: torch.Tensor,
original_input_range: Optional[Tuple] = [0, 255],
scaled_input_range: Optional[Tuple] = [0, 1],
original_input_range: Optional[Tuple] = (0, 255),
scaled_input_range: Optional[Tuple] = (0, 1),
) -> torch.Tensor:
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
Expand Down Expand Up @@ -304,11 +304,7 @@ def __call__(
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

prediction = super().__call__(
inputs=latent,
diffusion_model=diffusion_model,
noise=noise,
timesteps=timesteps,
condition=condition,
inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition
)

return prediction
Expand Down
2 changes: 1 addition & 1 deletion generative/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(
if type(input) is not list:
input = [input]
target_ = []
for disc_ind, disc_out in enumerate(input):
for _, disc_out in enumerate(input):
if self.criterion != AdversarialCriterions.HINGE.value:
target_.append(self.get_target_tensor(disc_out, target_is_real))
else:
Expand Down
42 changes: 6 additions & 36 deletions generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ class PerceptualLoss(nn.Module):
"""

def __init__(
self,
spatial_dims: int,
network_type: str = "alex",
is_fake_3d: bool = True,
fake_3d_ratio: float = 0.5,
self, spatial_dims: int, network_type: str = "alex", is_fake_3d: bool = True, fake_3d_ratio: float = 0.5
):
super().__init__()

Expand All @@ -58,11 +54,7 @@ def __init__(
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
else:
self.perceptual_function = LPIPS(
pretrained=True,
net=network_type,
verbose=False,
)
self.perceptual_function = LPIPS(pretrained=True, net=network_type, verbose=False)
self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio

Expand Down Expand Up @@ -90,26 +82,12 @@ def batchify_axis(x: torch.Tensor, fake_3d_perm: Tuple) -> torch.Tensor:
preserved_axes.remove(spatial_axis)

channel_axis = 1
input_slices = batchify_axis(
x=input,
fake_3d_perm=(
spatial_axis,
channel_axis,
)
+ tuple(preserved_axes),
)
input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to(
input_slices.device
)
input_slices = torch.index_select(input_slices, dim=0, index=indices)
target_slices = batchify_axis(
x=target,
fake_3d_perm=(
spatial_axis,
channel_axis,
)
+ tuple(preserved_axes),
)
target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
target_slices = torch.index_select(target_slices, dim=0, index=indices)

axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices))
Expand Down Expand Up @@ -150,11 +128,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(
self,
net: str = "medicalnet_resnet10_23datasets",
verbose: bool = False,
) -> None:
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose)
Expand Down Expand Up @@ -216,11 +190,7 @@ class RadImageNetPerceptualSimilarity(nn.Module):
verbose: if false, mute messages from torch Hub load function.
"""

def __init__(
self,
net: str = "radimagenet_resnet50",
verbose: bool = False,
) -> None:
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.eval()
Expand Down
7 changes: 1 addition & 6 deletions generative/losses/spectral_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,7 @@ def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor:
Returns:
fourier transformation amplitude
"""
img_fft = fftn(
images,
s=self.fft_signal_size,
dim=self.fft_dim,
norm=self.fft_norm,
)
img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm)

amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2)

Expand Down
8 changes: 4 additions & 4 deletions generative/metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y_pred = y_pred.float()

if y.ndimension() > 2:
raise ValueError(f"Inputs should have (number images, number of features) shape.")
raise ValueError("Inputs should have (number images, number of features) shape.")

mu_y_pred = torch.mean(y_pred, dim=0)
sigma_y_pred = _cov(y_pred, rowvar=False)
Expand Down Expand Up @@ -114,9 +114,9 @@ def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[to
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)

for _ in range(num_iters):
T = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
y_matrix = y_matrix.mm(T)
z_matrix = T.mm(z_matrix)
t = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
y_matrix = y_matrix.mm(t)
z_matrix = t.mm(z_matrix)

s_matrix = y_matrix * torch.sqrt(norm_of_matrix)

Expand Down
3 changes: 2 additions & 1 deletion generative/metrics/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor

if y_pred.shape != y.shape:
raise ValueError(
f"y_pred and y shapes dont match after being processed by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
"y_pred and y shapes dont match after being processed "
f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
)

for d in range(len(y.shape) - 1, 1, -1):
Expand Down
8 changes: 1 addition & 7 deletions generative/metrics/ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,7 @@ def __init__(
self.weights = weights
self.reduction = reduction

self.SSIM = SSIMMetric(
self.data_range,
self.win_size,
self.k1,
self.k2,
self.spatial_dims,
)
self.SSIM = SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims)

def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down
Loading