diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 760933bb96..ee19640d76 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -295,7 +295,7 @@ end function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} b, o = m.b, size(h, 1) - g = m.Wi*x .+ m.Wh*h .+ b + g = muladd(m.Wi, x, muladd(m.Wh, h, b)) input, forget, cell, output = multigate(g, o, Val(4)) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′)