diff --git a/example/merge_networks.f90 b/example/merge_networks.f90 new file mode 100644 index 00000000..590deb26 --- /dev/null +++ b/example/merge_networks.f90 @@ -0,0 +1,73 @@ +program merge_networks + use nf, only: dense, input, network, sgd + use nf_dense_layer, only: dense_layer + implicit none + + type(network) :: net1, net2, net3 + real, allocatable :: x1(:), x2(:) + real, pointer :: y1(:), y2(:) + real, allocatable :: y(:) + integer, parameter :: num_iterations = 500 + integer :: n, nn + integer :: net1_output_size, net2_output_size + + x1 = [0.1, 0.3, 0.5] + x2 = [0.2, 0.4] + y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852] + + net1 = network([ & + input(3), & + dense(2), & + dense(3), & + dense(2) & + ]) + + net2 = network([ & + input(2), & + dense(5), & + dense(3) & + ]) + + net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape) + net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape) + + ! Network 3 + net3 = network([ & + input(net1_output_size + net2_output_size), & + dense(7) & + ]) + + do n = 1, num_iterations + + ! Forward propagate two network branches + call net1 % forward(x1) + call net2 % forward(x2) + + ! Get outputs of net1 and net2, concatenate, and pass to net3 + call net1 % get_output(y1) + call net2 % get_output(y2) + call net3 % forward([y1, y2]) + + ! First compute the gradients on net3, then pass the gradients from the first + ! hidden layer on net3 to net1 and net2, and compute their gradients. + call net3 % backward(y) + + select type (next_layer => net3 % layers(2) % p) + type is (dense_layer) + call net1 % backward(y, gradient=next_layer % gradient(1:net1_output_size)) + call net2 % backward(y, gradient=next_layer % gradient(net1_output_size+1:size(next_layer % gradient))) + end select + + ! Gradients are now computed on all networks and we can update the weights + call net1 % update(optimizer=sgd(learning_rate=1.)) + call net2 % update(optimizer=sgd(learning_rate=1.)) + call net3 % update(optimizer=sgd(learning_rate=1.)) + + if (mod(n, 50) == 0) then + print *, "Iteration ", n, ", output RMSE = ", & + sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y)) + end if + + end do + +end program merge_networks \ No newline at end of file diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index 2743ff5b..3cfeb521 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -33,6 +33,7 @@ module nf_network procedure, private :: forward_1d_int procedure, private :: forward_2d procedure, private :: forward_3d + procedure, private :: get_output_1d procedure, private :: predict_1d procedure, private :: predict_1d_int procedure, private :: predict_2d @@ -42,6 +43,7 @@ module nf_network generic :: evaluate => evaluate_batch_1d generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d + generic :: get_output => get_output_1d generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d generic :: predict_batch => predict_batch_1d, predict_batch_3d @@ -131,7 +133,7 @@ end subroutine forward_3d end interface forward - interface output + interface predict module function predict_1d(self, input) result(res) !! Return the output of the network given the input 1-d array. @@ -169,9 +171,10 @@ module function predict_3d(self, input) result(res) real, allocatable :: res(:) !! Output of the network end function predict_3d - end interface output - interface output_batch + end interface predict + + interface predict_batch module function predict_batch_1d(self, input) result(res) !! Return the output of the network given an input batch of 3-d data. class(network), intent(in out) :: self @@ -191,11 +194,18 @@ module function predict_batch_3d(self, input) result(res) real, allocatable :: res(:,:) !! Output of the network; the last dimension is the batch end function predict_batch_3d - end interface output_batch + end interface predict_batch + + interface get_output + module subroutine get_output_1d(self, output) + class(network), intent(in), target :: self + real, pointer, intent(out) :: output(:) + end subroutine get_output_1d + end interface get_output interface - module subroutine backward(self, output, loss) + module subroutine backward(self, output, loss, gradient) !! Apply one backward pass through the network. !! This changes the state of layers on the network. !! Typically used only internally from the `train` method, @@ -206,6 +216,12 @@ module subroutine backward(self, output, loss) !! Output data class(loss_type), intent(in), optional :: loss !! Loss instance to use. If not provided, the default is quadratic(). + real, intent(in), optional :: gradient(:) + !! Gradient to use for the output layer. + !! If not provided, the gradient in the last layer is computed using + !! the loss function. + !! Passing the gradient is useful for merging/concatenating multiple + !! networks. end subroutine backward module integer function get_num_params(self) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index df95963a..13df77c0 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res) end function network_from_layers - module subroutine backward(self, output, loss) + module subroutine backward(self, output, loss, gradient) class(network), intent(in out) :: self real, intent(in) :: output(:) class(loss_type), intent(in), optional :: loss + real, intent(in), optional :: gradient(:) integer :: n, num_layers ! Passing the loss instance is optional. If not provided, and if the @@ -140,58 +141,71 @@ module subroutine backward(self, output, loss) ! Iterate backward over layers, from the output layer ! to the first non-input layer - do n = num_layers, 2, -1 - - if (n == num_layers) then - ! Output layer; apply the loss function - select type(this_layer => self % layers(n) % p) - type is(dense_layer) - call self % layers(n) % backward( & - self % layers(n - 1), & - self % loss % derivative(output, this_layer % output) & - ) - type is(flatten_layer) - call self % layers(n) % backward( & - self % layers(n - 1), & - self % loss % derivative(output, this_layer % output) & - ) - end select - else - ! Hidden layer; take the gradient from the next layer - select type(next_layer => self % layers(n + 1) % p) - type is(dense_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(dropout_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(conv2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(flatten_layer) - if (size(self % layers(n) % layer_shape) == 2) then - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d) - else - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d) - end if - type is(maxpool2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(reshape3d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(linear2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(self_attention_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(maxpool1d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(reshape2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(conv1d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(locally_connected2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(layernorm_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - end select - end if + ! Output layer first + n = num_layers + if (present(gradient)) then + + ! If the gradient is passed, use it directly for the output layer + select type(this_layer => self % layers(n) % p) + type is(dense_layer) + call self % layers(n) % backward(self % layers(n - 1), gradient) + type is(flatten_layer) + call self % layers(n) % backward(self % layers(n - 1), gradient) + end select + + else + + ! Apply the loss function + select type(this_layer => self % layers(n) % p) + type is(dense_layer) + call self % layers(n) % backward( & + self % layers(n - 1), & + self % loss % derivative(output, this_layer % output) & + ) + type is(flatten_layer) + call self % layers(n) % backward( & + self % layers(n - 1), & + self % loss % derivative(output, this_layer % output) & + ) + end select + + end if + + ! Hidden layers; take the gradient from the next layer + do n = num_layers - 1, 2, -1 + select type(next_layer => self % layers(n + 1) % p) + type is(dense_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(dropout_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(conv2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(flatten_layer) + if (size(self % layers(n) % layer_shape) == 2) then + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d) + else + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d) + end if + type is(maxpool2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(reshape3d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(linear2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(self_attention_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(maxpool1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(reshape2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(conv1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(locally_connected2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(layernorm_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + end select end do end subroutine backward @@ -497,6 +511,42 @@ module subroutine print_info(self) end subroutine print_info + module subroutine get_output_1d(self, output) + class(network), intent(in), target :: self + real, pointer, intent(out) :: output(:) + integer :: last + + last = size(self % layers) + + select type(output_layer => self % layers(last) % p) + type is (conv1d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is(conv2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (dense_layer) + output => output_layer % output + type is (dropout_layer) + output => output_layer % output + type is (flatten_layer) + output => output_layer % output + type is (layernorm_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (linear2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (locally_connected2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (maxpool1d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (maxpool2d_layer) + output(1:size(output_layer % output)) => output_layer % output + class default + error stop 'network % get_output not implemented for ' // & + trim(self % layers(last) % name) // ' layer' + end select + + end subroutine get_output_1d + + module function get_num_params(self) class(network), intent(in) :: self integer :: get_num_params