Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions example/merge_networks.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
program merge_networks
use nf, only: dense, input, network, sgd
use nf_dense_layer, only: dense_layer
implicit none

type(network) :: net1, net2, net3
real, allocatable :: x1(:), x2(:)
real, pointer :: y1(:), y2(:)
real, allocatable :: y(:)
integer, parameter :: num_iterations = 500
integer :: n, nn
integer :: net1_output_size, net2_output_size

x1 = [0.1, 0.3, 0.5]
x2 = [0.2, 0.4]
y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852]

net1 = network([ &
input(3), &
dense(2), &
dense(3), &
dense(2) &
])

net2 = network([ &
input(2), &
dense(5), &
dense(3) &
])

net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape)
net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape)

! Network 3
net3 = network([ &
input(net1_output_size + net2_output_size), &
dense(7) &
])

do n = 1, num_iterations

! Forward propagate two network branches
call net1 % forward(x1)
call net2 % forward(x2)

! Get outputs of net1 and net2, concatenate, and pass to net3
call net1 % get_output(y1)
call net2 % get_output(y2)
call net3 % forward([y1, y2])

! First compute the gradients on net3, then pass the gradients from the first
! hidden layer on net3 to net1 and net2, and compute their gradients.
call net3 % backward(y)

select type (next_layer => net3 % layers(2) % p)
type is (dense_layer)
call net1 % backward(y, gradient=next_layer % gradient(1:net1_output_size))
call net2 % backward(y, gradient=next_layer % gradient(net1_output_size+1:size(next_layer % gradient)))
end select

! Gradients are now computed on all networks and we can update the weights
call net1 % update(optimizer=sgd(learning_rate=1.))
call net2 % update(optimizer=sgd(learning_rate=1.))
call net3 % update(optimizer=sgd(learning_rate=1.))

if (mod(n, 50) == 0) then
print *, "Iteration ", n, ", output RMSE = ", &
sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y))
end if

end do

end program merge_networks
26 changes: 21 additions & 5 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ module nf_network
procedure, private :: forward_1d_int
procedure, private :: forward_2d
procedure, private :: forward_3d
procedure, private :: get_output_1d
procedure, private :: predict_1d
procedure, private :: predict_1d_int
procedure, private :: predict_2d
Expand All @@ -42,6 +43,7 @@ module nf_network

generic :: evaluate => evaluate_batch_1d
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
generic :: get_output => get_output_1d
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
generic :: predict_batch => predict_batch_1d, predict_batch_3d

Expand Down Expand Up @@ -131,7 +133,7 @@ end subroutine forward_3d

end interface forward

interface output
interface predict

module function predict_1d(self, input) result(res)
!! Return the output of the network given the input 1-d array.
Expand Down Expand Up @@ -169,9 +171,10 @@ module function predict_3d(self, input) result(res)
real, allocatable :: res(:)
!! Output of the network
end function predict_3d
end interface output

interface output_batch
end interface predict

interface predict_batch
module function predict_batch_1d(self, input) result(res)
!! Return the output of the network given an input batch of 3-d data.
class(network), intent(in out) :: self
Expand All @@ -191,11 +194,18 @@ module function predict_batch_3d(self, input) result(res)
real, allocatable :: res(:,:)
!! Output of the network; the last dimension is the batch
end function predict_batch_3d
end interface output_batch
end interface predict_batch

interface get_output
module subroutine get_output_1d(self, output)
class(network), intent(in), target :: self
real, pointer, intent(out) :: output(:)
end subroutine get_output_1d
end interface get_output

interface

module subroutine backward(self, output, loss)
module subroutine backward(self, output, loss, gradient)
!! Apply one backward pass through the network.
!! This changes the state of layers on the network.
!! Typically used only internally from the `train` method,
Expand All @@ -206,6 +216,12 @@ module subroutine backward(self, output, loss)
!! Output data
class(loss_type), intent(in), optional :: loss
!! Loss instance to use. If not provided, the default is quadratic().
real, intent(in), optional :: gradient(:)
!! Gradient to use for the output layer.
!! If not provided, the gradient in the last layer is computed using
!! the loss function.
!! Passing the gradient is useful for merging/concatenating multiple
!! networks.
end subroutine backward

module integer function get_num_params(self)
Expand Down
140 changes: 88 additions & 52 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res)
end function network_from_layers


module subroutine backward(self, output, loss)
module subroutine backward(self, output, loss, gradient)
class(network), intent(in out) :: self
real, intent(in) :: output(:)
class(loss_type), intent(in), optional :: loss
real, intent(in), optional :: gradient(:)
integer :: n, num_layers

! Passing the loss instance is optional. If not provided, and if the
Expand All @@ -140,58 +141,71 @@ module subroutine backward(self, output, loss)

! Iterate backward over layers, from the output layer
! to the first non-input layer
do n = num_layers, 2, -1

if (n == num_layers) then
! Output layer; apply the loss function
select type(this_layer => self % layers(n) % p)
type is(dense_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
type is(flatten_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
end select
else
! Hidden layer; take the gradient from the next layer
select type(next_layer => self % layers(n + 1) % p)
type is(dense_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(dropout_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(conv2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(flatten_layer)
if (size(self % layers(n) % layer_shape) == 2) then
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
else
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
end if
type is(maxpool2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(reshape3d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(linear2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(self_attention_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(maxpool1d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(reshape2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(conv1d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(locally_connected2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(layernorm_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
end select
end if

! Output layer first
n = num_layers
if (present(gradient)) then

! If the gradient is passed, use it directly for the output layer
select type(this_layer => self % layers(n) % p)
type is(dense_layer)
call self % layers(n) % backward(self % layers(n - 1), gradient)
type is(flatten_layer)
call self % layers(n) % backward(self % layers(n - 1), gradient)
end select

else

! Apply the loss function
select type(this_layer => self % layers(n) % p)
type is(dense_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
type is(flatten_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
self % loss % derivative(output, this_layer % output) &
)
end select

end if

! Hidden layers; take the gradient from the next layer
do n = num_layers - 1, 2, -1
select type(next_layer => self % layers(n + 1) % p)
type is(dense_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(dropout_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(conv2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(flatten_layer)
if (size(self % layers(n) % layer_shape) == 2) then
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
else
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
end if
type is(maxpool2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(reshape3d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(linear2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(self_attention_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(maxpool1d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(reshape2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(conv1d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(locally_connected2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
type is(layernorm_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
end select
end do

end subroutine backward
Expand Down Expand Up @@ -497,6 +511,28 @@ module subroutine print_info(self)
end subroutine print_info


module subroutine get_output_1d(self, output)
class(network), intent(in), target :: self
real, pointer, intent(out) :: output(:)
integer :: last

last = size(self % layers)

select type(output_layer => self % layers(last) % p)
type is(dense_layer)
output => output_layer % output
type is(dropout_layer)
output => output_layer % output
type is(flatten_layer)
output => output_layer % output
class default
error stop 'network % get_output not implemented for ' // &
trim(self % layers(last) % name) // ' layer'
end select

end subroutine get_output_1d


module function get_num_params(self)
class(network), intent(in) :: self
integer :: get_num_params
Expand Down