diff --git a/src/dim_helpers.jl b/src/dim_helpers.jl index 14f06038e..22d5636a7 100644 --- a/src/dim_helpers.jl +++ b/src/dim_helpers.jl @@ -62,11 +62,15 @@ spatial dimension at the end of the spatial dimensions. This does so for a Conv ) end -@inline function insert_singleton_spatial_dimension(x::AbstractArray) - return reshape(x, size(x)[1:end-2]..., 1, size(x)[end-1:end]...) +# We specialize common cases +@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T} + return reshape(x, size(x,1), 1, size(x,2), size(x,3)) +end +@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T} + return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4)) end -# Helper to do this multiple times +# Helper to do this as many times as needed @inline function insert_singleton_spatial_dimension(x, reps::Int) for r in 1:reps x = insert_singleton_spatial_dimension(x) diff --git a/src/dim_helpers/DenseConvDims.jl b/src/dim_helpers/DenseConvDims.jl index df559e6fa..9509f5b42 100644 --- a/src/dim_helpers/DenseConvDims.jl +++ b/src/dim_helpers/DenseConvDims.jl @@ -62,16 +62,16 @@ end function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M} # First, check that channel counts are all correct: - @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))") - @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))") - @assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))") - @assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))") + @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") + @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") + @assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))") + @assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))") # Next, check that the spatial dimensions match up - @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))") - @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))") - @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))") + @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") + @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") + @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") # Finally, check that the batch size matches - @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") + @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") end diff --git a/src/dim_helpers/DepthwiseConvDims.jl b/src/dim_helpers/DepthwiseConvDims.jl index a0555ff58..4c25eea6f 100644 --- a/src/dim_helpers/DepthwiseConvDims.jl +++ b/src/dim_helpers/DepthwiseConvDims.jl @@ -7,19 +7,16 @@ Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from DenseConvDims primarily for channel calculation differences. """ -struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F} +struct DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F} <: ConvDims{N,S,P,D,F} I::NTuple{N, Int} - K::NTuple{N, Int} - C_in::Int - C_mult::Int end # Getters for the fields input_size(c::DepthwiseConvDims) = c.I -kernel_size(c::DepthwiseConvDims) = c.K -channels_in(c::DepthwiseConvDims) = c.C_in -channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c) -channel_multiplier(c::DepthwiseConvDims) = c.C_mult +kernel_size(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = K +channels_in(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in +channels_out(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in * C_mult +channel_multiplier(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_mult # Convenience wrapper to create DepthwiseConvDims objects @@ -37,6 +34,12 @@ function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; return DepthwiseConvDims{ M - 2, + # Kernel spatial size + w_size[1:end-2], + # Input channels + x_size[end-1], + # Channel multiplier + w_size[end-1], stride, padding, dilation, @@ -44,15 +47,6 @@ function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M}; }( # Image spatial size x_size[1:end-2], - - # Kernel spatial size - w_size[1:end-2], - - # Input channels - x_size[end-1], - - # Channel multiplier - w_size[end-1], ) end @@ -69,22 +63,22 @@ end function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c), C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c), P=padding(c), D=dilation(c), F=flipkernel(c)) - return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m) + return DepthwiseConvDims{N, K, C_in, C_m, S, P, D, F}(I) end # This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M} # First, check that channel counts are all correct: - @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))") - @assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))") - @assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))") - @assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))") + @assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") + @assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))") + @assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))") + @assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))") # Next, check that the spatial dimensions match up - @assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))") - @assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))") - @assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))") + @assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))") + @assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))") + @assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))") # Finally, check that the batch size matches - @assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))") + @assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))") end \ No newline at end of file diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 0aacdec8f..5f5b7c4c3 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -1,4 +1,5 @@ ## This file contains direct Julia implementations of 2d and 3d convolutions +using Base.Threads # Helper functions for restricting x/w overreach function clamp_lo(x, w) @@ -57,50 +58,87 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) - # If we're doing crosscorr instead of conv, then don't bother to flip `w` - if !flipkernel(cdims) - w = w[end:-1:1, end:-1:1, end:-1:1, :, :] - end - + # Create a method that, at compile-time, determines how we're going to index into `w` + kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k + kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1 + # A helper function to project from output (w, h) to input (input_w, input_h) - @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - # explicit formulation of convolution. Oh hoisting gods, hear my plea. - @inbounds for batch in 1:size(x)[end], + # Use `calc_padding_regions` to determine where we do or don't need to worry about padding + padded_regions, central_region = calc_padding_regions(cdims) + + # Start with the central region + w_region, h_region, d_region = central_region + @inbounds for batch in 1:size(x, 5), + c_out in 1:out_c, + d_idx in d_region, + h_idx in h_region, + w_idx in w_region + + # Since we're in the central region, we don't need to worry about clamping + dotprod = yT(0) + for c_in in 1:channels_in(cdims), + kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + # Hoist me, you coward. + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w + + x_val = x[x_w, x_h, x_d, c_in, batch] + w_val = w[kproj(kw, kernel_w, cdims), + kproj(kh, kernel_h, cdims), + kproj(kd, kernel_d, cdims), + c_in, c_out] + dotprod = muladd(x_val, w_val, dotprod) + end + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + end + + # Next, do potentially-padded regions: + @inbounds for (w_region, h_region, d_region) in padded_regions, + batch in 1:size(x, 5), c_out in 1:out_c, - d_idx in 1:out_depth, - h_idx in 1:out_height, - w_idx in 1:out_width - - # Starting points of the window of x we're going to grab - x_w = project(w_idx, stride_w, pad_w_lo) - x_h = project(h_idx, stride_h, pad_h_lo) - x_d = project(d_idx, stride_d, pad_d_lo) - - # Grow that starting point into ranges - x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) - x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) - x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) - w_widxs = 1:kernel_w - w_hidxs = 1:kernel_h - w_didxs = 1:kernel_d - - # Clamp the ranges to simulate padding - x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) - x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) - x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) - x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) - x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) - x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) - - # Grab our slices - x_slice = view(x, x_widxs, x_hidxs, x_didxs, :, batch) - w_slice = view(w, w_widxs, w_hidxs, w_didxs, :, c_out) - - # Do the dotproduct dance, then weight by alpha/beta and git 'er done - dotprod = sum(x_slice .* w_slice) - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + - beta*y[w_idx, h_idx, d_idx, c_out, batch] + d_idx in d_region, + h_idx in h_region, + w_idx in w_region + + # Probe for out-of-bounds accesses on `x` and `continue` if we hit one + dotprod = yT(0) + for c_in in 1:channels_in(cdims), + kd in 1:kernel_d + + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d + if x_d <= 0 || x_d > depth + continue + end + + for kh in 1:kernel_h + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h + if x_h <= 0 || x_h > height + continue + end + + for kw in 1:kernel_w + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w + if x_w <= 0 || x_w > width + continue + end + + x_val = x[x_w, x_h, x_d, c_in, batch] + w_val = w[kproj(kw, kernel_w, cdims), + kproj(kh, kernel_h, cdims), + kproj(kd, kernel_d, cdims), + c_in, c_out] + dotprod = muladd(x_val, w_val, dotprod) + end + end + end + + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end return y diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index a932df8c1..17e6aaa61 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -46,7 +46,7 @@ function conv_im2col!( N = channels_out(cdims) K = prod(kernel_size(cdims))*channels_in(cdims) - @inbounds for batch_idx in 1:size(x,5) + @threads for batch_idx in 1:size(x,5) # We invoke `@timeit_debug` on the outside of `im2col!()` because inference # doesn't like us putting it on the inside. im2col!(col, view(x, :, :, :, :, batch_idx), cdims) @@ -94,7 +94,7 @@ function ∇conv_filter_im2col!( N = channels_out(cdims) K = prod(output_size(cdims)) - @inbounds for batch_idx in 1:size(x,5) + @threads for batch_idx in 1:size(x,5) im2col!(col, view(x, :, :, :, :, batch_idx), cdims) GC.@preserve col, dw, dy, begin col_ptr = pointer(col) @@ -142,7 +142,7 @@ function ∇conv_data_im2col!( N = prod(kernel_size(cdims))*channels_in(cdims) K = channels_out(cdims) - @inbounds for batch_idx in 1:size(dx, 5) + @threads for batch_idx in 1:size(dx, 5) GC.@preserve col, w, dy, begin dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1) w_ptr = pointer(w) diff --git a/src/impl/depthwiseconv_direct.jl b/src/impl/depthwiseconv_direct.jl index 5ebbb99c1..7e2e02bd5 100644 --- a/src/impl/depthwiseconv_direct.jl +++ b/src/impl/depthwiseconv_direct.jl @@ -18,10 +18,9 @@ channels in `x` is the last, not the second-to-last, as in a normal dense convol See the docstring for `conv_direct!()` for more on the optional parameters. """ -function depthwiseconv_direct!( - y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, - w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; - alpha::yT = yT(1), beta::yT = yT(0)) where {yT, xT, wT} +function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, + w::AbstractArray{wT,5}, cdims::DepthwiseConvDims; + alpha::yT = yT(1), beta = false) where {yT, xT, wT} check_dims(size(x), size(w), size(y), cdims) width, height, depth = input_size(cdims) @@ -32,54 +31,92 @@ function depthwiseconv_direct!( stride_w, stride_h, stride_d = stride(cdims) out_width, out_height, out_depth = output_size(cdims) - # If we're doing crosscorr instead of conv, then don't bother to flip `w` - if !flipkernel(cdims) - w = w[end:-1:1, end:-1:1, end:-1:1, :, :] - end - + # Create a method that, at compile-time, determines how we're going to index into `w` + kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,true}) where {N, K, C_mult, C_in, S, P, D} = k + kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,false}) where {N, K, C_mult, C_in, S, P, D} = M - k + 1 + # A helper function to project from output (w, h) to input (input_w, input_h) - @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + project(idx, stride, pad) = (idx - 1)*stride - pad + 1 - # explicit formulation of convolution. Oh hoisting gods, hear my plea. + # Use `calc_padding_regions` to determine where we do or don't need to worry about padding + padded_regions, central_region = calc_padding_regions(cdims) + + # Start with the central region + w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x)[end], c_mult in 1:channel_multiplier(cdims), c_in in 1:channels_in(cdims), - h_idx in 1:out_height, - w_idx in 1:out_width, - d_idx in 1:out_depth - - # Starting points of the window of x we're going to grab - x_w = project(w_idx, stride_w, pad_w_lo) - x_h = project(h_idx, stride_h, pad_h_lo) - x_d = project(d_idx, stride_d, pad_d_lo) - - # Grow that starting point into ranges - x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) - x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) - x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) - w_widxs = 1:kernel_w - w_hidxs = 1:kernel_h - w_didxs = 1:kernel_d - - # Clamp the ranges to simulate padding - x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) - x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) - x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) - x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) - x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) - x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) - - # Grab our slices (for a single channel pairing, as this is depthwise) + d_idx in d_region, + h_idx in h_region, + w_idx in w_region + + # Since we're in the central region, we don't need to worry about clamping + dotprod = yT(0) c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult - x_slice = view(x, x_widxs, x_hidxs, x_didxs, c_in, batch) - w_slice = view(w, w_widxs, w_hidxs, w_didxs, c_mult, c_in) - - # Do the dotproduct dance, then weight by alpha/beta and git 'er done - dotprod = sum(x_slice .* w_slice) - prev_yval::yT = beta*y[w_idx, h_idx, d_idx, c_out, batch] - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + prev_yval + for kd in 1:kernel_d, + kh in 1:kernel_h, + kw in 1:kernel_w + + # Hoist me, you coward. + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w + + x_val = x[x_w, x_h, x_d, c_in, batch] + w_val = w[kproj(kw, kernel_w, cdims), + kproj(kh, kernel_h, cdims), + kproj(kd, kernel_d, cdims), + c_mult, c_in] + dotprod = muladd(x_val, w_val, dotprod) + end + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] end - + + # Next, do potentially-padded regions: + @inbounds for (w_region, h_region, d_region) in padded_regions, + batch in 1:size(x)[end], + c_mult in 1:channel_multiplier(cdims), + c_in in 1:channels_in(cdims), + d_idx in d_region, + h_idx in h_region, + w_idx in w_region + + # Probe for out-of-bounds accesses on `x` and `continue` if we hit one + dotprod = yT(0) + c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult + for c_in in 1:channels_in(cdims), + kd in 1:kernel_d + + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d + if x_d <= 0 || x_d > depth + continue + end + + for kh in 1:kernel_h + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h + if x_h <= 0 || x_h > height + continue + end + + for kw in 1:kernel_w + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w + if x_w <= 0 || x_w > width + continue + end + + x_val = x[x_w, x_h, x_d, c_in, batch] + w_val = w[kproj(kw, kernel_w, cdims), + kproj(kh, kernel_h, cdims), + kproj(kd, kernel_d, cdims), + c_mult, c_in] + dotprod = muladd(x_val, w_val, dotprod) + end + end + end + + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] + end + return y end diff --git a/src/impl/padding_edges.jl b/src/impl/padding_edges.jl index 1d436ea56..df70645b7 100644 --- a/src/impl/padding_edges.jl +++ b/src/impl/padding_edges.jl @@ -26,10 +26,10 @@ function calc_padding_regions(dims) # spillage is slightly more complicated; we first figure out how many elements of # high padding are wasted (e.g. through strides not fitting to the end perfectly) # subtract that from the high padding, then do the same: - calc_lo_spill(O, S, P) = min(ceil(Int, P/S), O) + calc_lo_spill(O, S, P) = max(min(ceil(Int, P/S), O),0) @inline function calc_hi_spill(O, S, Pl, Ph, K, D, I) wasted_Ph = (I + Pl + Ph - (K - 1)*D - 1)%S - return min(ceil(Int, (Ph - wasted_Ph)/S), O) + return max(min(ceil(Int, (Ph - wasted_Ph)/S), O), 0) end spill_w_lo = calc_lo_spill(out_width, stride_w, pad_w_lo) diff --git a/test/pooling.jl b/test/pooling.jl index 1be1e094d..f7ada801c 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -300,9 +300,9 @@ end x = rand(10, 10, 3, 10) @test size(maxpool(x, (2, 2))) == (5, 5, 3, 10) -@test size(maxpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10) +@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10) @test size(meanpool(x, (2, 2))) == (5, 5, 3, 10) -@test size(meanpool(x, (2, 2); pad = (2, 2), stride = (2, 2))) == (7, 7, 3, 10) +@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10) # Add another test for 2d maxpool that uses an odd-length size: @testset "Issue #133" begin