mirror of https://github.com/InternLM/InternLM
update layer norm to tensorboard
parent
a94f429a67
commit
641ee14bbf
|
@ -115,7 +115,7 @@ class Engine:
|
|||
self._all_reduce_gradients()
|
||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||
|
||||
success, grad_norm = self.optimizer.step()
|
||||
success, grad_norm, layer_grad_norm = self.optimizer.step()
|
||||
|
||||
if success and self._lr_scheduler is not None:
|
||||
self._lr_scheduler.step()
|
||||
|
@ -123,7 +123,7 @@ class Engine:
|
|||
if success and self._beta2_scheduler is not None:
|
||||
self._beta2_scheduler.step()
|
||||
|
||||
return success, grad_norm
|
||||
return success, grad_norm, layer_grad_norm
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode."""
|
||||
|
|
|
@ -564,27 +564,55 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
total_layernorms[group_name] = self._compute_norm_with_stage(
|
||||
group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id]
|
||||
)
|
||||
total_norms[group_name] = sum(total_layernorms[group_name].values())
|
||||
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
||||
# during allreduce
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
# model and zero have been reduced!!!
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norms[group_name] = scaled_norm_tensor.item()
|
||||
# layer_norms allreduce
|
||||
scaled_layer_norm = torch.cuda.FloatTensor(
|
||||
list(total_layernorms[group_name].values()), device=get_current_device()
|
||||
)
|
||||
scaled_layer_norm = scaled_layer_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||
dist.all_reduce(scaled_layer_norm, group=pg)
|
||||
for i, layer_name in enumerate(total_layernorms[group_name].keys()):
|
||||
total_layernorms[group_name][layer_name] = scaled_layer_norm[i].item()
|
||||
|
||||
# compute total_norms using the layer grad_norm
|
||||
total_layer_norms_values = list(total_layernorms[group_name].values())
|
||||
# inf flag
|
||||
if -1 in total_layer_norms_values:
|
||||
total_norms[group_name] = -1
|
||||
# nan flag
|
||||
elif -2 in total_layer_norms_values:
|
||||
total_norms[group_name] = -2
|
||||
else:
|
||||
total_norms[group_name] = sum(total_layer_norms_values)
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
return self._step(closure=closure, norms=total_norms)
|
||||
return self._step(closure=closure, norms=total_norms, layer_norms=total_layernorms)
|
||||
|
||||
def _step(self, closure=None, norms=None):
|
||||
def _step(self, closure=None, norms=None, layer_norms=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
def scale_layer_norm(layer_norms, loss_scale):
|
||||
global_layer_norm_groups = {}
|
||||
if layer_norms:
|
||||
for group_name, layer_norm_dict in layer_norms.items():
|
||||
global_layer_norm_groups[group_name] = {}
|
||||
for layer_name, norm in layer_norm_dict.items():
|
||||
# filter unknown
|
||||
if layer_name == "unknown" and norm == 0:
|
||||
continue
|
||||
# handle inf (-1) and nan (-2)
|
||||
if norm != -1 or norm != -2:
|
||||
global_layer_norm_groups[group_name][layer_name] = norm**0.5 / loss_scale
|
||||
return global_layer_norm_groups
|
||||
|
||||
# check for overflow
|
||||
found_inf = False
|
||||
found_nan = False
|
||||
|
@ -603,6 +631,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if gpc.config.model.dtype is not torch.float32:
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# scale layer norm
|
||||
global_layer_norm_groups = scale_layer_norm(layer_norms, loss_scale)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
|
@ -613,7 +644,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, norms
|
||||
return False, norms, global_layer_norm_groups
|
||||
|
||||
if found_nan:
|
||||
if gpc.is_rank_for_log():
|
||||
|
@ -624,7 +655,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, norms
|
||||
return False, norms, global_layer_norm_groups
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
|
@ -711,7 +742,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# so synchronization is maintained
|
||||
for group_name, global_norm in global_norm_groups.items():
|
||||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
return True, global_norm_groups
|
||||
return True, global_norm_groups, global_layer_norm_groups
|
||||
|
||||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
|
|
@ -319,7 +319,7 @@ def compute_norm(
|
|||
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
||||
|
||||
for idx in range(len(total_layer_norms_keys)):
|
||||
for idx, layer_name in enumerate(total_layer_norms.keys()):
|
||||
layer_norm = total_layer_norms_values[idx]
|
||||
if torch.is_tensor(layer_norm):
|
||||
layer_norm = layer_norm.item()
|
||||
|
@ -328,7 +328,7 @@ def compute_norm(
|
|||
|
||||
if math.isnan(layer_norm):
|
||||
layer_norm = -2
|
||||
total_layer_norms[total_layer_norms_keys[idx]] = layer_norm
|
||||
total_layer_norms[layer_name] = layer_norm
|
||||
|
||||
return total_layer_norms
|
||||
|
||||
|
|
|
@ -405,6 +405,7 @@ def record_current_batch_training_metrics(
|
|||
loss,
|
||||
moe_loss,
|
||||
grad_norm,
|
||||
layer_grad_norm,
|
||||
metric,
|
||||
update_panel,
|
||||
):
|
||||
|
@ -526,6 +527,11 @@ def record_current_batch_training_metrics(
|
|||
else:
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
# add layer grad norm
|
||||
for key, value in layer_grad_norm.items():
|
||||
title = f"layer_grad_norm_group_{key}"
|
||||
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
||||
|
||||
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
|
||||
send_heartbeat("train_metrics", infos)
|
||||
|
||||
|
|
3
train.py
3
train.py
|
@ -240,7 +240,7 @@ def main(args):
|
|||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
success_update, grad_norm_groups, layer_grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
|
@ -268,6 +268,7 @@ def main(args):
|
|||
loss=loss,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=grad_norm_groups,
|
||||
layer_grad_norm=layer_grad_norm_groups,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue