mirror of https://github.com/InternLM/InternLM
feat(utils/writer.py): support writer add_scalars for writing dict data (#257)
* feat(utils/writer.py): support writer add_scalars interface for writing dict data * feat(hybrid_zero_optim.py): change grad_norm_groups list to dictpull/123/head
parent
c516602e9a
commit
b9202b12bc
|
@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
grads = [self.padding_grad]
|
||||
params = [self.padding_tensor]
|
||||
|
||||
norm = 0
|
||||
if self._clip_grad_norm > 0:
|
||||
# this norm is before scaling, it will be very large
|
||||
norm = compute_norm(
|
||||
|
@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = []
|
||||
total_norms = {}
|
||||
for group_id in range(self.num_param_groups):
|
||||
total_norms.append(
|
||||
self._compute_norm_with_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||
group_name = f"{group_id}_{group_name}"
|
||||
total_norms[group_name] = self._compute_norm_with_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
|
||||
timer("sync_grad").start()
|
||||
|
@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# found_inf = self._check_overflow()
|
||||
# Because you may encounter inf when computing norm
|
||||
|
||||
if -1 in norms:
|
||||
if -1 in norms.values():
|
||||
found_inf = True
|
||||
|
||||
loss_scale = float(self.loss_scale.item()) # backup
|
||||
|
@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
global_norm_groups = []
|
||||
global_norm_groups = {}
|
||||
if self._clip_grad_norm > 0:
|
||||
for norm in norms:
|
||||
global_norm_groups.append(norm**0.5)
|
||||
for group_name, norm in norms.items():
|
||||
global_norm_groups[group_name] = norm**0.5
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0:
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
|
||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||
self._unscale_and_clip_grads(
|
||||
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
|
||||
)
|
||||
|
||||
# update the parameters
|
||||
timer("step").start()
|
||||
|
@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||
for group_name, global_norm in global_norm_groups.items():
|
||||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
return True, global_norm_groups
|
||||
|
||||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
|
|
@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
|
|||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(key=key, value=value, step=train_state.step_count)
|
||||
else:
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if update_panel:
|
||||
# metrics shown with dashboard panels
|
||||
panel_metrics = {
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
}
|
||||
for norm_key, norm_value in grad_norm.items():
|
||||
panel_metrics[norm_key] = norm_value
|
||||
|
||||
logger.info(
|
||||
line,
|
||||
extra={
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
},
|
||||
"{line}",
|
||||
line=line,
|
||||
extra=panel_metrics,
|
||||
)
|
||||
else:
|
||||
logger.info(line)
|
||||
|
|
|
@ -134,6 +134,14 @@ class Writer:
|
|||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_scalars(self, key, value, step):
|
||||
try:
|
||||
assert isinstance(value, dict)
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_text(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
|
|
5
train.py
5
train.py
|
@ -6,7 +6,6 @@ import time
|
|||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
@ -236,7 +235,7 @@ def main(args):
|
|||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
|
@ -257,7 +256,7 @@ def main(args):
|
|||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=np.array(grad_norm_groups),
|
||||
grad_norm=grad_norm_groups,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue