@@ -640,6 +640,8 @@ end function get_num_params
640640 module function get_params (self ) result(params)
641641 class(layer), intent (in ) :: self
642642 real , allocatable :: params(:)
643+ real , pointer :: w_ptr(:)
644+ real , pointer :: b_ptr(:)
643645
644646 select type (this_layer = > self % p)
645647 type is (input1d_layer)
@@ -649,15 +651,27 @@ module function get_params(self) result(params)
649651 type is (input3d_layer)
650652 ! No parameters to get.
651653 type is (dense_layer)
652- params = this_layer % get_params()
654+ call this_layer % get_params_ptr(w_ptr, b_ptr)
655+ allocate (params(size (w_ptr) + size (b_ptr)))
656+ params(1 :size (w_ptr)) = w_ptr
657+ params(size (w_ptr)+ 1 :) = b_ptr
653658 type is (dropout_layer)
654659 ! No parameters to get.
655660 type is (conv1d_layer)
656- params = this_layer % get_params()
661+ call this_layer % get_params_ptr(w_ptr, b_ptr)
662+ allocate (params(size (w_ptr) + size (b_ptr)))
663+ params(1 :size (w_ptr)) = w_ptr
664+ params(size (w_ptr)+ 1 :) = b_ptr
657665 type is (conv2d_layer)
658- params = this_layer % get_params()
666+ call this_layer % get_params_ptr(w_ptr, b_ptr)
667+ allocate (params(size (w_ptr) + size (b_ptr)))
668+ params(1 :size (w_ptr)) = w_ptr
669+ params(size (w_ptr)+ 1 :) = b_ptr
659670 type is (locally_connected2d_layer)
660- params = this_layer % get_params()
671+ call this_layer % get_params_ptr(w_ptr, b_ptr)
672+ allocate (params(size (w_ptr) + size (b_ptr)))
673+ params(1 :size (w_ptr)) = w_ptr
674+ params(size (w_ptr)+ 1 :) = b_ptr
661675 type is (maxpool1d_layer)
662676 ! No parameters to get.
663677 type is (maxpool2d_layer)
@@ -669,7 +683,10 @@ module function get_params(self) result(params)
669683 type is (reshape3d_layer)
670684 ! No parameters to get.
671685 type is (linear2d_layer)
672- params = this_layer % get_params()
686+ call this_layer % get_params_ptr(w_ptr, b_ptr)
687+ allocate (params(size (w_ptr) + size (b_ptr)))
688+ params(1 :size (w_ptr)) = w_ptr
689+ params(size (w_ptr)+ 1 :) = b_ptr
673690 type is (self_attention_layer)
674691 params = this_layer % get_params()
675692 type is (embedding_layer)
0 commit comments