Skip to content

Commit 0169e9a

Browse files
committed
opの名前の取得方法をsignature経由に変更する
1 parent 3165ada commit 0169e9a

File tree

5 files changed

+36
-18
lines changed

5 files changed

+36
-18
lines changed

examples/regression_savedmodel.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,41 @@ fn main() -> Result<(), Box<dyn Error>> {
4141

4242
// Load the saved model exported by regression_savedmodel.py.
4343
let mut graph = Graph::new();
44-
let session =
45-
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session;
46-
let op_x = graph.operation_by_name_required("train_x")?;
47-
let op_y = graph.operation_by_name_required("train_y")?;
48-
let op_train = graph.operation_by_name_required("StatefulPartitionedCall")?;
49-
let op_w = graph.operation_by_name_required("StatefulPartitionedCall_1")?;
50-
let op_b = graph.operation_by_name_required("StatefulPartitionedCall_1")?;
44+
let bundle =
45+
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
46+
let session = &bundle.session;
47+
48+
let train_signature = bundle.meta_graph_def().get_signature("train")?;
49+
let x_info = train_signature.get_input("x")?;
50+
let y_info = train_signature.get_input("y")?;
51+
let train_info = train_signature.get_output("train")?;
52+
let op_x = graph.operation_by_name_required(&x_info.name().name)?;
53+
let op_y = graph.operation_by_name_required(&y_info.name().name)?;
54+
let op_train = graph.operation_by_name_required(&train_info.name().name)?;
55+
let w_info = bundle
56+
.meta_graph_def()
57+
.get_signature("w")?
58+
.get_output("output")?;
59+
let op_w = graph.operation_by_name_required(&w_info.name().name)?;
60+
let b_info = bundle
61+
.meta_graph_def()
62+
.get_signature("b")?
63+
.get_output("output")?;
64+
let op_b = graph.operation_by_name_required(&b_info.name().name)?;
5165

5266
// Train the model (e.g. for fine tuning).
5367
let mut train_step = SessionRunArgs::new();
5468
train_step.add_feed(&op_x, 0, &x);
5569
train_step.add_feed(&op_y, 0, &y);
56-
train_step.request_fetch(&op_train, 0);
70+
train_step.add_target(&op_train);
5771
for _ in 0..steps {
5872
session.run(&mut train_step)?;
5973
}
6074

6175
// Grab the data out of the session.
6276
let mut output_step = SessionRunArgs::new();
6377
let w_ix = output_step.request_fetch(&op_w, 0);
64-
let b_ix = output_step.request_fetch(&op_b, 1);
78+
let b_ix = output_step.request_fetch(&op_b, 0);
6579
session.run(&mut output_step)?;
6680

6781
// Check our results.

examples/regression_savedmodel/regression_savedmodel.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,39 @@ def __call__(self, x):
1414
return y_hat
1515

1616
@tf.function
17-
def get_weights(self):
18-
return self.w, self.b
17+
def get_w(self):
18+
return {"output": self.w}
19+
20+
@tf.function
21+
def get_b(self):
22+
return {"output": self.b}
1923

2024
@tf.function
2125
def train(self, x, y):
2226
with tf.GradientTape() as tape:
2327
y_hat = self(x)
2428
loss = tf.reduce_mean(tf.square(y_hat - y))
2529
grads = tape.gradient(loss, self.trainable_variables)
26-
_ = self.optimizer.apply_gradients(
27-
zip(grads, self.trainable_variables), name="train"
28-
)
29-
return loss
30+
_ = self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
31+
return {"train": loss}
3032

3133

3234
model = LinearRegresstion()
3335

3436
x = tf.TensorSpec([None], tf.float32, name="x")
3537
y = tf.TensorSpec([None], tf.float32, name="y")
3638
train = model.train.get_concrete_function(x, y)
37-
weights = model.get_weights.get_concrete_function()
39+
w = model.get_w.get_concrete_function()
40+
b = model.get_b.get_concrete_function()
3841

3942
directory = "examples/regression_savedmodel"
40-
signatures = {"train": train, "weights": weights}
43+
signatures = {"train": train, "w": w, "b": b}
4144
tf.saved_model.save(model, directory, signatures=signatures)
4245

4346
# export graph info to TensorBoard
4447
logdir = "logs/regression_savedmodel"
4548
writer = tf.summary.create_file_writer(logdir)
4649
with writer.as_default():
4750
tf.summary.graph(train.graph)
48-
tf.summary.graph(weights.graph)
51+
tf.summary.graph(w.graph)
52+
tf.summary.graph(b.graph)
-5.14 KB
Binary file not shown.
-39 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)