mirror of https://github.com/InternLM/InternLM
feat(utils/writer.py): support writer add_scalars for writing dict data (#257)
* feat(utils/writer.py): support writer add_scalars interface for writing dict data * feat(hybrid_zero_optim.py): change grad_norm_groups list to dictpull/123/head
parent
c516602e9a
commit
b9202b12bc
|
@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
grads = [self.padding_grad]
|
grads = [self.padding_grad]
|
||||||
params = [self.padding_tensor]
|
params = [self.padding_tensor]
|
||||||
|
|
||||||
|
norm = 0
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
# this norm is before scaling, it will be very large
|
# this norm is before scaling, it will be very large
|
||||||
norm = compute_norm(
|
norm = compute_norm(
|
||||||
|
@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
total_norms = []
|
total_norms = {}
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
total_norms.append(
|
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||||
self._compute_norm_with_stage(
|
group_name = f"{group_id}_{group_name}"
|
||||||
group_id=group_id,
|
total_norms[group_name] = self._compute_norm_with_stage(
|
||||||
last_bucket=True,
|
group_id=group_id,
|
||||||
last_stage=True,
|
last_bucket=True,
|
||||||
previous_norm=groups_norms[group_id],
|
last_stage=True,
|
||||||
)
|
previous_norm=groups_norms[group_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
|
@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# found_inf = self._check_overflow()
|
# found_inf = self._check_overflow()
|
||||||
# Because you may encounter inf when computing norm
|
# Because you may encounter inf when computing norm
|
||||||
|
|
||||||
if -1 in norms:
|
if -1 in norms.values():
|
||||||
found_inf = True
|
found_inf = True
|
||||||
|
|
||||||
loss_scale = float(self.loss_scale.item()) # backup
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
|
@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
global_norm_groups = []
|
global_norm_groups = {}
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
for norm in norms:
|
for group_name, norm in norms.items():
|
||||||
global_norm_groups.append(norm**0.5)
|
global_norm_groups[group_name] = norm**0.5
|
||||||
|
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if gpc.config.model.dtype is not torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
if len(single_grad_partition_groups) != 0:
|
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
|
self._unscale_and_clip_grads(
|
||||||
|
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
|
||||||
|
)
|
||||||
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
timer("step").start()
|
timer("step").start()
|
||||||
|
@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||||
# so synchronization is maintained
|
# so synchronization is maintained
|
||||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
for group_name, global_norm in global_norm_groups.items():
|
||||||
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
|
return True, global_norm_groups
|
||||||
|
|
||||||
def broadcast_params(self):
|
def broadcast_params(self):
|
||||||
handles = []
|
handles = []
|
||||||
|
|
|
@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
|
||||||
line = ""
|
line = ""
|
||||||
for key, value in infos.items():
|
for key, value in infos.items():
|
||||||
line += f"{key}={value} "
|
line += f"{key}={value} "
|
||||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
if isinstance(value, dict):
|
||||||
|
writer.add_scalars(key=key, value=value, step=train_state.step_count)
|
||||||
|
else:
|
||||||
|
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||||
|
|
||||||
if update_panel:
|
if update_panel:
|
||||||
|
# metrics shown with dashboard panels
|
||||||
|
panel_metrics = {
|
||||||
|
"step": batch_count,
|
||||||
|
"lr": lr,
|
||||||
|
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||||
|
"loss": loss.item(),
|
||||||
|
"flops": tflops,
|
||||||
|
"tgs": tk_per_gpu,
|
||||||
|
"acc": acc_perplex["acc"],
|
||||||
|
"perplexity": acc_perplex["perplexity"],
|
||||||
|
"fwd_bwd_time": fwd_bwd_time,
|
||||||
|
}
|
||||||
|
for norm_key, norm_value in grad_norm.items():
|
||||||
|
panel_metrics[norm_key] = norm_value
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
line,
|
"{line}",
|
||||||
extra={
|
line=line,
|
||||||
"step": batch_count,
|
extra=panel_metrics,
|
||||||
"lr": lr,
|
|
||||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
|
||||||
"grad_norm": grad_norm,
|
|
||||||
"loss": loss.item(),
|
|
||||||
"flops": tflops,
|
|
||||||
"tgs": tk_per_gpu,
|
|
||||||
"acc": acc_perplex["acc"],
|
|
||||||
"perplexity": acc_perplex["perplexity"],
|
|
||||||
"fwd_bwd_time": fwd_bwd_time,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(line)
|
logger.info(line)
|
||||||
|
|
|
@ -134,6 +134,14 @@ class Writer:
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def add_scalars(self, key, value, step):
|
||||||
|
try:
|
||||||
|
assert isinstance(value, dict)
|
||||||
|
if self.enable_tb and self.tb_writer is not None:
|
||||||
|
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def add_text(self, key, value, step):
|
def add_text(self, key, value, step):
|
||||||
try:
|
try:
|
||||||
if self.enable_tb and self.tb_writer is not None:
|
if self.enable_tb and self.tb_writer is not None:
|
||||||
|
|
5
train.py
5
train.py
|
@ -6,7 +6,6 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
@ -236,7 +235,7 @@ def main(args):
|
||||||
train_state.step_count += 1
|
train_state.step_count += 1
|
||||||
else:
|
else:
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
send_alert_message(
|
send_alert_message(
|
||||||
address=gpc.config.alert_address,
|
address=gpc.config.alert_address,
|
||||||
|
@ -257,7 +256,7 @@ def main(args):
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
grad_norm=np.array(grad_norm_groups),
|
grad_norm=grad_norm_groups,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue