@@ -41,27 +41,41 @@ fn main() -> Result<(), Box<dyn Error>> {
41
41
42
42
// Load the saved model exported by regression_savedmodel.py.
43
43
let mut graph = Graph :: new ( ) ;
44
- let session =
45
- SavedModelBundle :: load ( & SessionOptions :: new ( ) , & [ "serve" ] , & mut graph, export_dir) ?. session ;
46
- let op_x = graph. operation_by_name_required ( "train_x" ) ?;
47
- let op_y = graph. operation_by_name_required ( "train_y" ) ?;
48
- let op_train = graph. operation_by_name_required ( "StatefulPartitionedCall" ) ?;
49
- let op_w = graph. operation_by_name_required ( "StatefulPartitionedCall_1" ) ?;
50
- let op_b = graph. operation_by_name_required ( "StatefulPartitionedCall_1" ) ?;
44
+ let bundle =
45
+ SavedModelBundle :: load ( & SessionOptions :: new ( ) , & [ "serve" ] , & mut graph, export_dir) ?;
46
+ let session = & bundle. session ;
47
+
48
+ let train_signature = bundle. meta_graph_def ( ) . get_signature ( "train" ) ?;
49
+ let x_info = train_signature. get_input ( "x" ) ?;
50
+ let y_info = train_signature. get_input ( "y" ) ?;
51
+ let train_info = train_signature. get_output ( "train" ) ?;
52
+ let op_x = graph. operation_by_name_required ( & x_info. name ( ) . name ) ?;
53
+ let op_y = graph. operation_by_name_required ( & y_info. name ( ) . name ) ?;
54
+ let op_train = graph. operation_by_name_required ( & train_info. name ( ) . name ) ?;
55
+ let w_info = bundle
56
+ . meta_graph_def ( )
57
+ . get_signature ( "w" ) ?
58
+ . get_output ( "output" ) ?;
59
+ let op_w = graph. operation_by_name_required ( & w_info. name ( ) . name ) ?;
60
+ let b_info = bundle
61
+ . meta_graph_def ( )
62
+ . get_signature ( "b" ) ?
63
+ . get_output ( "output" ) ?;
64
+ let op_b = graph. operation_by_name_required ( & b_info. name ( ) . name ) ?;
51
65
52
66
// Train the model (e.g. for fine tuning).
53
67
let mut train_step = SessionRunArgs :: new ( ) ;
54
68
train_step. add_feed ( & op_x, 0 , & x) ;
55
69
train_step. add_feed ( & op_y, 0 , & y) ;
56
- train_step. request_fetch ( & op_train, 0 ) ;
70
+ train_step. add_target ( & op_train) ;
57
71
for _ in 0 ..steps {
58
72
session. run ( & mut train_step) ?;
59
73
}
60
74
61
75
// Grab the data out of the session.
62
76
let mut output_step = SessionRunArgs :: new ( ) ;
63
77
let w_ix = output_step. request_fetch ( & op_w, 0 ) ;
64
- let b_ix = output_step. request_fetch ( & op_b, 1 ) ;
78
+ let b_ix = output_step. request_fetch ( & op_b, 0 ) ;
65
79
session. run ( & mut output_step) ?;
66
80
67
81
// Check our results.
0 commit comments