mirror of https://github.com/InternLM/InternLM
refactor grad norm profiling (#466)
parent
d537e45456
commit
debb7e77b9
|
@ -611,14 +611,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||||
|
|
||||||
# compute norm for gradients in the before bucket
|
# compute norm for gradients in the before bucket
|
||||||
|
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||||
groups_norms = []
|
groups_norms = []
|
||||||
groups_param_norms = []
|
groups_param_norms = []
|
||||||
group_param_zero_grad_count = []
|
group_param_zero_grad_count = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||||
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
||||||
if gpc.config.get("zero_grad_profiling", False):
|
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||||
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
|
@ -645,7 +646,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
last_stage=True,
|
last_stage=True,
|
||||||
previous_norm=groups_norms[group_id],
|
previous_norm=groups_norms[group_id],
|
||||||
)
|
)
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||||
param_norms = self._compute_param_norm_stage(
|
param_norms = self._compute_param_norm_stage(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
last_bucket=True,
|
last_bucket=True,
|
||||||
|
@ -655,7 +656,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm(
|
total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm(
|
||||||
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
||||||
)
|
)
|
||||||
if gpc.config.get("zero_grad_profiling", False):
|
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||||
zero_grad_count = self._count_zero_grads_stage(
|
zero_grad_count = self._count_zero_grads_stage(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
last_bucket=True,
|
last_bucket=True,
|
||||||
|
@ -672,10 +673,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
||||||
state, global_norms = self._step(closure=closure, norms=total_norms)
|
state, global_norms = self._step(closure=closure, norms=total_norms)
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||||
global_norms["layer_norm"] = total_layer_norms
|
global_norms["layer_norm"] = total_layer_norms
|
||||||
global_norms["param_norm"] = total_param_norms
|
global_norms["param_norm"] = total_param_norms
|
||||||
if gpc.config.get("zero_grad_profiling", False):
|
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||||
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
||||||
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||||
"""
|
"""
|
||||||
if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False):
|
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||||
|
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
|
||||||
|
"zero_grad_profiling", False
|
||||||
|
):
|
||||||
# set the layer name as an attribute of the model parameters
|
# set the layer name as an attribute of the model parameters
|
||||||
set_model_params_layer_name(model)
|
set_model_params_layer_name(model)
|
||||||
|
|
||||||
|
@ -526,26 +529,42 @@ def record_current_batch_training_metrics(
|
||||||
for key, value in acc_perplex.items():
|
for key, value in acc_perplex.items():
|
||||||
infos[key] = value
|
infos[key] = value
|
||||||
|
|
||||||
if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False):
|
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||||
|
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
|
||||||
|
"zero_grad_profiling", False
|
||||||
|
):
|
||||||
layer_metrics = ["layer_norm", "layer_zero_grad"]
|
layer_metrics = ["layer_norm", "layer_zero_grad"]
|
||||||
param_metrics = ["param_norm", "param_zero_grad"]
|
param_metrics = ["param_norm", "param_zero_grad"]
|
||||||
|
layer_names = grad_profiling_config.get("layers", [])
|
||||||
for layer_metric_name in layer_metrics:
|
for layer_metric_name in layer_metrics:
|
||||||
layer_metric = grad_norm.get(layer_metric_name, {})
|
layer_metric = grad_norm.get(layer_metric_name, {})
|
||||||
if layer_metric:
|
if layer_metric:
|
||||||
for group_name, value in layer_metric.items():
|
for group_name, layer_group in layer_metric.items():
|
||||||
if value:
|
if layer_group:
|
||||||
title = f"{layer_metric_name}/{group_name}"
|
title = f"{layer_metric_name}/{group_name}"
|
||||||
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
if layer_names:
|
||||||
|
filter_layer_metrics = {}
|
||||||
|
for layer_name, metric_value in layer_group.items():
|
||||||
|
if layer_name in layer_names:
|
||||||
|
filter_layer_metrics[layer_name] = metric_value
|
||||||
|
writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count)
|
||||||
|
else:
|
||||||
|
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
|
||||||
del grad_norm[layer_metric_name]
|
del grad_norm[layer_metric_name]
|
||||||
|
|
||||||
for param_metric_name in param_metrics:
|
for param_metric_name in param_metrics:
|
||||||
param_metric = grad_norm.get(param_metric_name, {})
|
param_metric = grad_norm.get(param_metric_name, {})
|
||||||
if param_metric:
|
if param_metric:
|
||||||
for group_name, layer_group in param_metric.items():
|
for group_name, layer_group in param_metric.items():
|
||||||
if layer_group:
|
for param_name, param_group in layer_group.items():
|
||||||
for param_name, param_group in layer_group.items():
|
title = f"{param_name}/{group_name}_{param_metric_name}"
|
||||||
title = f"{param_name}/{group_name}_{param_metric_name}"
|
if layer_names:
|
||||||
|
filter_param_group = {}
|
||||||
|
for layer_name, metric_value in param_group.items():
|
||||||
|
if layer_name in layer_names:
|
||||||
|
filter_param_group[layer_name] = param_group[layer_name]
|
||||||
|
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
|
||||||
|
else:
|
||||||
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
||||||
del grad_norm[param_metric_name]
|
del grad_norm[param_metric_name]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue