update layer norm to tensorboard

pull/412/head
JiaoPL 2023-10-13 12:07:58 +08:00
parent a94f429a67
commit 641ee14bbf
5 changed files with 53 additions and 15 deletions

View File

@ -115,7 +115,7 @@ class Engine:
self._all_reduce_gradients() self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
success, grad_norm = self.optimizer.step() success, grad_norm, layer_grad_norm = self.optimizer.step()
if success and self._lr_scheduler is not None: if success and self._lr_scheduler is not None:
self._lr_scheduler.step() self._lr_scheduler.step()
@ -123,7 +123,7 @@ class Engine:
if success and self._beta2_scheduler is not None: if success and self._beta2_scheduler is not None:
self._beta2_scheduler.step() self._beta2_scheduler.step()
return success, grad_norm return success, grad_norm, layer_grad_norm
def train(self): def train(self):
"""Sets the model to training mode.""" """Sets the model to training mode."""

View File

@ -564,27 +564,55 @@ class HybridZeroOptimizer(BaseOptimizer):
total_layernorms[group_name] = self._compute_norm_with_stage( total_layernorms[group_name] = self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id] group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id]
) )
total_norms[group_name] = sum(total_layernorms[group_name].values())
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced # Need to allreduce(avg) the norms across different ranks because moe params will not be synced
# during allreduce # during allreduce
if self._is_moe_group(self.optim.param_groups[group_id]): if self._is_moe_group(self.optim.param_groups[group_id]):
# model and zero have been reduced!!! # model and zero have been reduced!!!
pg = gpc.get_group(ParallelMode.EXPERT) pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) # layer_norms allreduce
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) scaled_layer_norm = torch.cuda.FloatTensor(
dist.all_reduce(scaled_norm_tensor, group=pg) list(total_layernorms[group_name].values()), device=get_current_device()
total_norms[group_name] = scaled_norm_tensor.item() )
scaled_layer_norm = scaled_layer_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
dist.all_reduce(scaled_layer_norm, group=pg)
for i, layer_name in enumerate(total_layernorms[group_name].keys()):
total_layernorms[group_name][layer_name] = scaled_layer_norm[i].item()
# compute total_norms using the layer grad_norm
total_layer_norms_values = list(total_layernorms[group_name].values())
# inf flag
if -1 in total_layer_norms_values:
total_norms[group_name] = -1
# nan flag
elif -2 in total_layer_norms_values:
total_norms[group_name] = -2
else:
total_norms[group_name] = sum(total_layer_norms_values)
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
return self._step(closure=closure, norms=total_norms) return self._step(closure=closure, norms=total_norms, layer_norms=total_layernorms)
def _step(self, closure=None, norms=None): def _step(self, closure=None, norms=None, layer_norms=None):
assert closure is None, "closure is not supported by step()" assert closure is None, "closure is not supported by step()"
def scale_layer_norm(layer_norms, loss_scale):
global_layer_norm_groups = {}
if layer_norms:
for group_name, layer_norm_dict in layer_norms.items():
global_layer_norm_groups[group_name] = {}
for layer_name, norm in layer_norm_dict.items():
# filter unknown
if layer_name == "unknown" and norm == 0:
continue
# handle inf (-1) and nan (-2)
if norm != -1 or norm != -2:
global_layer_norm_groups[group_name][layer_name] = norm**0.5 / loss_scale
return global_layer_norm_groups
# check for overflow # check for overflow
found_inf = False found_inf = False
found_nan = False found_nan = False
@ -603,6 +631,9 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.config.model.dtype is not torch.float32: if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
# scale layer norm
global_layer_norm_groups = scale_layer_norm(layer_norms, loss_scale)
# update loss scale if overflow occurs # update loss scale if overflow occurs
if found_inf: if found_inf:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
@ -613,7 +644,7 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()
self.zero_grad() self.zero_grad()
return False, norms return False, norms, global_layer_norm_groups
if found_nan: if found_nan:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
@ -624,7 +655,7 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()
self.zero_grad() self.zero_grad()
return False, norms return False, norms, global_layer_norm_groups
# copy the grad of fp16 param to fp32 param # copy the grad of fp16 param to fp32 param
single_grad_partition_groups = [] single_grad_partition_groups = []
@ -711,7 +742,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# so synchronization is maintained # so synchronization is maintained
for group_name, global_norm in global_norm_groups.items(): for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups return True, global_norm_groups, global_layer_norm_groups
def broadcast_params(self): def broadcast_params(self):
handles = [] handles = []

View File

@ -319,7 +319,7 @@ def compute_norm(
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
for idx in range(len(total_layer_norms_keys)): for idx, layer_name in enumerate(total_layer_norms.keys()):
layer_norm = total_layer_norms_values[idx] layer_norm = total_layer_norms_values[idx]
if torch.is_tensor(layer_norm): if torch.is_tensor(layer_norm):
layer_norm = layer_norm.item() layer_norm = layer_norm.item()
@ -328,7 +328,7 @@ def compute_norm(
if math.isnan(layer_norm): if math.isnan(layer_norm):
layer_norm = -2 layer_norm = -2
total_layer_norms[total_layer_norms_keys[idx]] = layer_norm total_layer_norms[layer_name] = layer_norm
return total_layer_norms return total_layer_norms

View File

@ -405,6 +405,7 @@ def record_current_batch_training_metrics(
loss, loss,
moe_loss, moe_loss,
grad_norm, grad_norm,
layer_grad_norm,
metric, metric,
update_panel, update_panel,
): ):
@ -526,6 +527,11 @@ def record_current_batch_training_metrics(
else: else:
writer.add_scalar(key=key, value=value, step=train_state.step_count) writer.add_scalar(key=key, value=value, step=train_state.step_count)
# add layer grad norm
for key, value in layer_grad_norm.items():
title = f"layer_grad_norm_group_{key}"
writer.add_scalars(key=title, value=value, step=train_state.step_count)
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0: if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
send_heartbeat("train_metrics", infos) send_heartbeat("train_metrics", infos)

View File

@ -240,7 +240,7 @@ def main(args):
trainer_result = trainer.step() trainer_result = trainer.step()
assert trainer_result is not None assert trainer_result is not None
success_update, grad_norm_groups = trainer_result success_update, grad_norm_groups, layer_grad_norm_groups = trainer_result
if success_update: # update parameters successfully if success_update: # update parameters successfully
train_state.step_count += 1 train_state.step_count += 1
else: else:
@ -268,6 +268,7 @@ def main(args):
loss=loss, loss=loss,
moe_loss=moe_loss, moe_loss=moe_loss,
grad_norm=grad_norm_groups, grad_norm=grad_norm_groups,
layer_grad_norm=layer_grad_norm_groups,
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,
) )