4
4
import torch
5
5
import torch .distributed as dist
6
6
import torch .nn as nn
7
+ from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
7
8
from torch .optim import Optimizer
8
9
from torch .optim .lr_scheduler import LRScheduler
9
10
@@ -40,23 +41,26 @@ def __init__(
40
41
def _train_batch (
41
42
self ,
42
43
batch : DatasetBatch ,
43
- model : nn . Module ,
44
+ model : FSDP ,
44
45
optimizer : Optimizer ,
45
46
scheduler : LRScheduler ,
46
47
loss_fun : Loss ,
47
48
train_step_id : int ,
48
49
data_loader : LLMDataLoader ,
49
50
) -> Tuple [torch .Tensor , torch .Tensor ]:
50
51
result_batch = model_predict_batch (model = model , batch = batch )
51
- loss = loss_fun (result_batch ) / self .gradient_acc_steps
52
- loss .backward ()
53
- gradient_norm_score = self .gradient_clipper (model )
52
+ loss = loss_fun (result_batch )
53
+ (loss / self .gradient_acc_steps ).backward ()
54
54
55
55
if (train_step_id + 1 ) % self .gradient_acc_steps == 0 or (train_step_id + 1 ) == len (data_loader ):
56
+ # gradient_norm_score = self.gradient_clipper(model)
57
+ gradient_norm_score = model .clip_grad_norm_ (max_norm = 1 , norm_type = 2 ).sum ()
56
58
optimizer .step ()
57
59
scheduler .step ()
58
60
optimizer .zero_grad ()
59
- return loss , gradient_norm_score
61
+ return loss , gradient_norm_score
62
+ else :
63
+ return loss , None
60
64
61
65
def train (
62
66
self ,
@@ -82,6 +86,7 @@ def train(
82
86
dist .barrier ()
83
87
forward_backward_time_recorder = TimeRecorder ()
84
88
forward_backward_time_recorder .start ()
89
+ gradient_norm_scores = []
85
90
for batch_id , batch in enumerate (train_loader ):
86
91
# Because we might resume training, we add the starting batch id of the data loader
87
92
train_step_id = batch_id + train_loader .fast_forward_batch_id
@@ -98,9 +103,13 @@ def train(
98
103
forward_backward_time_recorder .stop ()
99
104
# Save the batch loss
100
105
cumulated_loss_and_gradient_norm [0 ] += batch_loss .item ()
101
- cumulated_loss_and_gradient_norm [1 ] += gradient_norm_score .item ()
102
106
# This works, because we always drop the last batch in case it has less samples than the batch size
103
107
cumulated_loss_and_gradient_norm [- 1 ] += 1 # number of local batches
108
+
109
+ # gradient norm is already synced across all ranks
110
+ if gradient_norm_score is not None :
111
+ gradient_norm_scores .append (gradient_norm_score .item ())
112
+
104
113
batch_length_tensor = torch .tensor (len (batch )).to (device )
105
114
thoughput_aggregator .add_value (key = ThroughputAggregationKeys .NUM_SAMPLES , value = batch_length_tensor )
106
115
@@ -124,35 +133,37 @@ def train(
124
133
)
125
134
synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time
126
135
# TODO: insert reducer from outside so Trainer is independent of FSDP
127
- cumulated_loss_and_gradient_norm [ 2 ] = batch_loss . item ()
128
- cumulated_loss_and_gradient_norm [3 ] = gradient_norm_score .item ()
136
+ # add the loss and gradient norm for the LAST batch
137
+ cumulated_loss_and_gradient_norm [1 ] = batch_loss .item ()
129
138
130
139
reduced_loss_and_gradient_norm = Reducer .reduce (
131
140
tensor = cumulated_loss_and_gradient_norm ,
132
141
operation = dist .ReduceOp .SUM ,
133
- # divide the first two elements by the last one
134
- # i.e., summed batch loss / (num batches * world size)
135
- # and summed gradient norm/ (num batches * world size).
136
- # keep the other elements as is
137
- post_processing_fun = lambda t : torch .cat ((t [:2 ] / t [- 1 ], t [2 :- 1 ] / dist .get_world_size ())),
142
+ # 1.) summed batch loss / (num batches * world size)
143
+ # 2.) last batch loss / world size
144
+ post_processing_fun = lambda t : torch .stack ([t [0 ] / t [- 1 ], t [1 ] / dist .get_world_size ()]),
138
145
)
139
146
140
- train_loss_avg , train_gradient_norm_avg , train_loss_last_batch , train_gradient_norm_last_batch = (
147
+ train_loss_avg , train_loss_last_batch = (
141
148
reduced_loss_and_gradient_norm [0 ],
142
149
reduced_loss_and_gradient_norm [1 ],
143
- reduced_loss_and_gradient_norm [2 ],
144
- reduced_loss_and_gradient_norm [3 ],
145
150
)
151
+ losses = {
152
+ f"{ loss_fun .tag } average" : train_loss_avg ,
153
+ f"{ loss_fun .tag } last step" : train_loss_last_batch ,
154
+ }
155
+ if len (gradient_norm_scores ) > 0 :
156
+ metrics = {
157
+ "grad_norm_avg" : torch .mean (torch .Tensor (gradient_norm_scores )),
158
+ "grad_norm_last_batch" : gradient_norm_scores [- 1 ],
159
+ }
160
+ gradient_norm_scores = []
161
+ else :
162
+ metrics = {}
146
163
147
164
training_metrics = EvaluationResultBatch (
148
- losses = {
149
- f"{ loss_fun .tag } interval average" : train_loss_avg ,
150
- f"{ loss_fun .tag } last batch" : train_loss_last_batch ,
151
- },
152
- metrics = {
153
- "grad_norm_avg" : train_gradient_norm_avg ,
154
- "grad_norm_last_batch" : train_gradient_norm_last_batch ,
155
- },
165
+ losses = losses ,
166
+ metrics = metrics ,
156
167
# TODO: hardcoded metric key
157
168
throughput_metrics = {
158
169
"training_synced_num_samples_per_second" : synced_num_samples_per_second ,
@@ -181,7 +192,8 @@ def train(
181
192
182
193
def _reset_loss_and_gradient_norm (self ):
183
194
# TODO: we should handle the device assignment more centrally.
184
- cumulated_loss_and_gradient_norm = torch .zeros (5 )
195
+ # summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps)
196
+ cumulated_loss_and_gradient_norm = torch .zeros (3 )
185
197
if torch .cuda .is_available ():
186
198
cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm .to (torch .device (self .local_rank ))
187
199
else :
0 commit comments