add zero_grad_profiling option

pull/449/head
JiaoPL 2023-10-26 17:20:44 +08:00
parent a6051335b7
commit 83cb7036a7
3 changed files with 70 additions and 59 deletions

View File

@ -605,6 +605,7 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if gpc.config.get("grad_norm_profiling", False): if gpc.config.get("grad_norm_profiling", False):
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
if gpc.config.get("zero_grad_profiling", False):
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
@ -641,6 +642,7 @@ class HybridZeroOptimizer(BaseOptimizer):
total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm( total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm(
param_norms=param_norms, loss_scale=self.loss_scale.item() param_norms=param_norms, loss_scale=self.loss_scale.item()
) )
if gpc.config.get("zero_grad_profiling", False):
zero_grad_count = self._count_zero_grads_stage( zero_grad_count = self._count_zero_grads_stage(
group_id=group_id, group_id=group_id,
last_bucket=True, last_bucket=True,
@ -668,8 +670,9 @@ class HybridZeroOptimizer(BaseOptimizer):
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms)
if gpc.config.get("grad_norm_profiling", False): if gpc.config.get("grad_norm_profiling", False):
global_norms["layer_norms"] = total_layer_norms global_norms["layer_norm"] = total_layer_norms
global_norms["param_norms"] = total_param_norms global_norms["param_norm"] = total_param_norms
if gpc.config.get("zero_grad_profiling", False):
global_norms["layer_zero_grad"] = total_layer_zero_grad_count global_norms["layer_zero_grad"] = total_layer_zero_grad_count
global_norms["param_zero_grad"] = total_param_zero_grad_count global_norms["param_zero_grad"] = total_param_zero_grad_count

View File

@ -211,9 +211,11 @@ def calc_lp(grads, norm_type):
def calc_zero_grad(grads): def calc_zero_grad(grads):
zero_count = 0 zero_count = 0
grad_size = 0
for grad in grads: for grad in grads:
zero_count += (grad == 0).sum().item() zero_count += (grad == 0).sum().item()
return zero_count grad_size += grad.numel()
return torch.tensor([zero_count, grad_size])
def reduce_grads(gradients, parameters, fine_grained=False): def reduce_grads(gradients, parameters, fine_grained=False):
@ -370,12 +372,12 @@ def compute_param_metric(
for param_name, grads in param_grads.items(): for param_name, grads in param_grads.items():
if metric_type == "norm": if metric_type == "norm":
if norm_type == inf: if norm_type == inf:
param_norm = max(g.data.abs().max() for g in grads) param_metric = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels: elif norm_type == 2.0 and enable_cuda_kernels:
param_norm = calc_l2_norm(grads) ** norm_type param_metric = calc_l2_norm(grads) ** norm_type
else: else:
param_norm = calc_lp(grads, norm_type) param_metric = calc_lp(grads, norm_type)
param_metrics[param_name] = param_norm.item() if torch.is_tensor(param_norm) else param_norm param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric
elif metric_type == "zero_grad": elif metric_type == "zero_grad":
param_zero_grad_count = calc_zero_grad(grads) param_zero_grad_count = calc_zero_grad(grads)
param_metrics[param_name] = param_zero_grad_count param_metrics[param_name] = param_zero_grad_count
@ -396,45 +398,59 @@ def compute_param_metric(
# model parallel # model parallel
model_parallel_param_metrics = {} model_parallel_param_metrics = {}
if gpc.is_initialized(ParallelMode.MODEL): if gpc.is_initialized(ParallelMode.MODEL):
parallel_param_norms = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))] parallel_param_metrics = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
dist.all_gather_object(parallel_param_norms, param_metrics, group=gpc.get_group(ParallelMode.MODEL)) dist.all_gather_object(parallel_param_metrics, param_metrics, group=gpc.get_group(ParallelMode.MODEL))
for local_param_norm in parallel_param_norms: for local_param_metric in parallel_param_metrics:
for param_name, param_norm in local_param_norm.items(): for param_name, param_metric in local_param_metric.items():
if param_name not in model_parallel_param_metrics: if param_name not in model_parallel_param_metrics:
model_parallel_param_metrics[param_name] = 0.0 model_parallel_param_metrics[param_name] = 0.0
if metric_type == "norm" and norm_type == inf: if metric_type == "norm" and norm_type == inf:
model_parallel_param_metrics[param_name] = max(model_parallel_param_metrics[param_name], param_norm) model_parallel_param_metrics[param_name] = max(
model_parallel_param_metrics[param_name], param_metric
)
else: else:
model_parallel_param_metrics[param_name] += param_norm model_parallel_param_metrics[param_name] += param_metric
# zero parallel # zero parallel
zero_param_metrics = [None for _ in range(gpc.get_world_size(zero_mode))] zero_param_metrics = [None for _ in range(gpc.get_world_size(zero_mode))]
dist.all_gather_object(zero_param_metrics, model_parallel_param_metrics, group=gpc.get_group(zero_mode)) dist.all_gather_object(zero_param_metrics, model_parallel_param_metrics, group=gpc.get_group(zero_mode))
for local_param_norm in zero_param_metrics: for local_param_metric in zero_param_metrics:
for param_name, param_norm in local_param_norm.items(): for param_name, param_metric in local_param_metric.items():
if param_name not in total_metrics: if param_name not in total_metrics:
total_metrics[param_name] = 0.0 total_metrics[param_name] = 0.0
if metric_type == "norm" and norm_type == inf: if metric_type == "norm" and norm_type == inf:
total_metrics[param_name] = max(total_metrics[param_name], param_norm) total_metrics[param_name] = max(total_metrics[param_name], param_metric)
else: else:
total_metrics[param_name] += param_norm total_metrics[param_name] += param_metric
# moe # moe
if is_moe_group: if is_moe_group:
pg = gpc.get_group(ParallelMode.EXPERT) pg = gpc.get_group(ParallelMode.EXPERT)
scaled_param_metric = torch.cuda.FloatTensor(list(total_metrics.values()), device=get_current_device()) total_metric_values = list(total_metrics.values())
if isinstance(total_metric_values[0], torch.Tensor):
scaled_param_metric = torch.stack(total_metric_values).to(device=get_current_device())
else:
scaled_param_metric = torch.cuda.FloatTensor(total_metric_values, device=get_current_device())
scaled_param_metric = scaled_param_metric / float(gpc.get_world_size(ParallelMode.EXPERT)) scaled_param_metric = scaled_param_metric / float(gpc.get_world_size(ParallelMode.EXPERT))
dist.all_reduce(scaled_param_metric, group=pg) dist.all_reduce(scaled_param_metric, group=pg)
for i, param_name in enumerate(total_metrics.keys()): for i, param_name in enumerate(total_metrics.keys()):
total_metrics[param_name] = scaled_param_metric[i].item() total_metrics[param_name] = scaled_param_metric[i]
# calc zero grad percent
if metric_type == "zero_grad":
for param_name, param_metric in total_metrics.items():
total_metrics[param_name] = (param_metric[0] / param_metric[1]).item()
# scale norm # scale norm
if metric_type == "norm": if metric_type == "norm":
for param_name, param_norm in total_metrics.items(): for param_name, param_metric in total_metrics.items():
if param_norm in (inf, -inf): metric_value = param_metric.item()
if metric_value in (inf, -inf):
total_metrics[param_name] = -1 total_metrics[param_name] = -1
elif math.isnan(param_norm): elif math.isnan(metric_value):
total_metrics[param_name] = -2 total_metrics[param_name] = -2
else:
total_metrics[param_name] = metric_value
return total_metrics return total_metrics
@ -508,15 +524,15 @@ def compute_layer_norm(param_norms, loss_scale):
for param_name, param_norm in param_norms.items(): for param_name, param_norm in param_norms.items():
layer_name, param_key = param_name.split("-") layer_name, param_key = param_name.split("-")
if layer_name not in param_norms_groupby_layer: if param_key not in param_norms_groupby_layer:
param_norms_groupby_layer[layer_name] = {} param_norms_groupby_layer[param_key] = {}
if layer_name not in layer_norms: if layer_name not in layer_norms:
layer_norms[layer_name] = 0.0 layer_norms[layer_name] = 0.0
if param_norm not in (-1, -2): if param_norm not in (-1, -2):
param_norm = param_norm**0.5 / loss_scale param_norm = param_norm**0.5 / loss_scale
param_norms_groupby_layer[layer_name][param_key] = param_norm param_norms_groupby_layer[param_key][layer_name] = param_norm
layer_norms[layer_name] += param_norm layer_norms[layer_name] += param_norm
return layer_norms, param_norms_groupby_layer return layer_norms, param_norms_groupby_layer
@ -528,12 +544,12 @@ def compute_layer_zero_grad_count(param_zero_grad_count):
for param_name, zero_grad_count in param_zero_grad_count.items(): for param_name, zero_grad_count in param_zero_grad_count.items():
layer_name, param_key = param_name.split("-") layer_name, param_key = param_name.split("-")
if layer_name not in param_zero_grad_count_groupby_layer: if param_key not in param_zero_grad_count_groupby_layer:
param_zero_grad_count_groupby_layer[layer_name] = {} param_zero_grad_count_groupby_layer[param_key] = {}
if layer_name not in layer_zero_grad_count: if layer_name not in layer_zero_grad_count:
layer_zero_grad_count[layer_name] = 0.0 layer_zero_grad_count[layer_name] = 0.0
param_zero_grad_count_groupby_layer[layer_name][param_key] = zero_grad_count param_zero_grad_count_groupby_layer[param_key][layer_name] = zero_grad_count
layer_zero_grad_count[layer_name] += zero_grad_count layer_zero_grad_count[layer_name] += zero_grad_count
return layer_zero_grad_count, param_zero_grad_count_groupby_layer return layer_zero_grad_count, param_zero_grad_count_groupby_layer

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy
import functools import functools
import time import time
from functools import partial from functools import partial
@ -159,7 +158,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
Returns: Returns:
A tuple of (optimizer, beta2_scheduler, lr_scheduler). A tuple of (optimizer, beta2_scheduler, lr_scheduler).
""" """
if gpc.config.get("grad_norm_profiling", False): if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False):
# set the layer name as an attribute of the model parameters # set the layer name as an attribute of the model parameters
set_model_params_layer_name(model) set_model_params_layer_name(model)
@ -527,35 +526,28 @@ def record_current_batch_training_metrics(
for key, value in acc_perplex.items(): for key, value in acc_perplex.items():
infos[key] = value infos[key] = value
if gpc.config.get("grad_norm_profiling", False): if gpc.config.get("grad_norm_profiling", False) or gpc.config.get("zero_grad_profiling", False):
layer_norms = copy.deepcopy(grad_norm["layer_norms"]) layer_metrics = ["layer_norm", "layer_zero_grad"]
param_norms = copy.deepcopy(grad_norm["param_norms"]) param_metrics = ["param_norm", "param_zero_grad"]
layer_zero_grad_count = copy.deepcopy(grad_norm["layer_zero_grad"])
param_zero_grad_count = copy.deepcopy(grad_norm["param_zero_grad"]) for layer_metric_name in layer_metrics:
for group_name, value in layer_norms.items(): layer_metric = grad_norm.get(layer_metric_name, {})
if value: if layer_metric:
title = f"laye_norm/{group_name}" for group_name, value in layer_metric.items():
writer.add_scalars(key=title, value=value, step=train_state.step_count) if value:
for group_name, layer_group in param_norms.items(): title = f"{layer_metric_name}/{group_name}"
if layer_group: writer.add_scalars(key=title, value=value, step=train_state.step_count)
for layer_name, param_group in layer_group.items(): del grad_norm[layer_metric_name]
for param_name, param_value in param_group.items():
title = f"param_norm/{group_name}/{layer_name}/{param_name}" for param_metric_name in param_metrics:
writer.add_scalar(key=title, value=param_value, step=train_state.step_count) param_metric = grad_norm.get(param_metric_name, {})
for group_name, value in layer_zero_grad_count.items(): if param_metric:
if value: for group_name, layer_group in param_metric.items():
title = f"laye_zero_grad/{group_name}" if layer_group:
writer.add_scalars(key=title, value=value, step=train_state.step_count) for param_name, param_group in layer_group.items():
for group_name, layer_group in param_zero_grad_count.items(): title = f"{param_name}/{group_name}_{param_metric_name}"
if layer_group: writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
for layer_name, param_group in layer_group.items(): del grad_norm[param_metric_name]
for param_name, param_value in param_group.items():
title = f"param_zero_grad/{group_name}/{layer_name}/{param_name}"
writer.add_scalar(key=title, value=param_value, step=train_state.step_count)
del grad_norm["layer_norms"]
del grad_norm["param_norms"]
del grad_norm["layer_zero_grad"]
del grad_norm["param_zero_grad"]
line = "" line = ""
for key, value in infos.items(): for key, value in infos.items():