mirror of https://github.com/InternLM/InternLM
update layer norm to tensorboard
parent
a94f429a67
commit
641ee14bbf
|
@ -115,7 +115,7 @@ class Engine:
|
||||||
self._all_reduce_gradients()
|
self._all_reduce_gradients()
|
||||||
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
||||||
|
|
||||||
success, grad_norm = self.optimizer.step()
|
success, grad_norm, layer_grad_norm = self.optimizer.step()
|
||||||
|
|
||||||
if success and self._lr_scheduler is not None:
|
if success and self._lr_scheduler is not None:
|
||||||
self._lr_scheduler.step()
|
self._lr_scheduler.step()
|
||||||
|
@ -123,7 +123,7 @@ class Engine:
|
||||||
if success and self._beta2_scheduler is not None:
|
if success and self._beta2_scheduler is not None:
|
||||||
self._beta2_scheduler.step()
|
self._beta2_scheduler.step()
|
||||||
|
|
||||||
return success, grad_norm
|
return success, grad_norm, layer_grad_norm
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Sets the model to training mode."""
|
"""Sets the model to training mode."""
|
||||||
|
|
|
@ -564,27 +564,55 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
total_layernorms[group_name] = self._compute_norm_with_stage(
|
total_layernorms[group_name] = self._compute_norm_with_stage(
|
||||||
group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id]
|
group_id=group_id, last_bucket=True, last_stage=True, previous_layer_norms=groups_layer_norms[group_id]
|
||||||
)
|
)
|
||||||
total_norms[group_name] = sum(total_layernorms[group_name].values())
|
|
||||||
|
|
||||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
||||||
# during allreduce
|
# during allreduce
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||||
# model and zero have been reduced!!!
|
# model and zero have been reduced!!!
|
||||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
# layer_norms allreduce
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
scaled_layer_norm = torch.cuda.FloatTensor(
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
list(total_layernorms[group_name].values()), device=get_current_device()
|
||||||
total_norms[group_name] = scaled_norm_tensor.item()
|
)
|
||||||
|
scaled_layer_norm = scaled_layer_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||||
|
dist.all_reduce(scaled_layer_norm, group=pg)
|
||||||
|
for i, layer_name in enumerate(total_layernorms[group_name].keys()):
|
||||||
|
total_layernorms[group_name][layer_name] = scaled_layer_norm[i].item()
|
||||||
|
|
||||||
|
# compute total_norms using the layer grad_norm
|
||||||
|
total_layer_norms_values = list(total_layernorms[group_name].values())
|
||||||
|
# inf flag
|
||||||
|
if -1 in total_layer_norms_values:
|
||||||
|
total_norms[group_name] = -1
|
||||||
|
# nan flag
|
||||||
|
elif -2 in total_layer_norms_values:
|
||||||
|
total_norms[group_name] = -2
|
||||||
|
else:
|
||||||
|
total_norms[group_name] = sum(total_layer_norms_values)
|
||||||
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
||||||
return self._step(closure=closure, norms=total_norms)
|
return self._step(closure=closure, norms=total_norms, layer_norms=total_layernorms)
|
||||||
|
|
||||||
def _step(self, closure=None, norms=None):
|
def _step(self, closure=None, norms=None, layer_norms=None):
|
||||||
assert closure is None, "closure is not supported by step()"
|
assert closure is None, "closure is not supported by step()"
|
||||||
|
|
||||||
|
def scale_layer_norm(layer_norms, loss_scale):
|
||||||
|
global_layer_norm_groups = {}
|
||||||
|
if layer_norms:
|
||||||
|
for group_name, layer_norm_dict in layer_norms.items():
|
||||||
|
global_layer_norm_groups[group_name] = {}
|
||||||
|
for layer_name, norm in layer_norm_dict.items():
|
||||||
|
# filter unknown
|
||||||
|
if layer_name == "unknown" and norm == 0:
|
||||||
|
continue
|
||||||
|
# handle inf (-1) and nan (-2)
|
||||||
|
if norm != -1 or norm != -2:
|
||||||
|
global_layer_norm_groups[group_name][layer_name] = norm**0.5 / loss_scale
|
||||||
|
return global_layer_norm_groups
|
||||||
|
|
||||||
# check for overflow
|
# check for overflow
|
||||||
found_inf = False
|
found_inf = False
|
||||||
found_nan = False
|
found_nan = False
|
||||||
|
@ -603,6 +631,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if gpc.config.model.dtype is not torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
self.grad_scaler.update(found_inf)
|
self.grad_scaler.update(found_inf)
|
||||||
|
|
||||||
|
# scale layer norm
|
||||||
|
global_layer_norm_groups = scale_layer_norm(layer_norms, loss_scale)
|
||||||
|
|
||||||
# update loss scale if overflow occurs
|
# update loss scale if overflow occurs
|
||||||
if found_inf:
|
if found_inf:
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
@ -613,7 +644,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, norms
|
return False, norms, global_layer_norm_groups
|
||||||
|
|
||||||
if found_nan:
|
if found_nan:
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
@ -624,7 +655,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, norms
|
return False, norms, global_layer_norm_groups
|
||||||
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
single_grad_partition_groups = []
|
||||||
|
@ -711,7 +742,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# so synchronization is maintained
|
# so synchronization is maintained
|
||||||
for group_name, global_norm in global_norm_groups.items():
|
for group_name, global_norm in global_norm_groups.items():
|
||||||
global_norm_groups[group_name] = global_norm / loss_scale
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
return True, global_norm_groups
|
return True, global_norm_groups, global_layer_norm_groups
|
||||||
|
|
||||||
def broadcast_params(self):
|
def broadcast_params(self):
|
||||||
handles = []
|
handles = []
|
||||||
|
|
|
@ -319,7 +319,7 @@ def compute_norm(
|
||||||
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||||
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
dist.all_reduce(total_layer_norms_values, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
||||||
|
|
||||||
for idx in range(len(total_layer_norms_keys)):
|
for idx, layer_name in enumerate(total_layer_norms.keys()):
|
||||||
layer_norm = total_layer_norms_values[idx]
|
layer_norm = total_layer_norms_values[idx]
|
||||||
if torch.is_tensor(layer_norm):
|
if torch.is_tensor(layer_norm):
|
||||||
layer_norm = layer_norm.item()
|
layer_norm = layer_norm.item()
|
||||||
|
@ -328,7 +328,7 @@ def compute_norm(
|
||||||
|
|
||||||
if math.isnan(layer_norm):
|
if math.isnan(layer_norm):
|
||||||
layer_norm = -2
|
layer_norm = -2
|
||||||
total_layer_norms[total_layer_norms_keys[idx]] = layer_norm
|
total_layer_norms[layer_name] = layer_norm
|
||||||
|
|
||||||
return total_layer_norms
|
return total_layer_norms
|
||||||
|
|
||||||
|
|
|
@ -405,6 +405,7 @@ def record_current_batch_training_metrics(
|
||||||
loss,
|
loss,
|
||||||
moe_loss,
|
moe_loss,
|
||||||
grad_norm,
|
grad_norm,
|
||||||
|
layer_grad_norm,
|
||||||
metric,
|
metric,
|
||||||
update_panel,
|
update_panel,
|
||||||
):
|
):
|
||||||
|
@ -526,6 +527,11 @@ def record_current_batch_training_metrics(
|
||||||
else:
|
else:
|
||||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||||
|
|
||||||
|
# add layer grad norm
|
||||||
|
for key, value in layer_grad_norm.items():
|
||||||
|
title = f"layer_grad_norm_group_{key}"
|
||||||
|
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
||||||
|
|
||||||
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
|
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
|
||||||
send_heartbeat("train_metrics", infos)
|
send_heartbeat("train_metrics", infos)
|
||||||
|
|
||||||
|
|
3
train.py
3
train.py
|
@ -240,7 +240,7 @@ def main(args):
|
||||||
trainer_result = trainer.step()
|
trainer_result = trainer.step()
|
||||||
assert trainer_result is not None
|
assert trainer_result is not None
|
||||||
|
|
||||||
success_update, grad_norm_groups = trainer_result
|
success_update, grad_norm_groups, layer_grad_norm_groups = trainer_result
|
||||||
if success_update: # update parameters successfully
|
if success_update: # update parameters successfully
|
||||||
train_state.step_count += 1
|
train_state.step_count += 1
|
||||||
else:
|
else:
|
||||||
|
@ -268,6 +268,7 @@ def main(args):
|
||||||
loss=loss,
|
loss=loss,
|
||||||
moe_loss=moe_loss,
|
moe_loss=moe_loss,
|
||||||
grad_norm=grad_norm_groups,
|
grad_norm=grad_norm_groups,
|
||||||
|
layer_grad_norm=layer_grad_norm_groups,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue