From 33fef0a788f37ca0884bd4457972c330b5ebd9bb Mon Sep 17 00:00:00 2001 From: junliang-lin Date: Tue, 10 Jan 2023 14:32:06 -0500 Subject: [PATCH] Add support for output padding in flipout layers --- bayesian_torch/layers/flipout_layers/conv_flipout.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py index 311462a..c92d24b 100644 --- a/bayesian_torch/layers/flipout_layers/conv_flipout.py +++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py @@ -557,6 +557,7 @@ def __init__(self, padding=0, dilation=1, groups=1, + output_padding=0, prior_mean=0, prior_variance=1, posterior_mu_init=0, @@ -588,6 +589,7 @@ def __init__(self, self.kernel_size = kernel_size self.stride = stride self.padding = padding + self.output_padding = output_padding self.dilation = dilation self.groups = groups self.bias = bias @@ -669,6 +671,7 @@ def forward(self, x, return_kl=True): bias=self.mu_bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) @@ -702,6 +705,7 @@ def forward(self, x, return_kl=True): bias=bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) * sign_output @@ -719,6 +723,7 @@ def __init__(self, kernel_size, stride=1, padding=0, + output_padding=0, dilation=1, groups=1, prior_mean=0, @@ -752,6 +757,7 @@ def __init__(self, self.kernel_size = kernel_size self.stride = stride self.padding = padding + self.output_padding = output_padding self.dilation = dilation self.groups = groups self.bias = bias @@ -837,6 +843,7 @@ def forward(self, x, return_kl=True): weight=self.mu_kernel, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) @@ -870,6 +877,7 @@ def forward(self, x, return_kl=True): weight=delta_kernel, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) * sign_output @@ -887,6 +895,7 @@ def __init__(self, kernel_size, stride=1, padding=0, + output_padding=0, dilation=1, groups=1, prior_mean=0, @@ -920,6 +929,7 @@ def __init__(self, self.kernel_size = kernel_size self.stride = stride self.padding = padding + self.output_padding = output_padding self.dilation = dilation self.groups = groups @@ -1005,6 +1015,7 @@ def forward(self, x, return_kl=True): bias=self.mu_bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) @@ -1037,6 +1048,7 @@ def forward(self, x, return_kl=True): bias=bias, stride=self.stride, padding=self.padding, + output_padding=self.output_padding, dilation=self.dilation, groups=self.groups) * sign_output