From 81c76b51ab67765b8c4e81eb1207b83eaab3ce60 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 23 Jan 2023 10:49:22 -0800 Subject: [PATCH] add checkAlive in NasBertTrainer --- .../NasBert/NasBertTrainer.cs | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs index 8191e6979c..c2a229413d 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs @@ -249,7 +249,8 @@ public override NasBertTransformer Fit(IDataView input) for (int i = 0; i < Option.MaxEpoch; i++) { ch.Trace($"Starting epoch {i}"); - trainer.Train(input); + Host.CheckAlive(); + trainer.Train(Host, input); ch.Trace($"Finished epoch {i}"); if (Option.ValidationSet != null) trainer.Validate(pch, ch, i); @@ -423,7 +424,7 @@ private bool ValidateStep(DataViewRowCursor cursor, return cursorValid; } - public void Train(IDataView input) + public void Train(IHost host, IDataView input) { // Get the cursor and the correct columns based on the inputs DataViewRowCursor cursor = default; @@ -443,14 +444,15 @@ public void Train(IDataView input) var cursorValid = true; while (cursorValid) { - cursorValid = TrainStep(cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets); + cursorValid = TrainStep(host, cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets); } } - private bool TrainStep(DataViewRowCursor cursor, - ValueGetter> sentence1Getter, - ValueGetter> sentence2Getter, - ValueGetter labelGetter, + private bool TrainStep(IHost host, + DataViewRowCursor cursor, + ValueGetter> sentence1Getter, + ValueGetter> sentence2Getter, + ValueGetter labelGetter, ref List inputTensors, ref List targets) { @@ -461,6 +463,7 @@ private bool TrainStep(DataViewRowCursor cursor, var cursorValid = true; for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++) { + host.CheckAlive(); cursorValid = cursor.MoveNext(); if (cursorValid) { @@ -479,7 +482,7 @@ private bool TrainStep(DataViewRowCursor cursor, } Updates++; - + host.CheckAlive(); torch.random.manual_seed(1 + Updates); torch.cuda.manual_seed(1 + Updates); Model.train(); @@ -497,8 +500,10 @@ private bool TrainStep(DataViewRowCursor cursor, loss = torch.nn.MSELoss(reduction: Parent.Option.Reduction).forward(logits, targetsTensor); logits = logits.squeeze(); } - + host.CheckAlive(); loss.backward(); + + host.CheckAlive(); OptimizeStep(); return cursorValid;