feat(utils/writer.py): support writer add_scalars for writing dict data (#257)

* feat(utils/writer.py): support writer add_scalars interface for writing dict data

* feat(hybrid_zero_optim.py): change grad_norm_groups list to dict
pull/123/head
huangting4201 2023-09-01 13:24:46 +08:00 committed by GitHub
parent c516602e9a
commit b9202b12bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 32 deletions

View File

@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer):
grads = [self.padding_grad]
params = [self.padding_tensor]
norm = 0
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = []
total_norms = {}
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}"
total_norms[group_name] = self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
timer("sync_grad").start()
@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm
if -1 in norms:
if -1 in norms.values():
found_inf = True
loss_scale = float(self.loss_scale.item()) # backup
@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer):
# unscale and clip grads
# get the global norm
global_norm_groups = []
global_norm_groups = {}
if self._clip_grad_norm > 0:
for norm in norms:
global_norm_groups.append(norm**0.5)
for group_name, norm in norms.items():
global_norm_groups[group_name] = norm**0.5
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)
# update the parameters
timer("step").start()
@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def broadcast_params(self):
handles = []

View File

@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if isinstance(value, dict):
writer.add_scalars(key=key, value=value, step=train_state.step_count)
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel:
# metrics shown with dashboard panels
panel_metrics = {
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
}
for norm_key, norm_value in grad_norm.items():
panel_metrics[norm_key] = norm_value
logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
"{line}",
line=line,
extra=panel_metrics,
)
else:
logger.info(line)

View File

@ -134,6 +134,14 @@ class Writer:
except Exception:
traceback.print_exc()
def add_scalars(self, key, value, step):
try:
assert isinstance(value, dict)
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
except Exception:
traceback.print_exc()
def add_text(self, key, value, step):
try:
if self.enable_tb and self.tb_writer is not None:

View File

@ -6,7 +6,6 @@ import time
import traceback
from functools import partial
import numpy as np
import torch
import torch.distributed as dist
@ -236,7 +235,7 @@ def main(args):
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
@ -257,7 +256,7 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=np.array(grad_norm_groups),
grad_norm=grad_norm_groups,
metric=metric,
update_panel=uniscale_logger is not None,
)