@@ -99,7 +99,6 @@ def sgd_train_linear_model(
9999 This will return the final training loss (averaged with
100100 `running_loss_window`)
101101 """
102-
103102 loss_window : List [torch .Tensor ] = []
104103 min_avg_loss = None
105104 convergence_counter = 0
@@ -145,77 +144,77 @@ def get_point(datapoint):
145144 if model .linear .bias is not None :
146145 model .linear .bias .zero_ ()
147146
148- optim = torch .optim .SGD (model .parameters (), lr = initial_lr )
149- if reduce_lr :
150- scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
151- optim , factor = 0.5 , patience = patience , threshold = threshold
152- )
153-
154- t1 = time .time ()
155- epoch = 0
156- i = 0
157- while epoch < max_epoch :
158- while True : # for x, y, w in dataloader
159- if running_loss_window is None :
160- running_loss_window = x .shape [0 ] * len (dataloader )
161-
162- y = y .view (x .shape [0 ], - 1 )
163- if w is not None :
164- w = w .view (x .shape [0 ], - 1 )
165-
166- i += 1
167-
168- out = model (x )
169-
170- loss = loss_fn (y , out , w )
171- if reg_term is not None :
172- reg = torch .norm (model .linear .weight , p = reg_term )
173- loss += reg .sum () * alpha
174-
175- if len (loss_window ) >= running_loss_window :
176- loss_window = loss_window [1 :]
177- loss_window .append (loss .clone ().detach ())
178- assert len (loss_window ) <= running_loss_window
179-
180- average_loss = torch .mean (torch .stack (loss_window ))
181- if min_avg_loss is not None :
182- # if we haven't improved by at least `threshold`
183- if average_loss > min_avg_loss or torch .isclose (
184- min_avg_loss , average_loss , atol = threshold
185- ):
186- convergence_counter += 1
187- if convergence_counter >= patience :
188- converged = True
189- break
190- else :
191- convergence_counter = 0
192- if min_avg_loss is None or min_avg_loss >= average_loss :
193- min_avg_loss = average_loss .clone ()
194-
195- if debug :
196- print (
197- f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } ,"
198- + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199- )
200-
201- loss .backward ()
202-
203- optim .step ()
204- model .zero_grad ()
205- if scheduler :
206- scheduler .step (average_loss )
207-
208- temp = next (data_iter , None )
209- if temp is None :
147+ with torch .enable_grad ():
148+ optim = torch .optim .SGD (model .parameters (), lr = initial_lr )
149+ if reduce_lr :
150+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
151+ optim , factor = 0.5 , patience = patience , threshold = threshold
152+ )
153+
154+ t1 = time .time ()
155+ epoch = 0
156+ i = 0
157+ while epoch < max_epoch :
158+ while True : # for x, y, w in dataloader
159+ if running_loss_window is None :
160+ running_loss_window = x .shape [0 ] * len (dataloader )
161+
162+ y = y .view (x .shape [0 ], - 1 )
163+ if w is not None :
164+ w = w .view (x .shape [0 ], - 1 )
165+
166+ i += 1
167+
168+ out = model (x )
169+
170+ loss = loss_fn (y , out , w )
171+ if reg_term is not None :
172+ reg = torch .norm (model .linear .weight , p = reg_term )
173+ loss += reg .sum () * alpha
174+
175+ if len (loss_window ) >= running_loss_window :
176+ loss_window = loss_window [1 :]
177+ loss_window .append (loss .clone ().detach ())
178+ assert len (loss_window ) <= running_loss_window
179+
180+ average_loss = torch .mean (torch .stack (loss_window ))
181+ if min_avg_loss is not None :
182+ # if we haven't improved by at least `threshold`
183+ if average_loss > min_avg_loss or torch .isclose (
184+ min_avg_loss , average_loss , atol = threshold
185+ ):
186+ convergence_counter += 1
187+ if convergence_counter >= patience :
188+ converged = True
189+ break
190+ else :
191+ convergence_counter = 0
192+ if min_avg_loss is None or min_avg_loss >= average_loss :
193+ min_avg_loss = average_loss .clone ()
194+
195+ if debug :
196+ print (
197+ f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } ,"
198+ + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199+ )
200+
201+ loss .backward ()
202+ optim .step ()
203+ model .zero_grad ()
204+ if scheduler :
205+ scheduler .step (average_loss )
206+
207+ temp = next (data_iter , None )
208+ if temp is None :
209+ break
210+ x , y , w = get_point (temp )
211+
212+ if converged :
210213 break
211- x , y , w = get_point (temp )
212-
213- if converged :
214- break
215214
216- epoch += 1
217- data_iter = iter (dataloader )
218- x , y , w = get_point (next (data_iter ))
215+ epoch += 1
216+ data_iter = iter (dataloader )
217+ x , y , w = get_point (next (data_iter ))
219218
220219 t2 = time .time ()
221220 return {
0 commit comments