feat(utils/writer.py): support writer add_scalars for writing dict data (#257)

* feat(utils/writer.py): support writer add_scalars interface for writing dict data

* feat(hybrid_zero_optim.py): change grad_norm_groups list to dict
pull/123/head
huangting4201 2023-09-01 13:24:46 +08:00 committed by GitHub
parent c516602e9a
commit b9202b12bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 32 deletions

View File

@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer):
grads = [self.padding_grad] grads = [self.padding_grad]
params = [self.padding_tensor] params = [self.padding_tensor]
norm = 0
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large # this norm is before scaling, it will be very large
norm = compute_norm( norm = compute_norm(
@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket # compute norm for gradients in the last bucket
total_norms = [] total_norms = {}
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
total_norms.append( group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
self._compute_norm_with_stage( group_name = f"{group_id}_{group_name}"
group_id=group_id, total_norms[group_name] = self._compute_norm_with_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_norm=groups_norms[group_id], last_stage=True,
) previous_norm=groups_norms[group_id],
) )
timer("sync_grad").start() timer("sync_grad").start()
@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# found_inf = self._check_overflow() # found_inf = self._check_overflow()
# Because you may encounter inf when computing norm # Because you may encounter inf when computing norm
if -1 in norms: if -1 in norms.values():
found_inf = True found_inf = True
loss_scale = float(self.loss_scale.item()) # backup loss_scale = float(self.loss_scale.item()) # backup
@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer):
# unscale and clip grads # unscale and clip grads
# get the global norm # get the global norm
global_norm_groups = [] global_norm_groups = {}
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
for norm in norms: for group_name, norm in norms.items():
global_norm_groups.append(norm**0.5) global_norm_groups[group_name] = norm**0.5
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32: if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0: if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale) self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)
# update the parameters # update the parameters
timer("step").start() timer("step").start()
@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# update gradients may not be needed here, because the sync_params function is used in initialization, # update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained # so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups] for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def broadcast_params(self): def broadcast_params(self):
handles = [] handles = []

View File

@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
line = "" line = ""
for key, value in infos.items(): for key, value in infos.items():
line += f"{key}={value} " line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count) if isinstance(value, dict):
writer.add_scalars(key=key, value=value, step=train_state.step_count)
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel: if update_panel:
# metrics shown with dashboard panels
panel_metrics = {
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
}
for norm_key, norm_value in grad_norm.items():
panel_metrics[norm_key] = norm_value
logger.info( logger.info(
line, "{line}",
extra={ line=line,
"step": batch_count, extra=panel_metrics,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
) )
else: else:
logger.info(line) logger.info(line)

View File

@ -134,6 +134,14 @@ class Writer:
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
def add_scalars(self, key, value, step):
try:
assert isinstance(value, dict)
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
except Exception:
traceback.print_exc()
def add_text(self, key, value, step): def add_text(self, key, value, step):
try: try:
if self.enable_tb and self.tb_writer is not None: if self.enable_tb and self.tb_writer is not None:

View File

@ -6,7 +6,6 @@ import time
import traceback import traceback
from functools import partial from functools import partial
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -236,7 +235,7 @@ def main(args):
train_state.step_count += 1 train_state.step_count += 1
else: else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.") logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message( send_alert_message(
address=gpc.config.alert_address, address=gpc.config.alert_address,
@ -257,7 +256,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
grad_norm=np.array(grad_norm_groups), grad_norm=grad_norm_groups,
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,
) )