mirror of https://github.com/InternLM/InternLM
add zero grad count
parent
949a0a1d55
commit
e900a1e45f
|
@ -34,7 +34,13 @@ from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
from .base_optimizer import BaseOptimizer
|
from .base_optimizer import BaseOptimizer
|
||||||
from .utils import compute_layer_norm, compute_norm, compute_param_norm
|
from .utils import (
|
||||||
|
compute_layer_norm,
|
||||||
|
compute_layer_zero_grad_count,
|
||||||
|
compute_norm,
|
||||||
|
compute_param_norm,
|
||||||
|
compute_zero_grad_count,
|
||||||
|
)
|
||||||
|
|
||||||
inf = math.inf
|
inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
@ -543,6 +549,29 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
return total_param_norms
|
return total_param_norms
|
||||||
|
|
||||||
|
def _count_zero_grads_stage(
|
||||||
|
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None
|
||||||
|
):
|
||||||
|
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||||
|
|
||||||
|
total_zero_grad_count = {}
|
||||||
|
|
||||||
|
if len(params) == 0:
|
||||||
|
dtype = self.param_groups[group_id]["dtype"]
|
||||||
|
grads = [self.padding_grad.to(dtype)]
|
||||||
|
params = [self.padding_tensor.to(dtype)]
|
||||||
|
|
||||||
|
if self._clip_grad_norm > 0:
|
||||||
|
total_zero_grad_count = compute_zero_grad_count(
|
||||||
|
grads,
|
||||||
|
params,
|
||||||
|
last_stage=last_stage,
|
||||||
|
previous_zero_grad_count=previous_zero_grad_count,
|
||||||
|
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||||
|
is_moe_group=self._is_moe_group(self.optim.param_groups[group_id]),
|
||||||
|
)
|
||||||
|
return total_zero_grad_count
|
||||||
|
|
||||||
@llm_timeout(func_name="optim_step")
|
@llm_timeout(func_name="optim_step")
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
@ -571,10 +600,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# compute norm for gradients in the before bucket
|
# compute norm for gradients in the before bucket
|
||||||
groups_norms = []
|
groups_norms = []
|
||||||
groups_param_norms = []
|
groups_param_norms = []
|
||||||
|
group_param_zero_grad_count = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if gpc.config.get("grad_norm_profiling", False):
|
||||||
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
||||||
|
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
# grads in the last bucket is reduced
|
# grads in the last bucket is reduced
|
||||||
|
@ -588,6 +619,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
total_norms = {}
|
total_norms = {}
|
||||||
total_param_norms = {}
|
total_param_norms = {}
|
||||||
|
total_param_zero_grad_count = {}
|
||||||
|
total_layer_zero_grad_count = {}
|
||||||
total_layer_norms = {}
|
total_layer_norms = {}
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||||
|
@ -608,6 +641,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm(
|
total_layer_norms[group_name], total_param_norms[group_name] = compute_layer_norm(
|
||||||
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
||||||
)
|
)
|
||||||
|
zero_grad_count = self._count_zero_grads_stage(
|
||||||
|
group_id=group_id,
|
||||||
|
last_bucket=True,
|
||||||
|
last_stage=True,
|
||||||
|
previous_zero_grad_count=group_param_zero_grad_count[group_id],
|
||||||
|
)
|
||||||
|
(
|
||||||
|
total_layer_zero_grad_count[group_name],
|
||||||
|
total_param_zero_grad_count[group_name],
|
||||||
|
) = compute_layer_zero_grad_count(zero_grad_count)
|
||||||
|
|
||||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
||||||
# during allreduce
|
# during allreduce
|
||||||
|
@ -627,6 +670,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if gpc.config.get("grad_norm_profiling", False):
|
||||||
global_norms["layer_norms"] = total_layer_norms
|
global_norms["layer_norms"] = total_layer_norms
|
||||||
global_norms["param_norms"] = total_param_norms
|
global_norms["param_norms"] = total_param_norms
|
||||||
|
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
||||||
|
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
||||||
|
|
||||||
return state, global_norms
|
return state, global_norms
|
||||||
|
|
||||||
|
|
|
@ -209,6 +209,13 @@ def calc_lp(grads, norm_type):
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
|
def calc_zero_grad(grads):
|
||||||
|
zero_count = 0
|
||||||
|
for grad in grads:
|
||||||
|
zero_count += (grad == 0).sum().item()
|
||||||
|
return zero_count
|
||||||
|
|
||||||
|
|
||||||
def reduce_grads(gradients, parameters, fine_grained=False):
|
def reduce_grads(gradients, parameters, fine_grained=False):
|
||||||
parallel_grads = []
|
parallel_grads = []
|
||||||
if fine_grained:
|
if fine_grained:
|
||||||
|
@ -336,6 +343,102 @@ def compute_norm(
|
||||||
return total_norm
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_param_metric(
|
||||||
|
gradients,
|
||||||
|
parameters,
|
||||||
|
metric_type: str,
|
||||||
|
last_stage=False,
|
||||||
|
previous_param_metrics=None,
|
||||||
|
norm_type=2,
|
||||||
|
zero_mode=ParallelMode.ZERO1,
|
||||||
|
is_moe_group=False,
|
||||||
|
):
|
||||||
|
"""Get the metrics of params
|
||||||
|
Argumemts:
|
||||||
|
metric_type: (norm | zero_grad)
|
||||||
|
"""
|
||||||
|
|
||||||
|
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||||
|
total_metrics = {}
|
||||||
|
param_metrics = {}
|
||||||
|
param_grads = reduce_grads(gradients, parameters, fine_grained=True)
|
||||||
|
|
||||||
|
if metric_type == "norm":
|
||||||
|
# Norm parameters.
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
|
for param_name, grads in param_grads.items():
|
||||||
|
if metric_type == "norm":
|
||||||
|
if norm_type == inf:
|
||||||
|
param_norm = max(g.data.abs().max() for g in grads)
|
||||||
|
elif norm_type == 2.0 and enable_cuda_kernels:
|
||||||
|
param_norm = calc_l2_norm(grads) ** norm_type
|
||||||
|
else:
|
||||||
|
param_norm = calc_lp(grads, norm_type)
|
||||||
|
param_metrics[param_name] = param_norm.item() if torch.is_tensor(param_norm) else param_norm
|
||||||
|
elif metric_type == "zero_grad":
|
||||||
|
param_zero_grad_count = calc_zero_grad(grads)
|
||||||
|
param_metrics[param_name] = param_zero_grad_count
|
||||||
|
|
||||||
|
if last_stage is False:
|
||||||
|
return param_metrics
|
||||||
|
|
||||||
|
if previous_param_metrics is not None:
|
||||||
|
for key, value in previous_param_metrics.items():
|
||||||
|
if key not in param_metrics:
|
||||||
|
param_metrics[key] = value
|
||||||
|
continue
|
||||||
|
if metric_type == "norm" and norm_type == inf:
|
||||||
|
param_metrics[key] = max(param_metrics[key], value)
|
||||||
|
else:
|
||||||
|
param_metrics[key] += value
|
||||||
|
|
||||||
|
# model parallel
|
||||||
|
model_parallel_param_metrics = {}
|
||||||
|
if gpc.is_initialized(ParallelMode.MODEL):
|
||||||
|
parallel_param_norms = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
|
||||||
|
dist.all_gather_object(parallel_param_norms, param_metrics, group=gpc.get_group(ParallelMode.MODEL))
|
||||||
|
for local_param_norm in parallel_param_norms:
|
||||||
|
for param_name, param_norm in local_param_norm.items():
|
||||||
|
if param_name not in model_parallel_param_metrics:
|
||||||
|
model_parallel_param_metrics[param_name] = 0.0
|
||||||
|
if metric_type == "norm" and norm_type == inf:
|
||||||
|
model_parallel_param_metrics[param_name] = max(model_parallel_param_metrics[param_name], param_norm)
|
||||||
|
else:
|
||||||
|
model_parallel_param_metrics[param_name] += param_norm
|
||||||
|
|
||||||
|
# zero parallel
|
||||||
|
zero_param_metrics = [None for _ in range(gpc.get_world_size(zero_mode))]
|
||||||
|
dist.all_gather_object(zero_param_metrics, model_parallel_param_metrics, group=gpc.get_group(zero_mode))
|
||||||
|
for local_param_norm in zero_param_metrics:
|
||||||
|
for param_name, param_norm in local_param_norm.items():
|
||||||
|
if param_name not in total_metrics:
|
||||||
|
total_metrics[param_name] = 0.0
|
||||||
|
if metric_type == "norm" and norm_type == inf:
|
||||||
|
total_metrics[param_name] = max(total_metrics[param_name], param_norm)
|
||||||
|
else:
|
||||||
|
total_metrics[param_name] += param_norm
|
||||||
|
|
||||||
|
# moe
|
||||||
|
if is_moe_group:
|
||||||
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
|
scaled_param_metric = torch.cuda.FloatTensor(list(total_metrics.values()), device=get_current_device())
|
||||||
|
scaled_param_metric = scaled_param_metric / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||||
|
dist.all_reduce(scaled_param_metric, group=pg)
|
||||||
|
for i, param_name in enumerate(total_metrics.keys()):
|
||||||
|
total_metrics[param_name] = scaled_param_metric[i].item()
|
||||||
|
|
||||||
|
# scale norm
|
||||||
|
if metric_type == "norm":
|
||||||
|
for param_name, param_norm in total_metrics.items():
|
||||||
|
if param_norm in (inf, -inf):
|
||||||
|
total_metrics[param_name] = -1
|
||||||
|
elif math.isnan(param_norm):
|
||||||
|
total_metrics[param_name] = -2
|
||||||
|
|
||||||
|
return total_metrics
|
||||||
|
|
||||||
|
|
||||||
def compute_param_norm(
|
def compute_param_norm(
|
||||||
gradients,
|
gradients,
|
||||||
parameters,
|
parameters,
|
||||||
|
@ -355,80 +458,45 @@ def compute_param_norm(
|
||||||
Returns:
|
Returns:
|
||||||
The norm of the parameters.
|
The norm of the parameters.
|
||||||
"""
|
"""
|
||||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
|
||||||
# Norm parameters.
|
|
||||||
norm_type = float(norm_type)
|
|
||||||
total_param_norms = {}
|
|
||||||
|
|
||||||
param_grads = reduce_grads(gradients, parameters, fine_grained=True)
|
return compute_param_metric(
|
||||||
|
gradients,
|
||||||
|
parameters,
|
||||||
|
metric_type="norm",
|
||||||
|
last_stage=last_stage,
|
||||||
|
previous_param_metrics=previous_param_norms,
|
||||||
|
norm_type=norm_type,
|
||||||
|
zero_mode=zero_mode,
|
||||||
|
is_moe_group=is_moe_group,
|
||||||
|
)
|
||||||
|
|
||||||
param_norms = {}
|
|
||||||
for param_name, grads in param_grads.items():
|
|
||||||
if norm_type == inf:
|
|
||||||
param_norm = max(g.data.abs().max() for g in grads)
|
|
||||||
elif norm_type == 2.0 and enable_cuda_kernels:
|
|
||||||
param_norm = calc_l2_norm(grads) ** norm_type
|
|
||||||
else:
|
|
||||||
param_norm = calc_lp(grads, norm_type)
|
|
||||||
param_norms[param_name] = param_norm.item() if torch.is_tensor(param_norm) else param_norm
|
|
||||||
|
|
||||||
if last_stage is False:
|
def compute_zero_grad_count(
|
||||||
return param_norms
|
gradients,
|
||||||
|
parameters,
|
||||||
|
last_stage=False,
|
||||||
|
previous_zero_grad_count=None,
|
||||||
|
zero_mode=ParallelMode.ZERO1,
|
||||||
|
is_moe_group=False,
|
||||||
|
):
|
||||||
|
"""Get the count of zero gradient for each parameters
|
||||||
|
Arguments:
|
||||||
|
gradients (Iterable[Tensor]): The gradient value.
|
||||||
|
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||||
|
|
||||||
if previous_param_norms is not None:
|
Returns:
|
||||||
for key, value in previous_param_norms.items():
|
The count of zero gradient for each parameters
|
||||||
if key not in param_norms:
|
"""
|
||||||
param_norms[key] = value
|
|
||||||
continue
|
|
||||||
|
|
||||||
if norm_type == inf:
|
return compute_param_metric(
|
||||||
param_norms[key] = max(param_norms[key], value)
|
gradients,
|
||||||
else:
|
parameters,
|
||||||
param_norms[key] += value
|
metric_type="zero_grad",
|
||||||
|
last_stage=last_stage,
|
||||||
# model parallel
|
previous_param_metrics=previous_zero_grad_count,
|
||||||
model_parallel_param_norms = {}
|
zero_mode=zero_mode,
|
||||||
if gpc.is_initialized(ParallelMode.MODEL):
|
is_moe_group=is_moe_group,
|
||||||
parallel_param_norms = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))]
|
)
|
||||||
dist.all_gather_object(parallel_param_norms, param_norms, group=gpc.get_group(ParallelMode.MODEL))
|
|
||||||
for local_param_norm in parallel_param_norms:
|
|
||||||
for param_name, param_norm in local_param_norm.items():
|
|
||||||
if param_name not in model_parallel_param_norms:
|
|
||||||
model_parallel_param_norms[param_name] = 0.0
|
|
||||||
if norm_type == inf:
|
|
||||||
model_parallel_param_norms[param_name] = max(model_parallel_param_norms[param_name], param_norm)
|
|
||||||
else:
|
|
||||||
model_parallel_param_norms[param_name] += param_norm
|
|
||||||
|
|
||||||
# zero parallel
|
|
||||||
zero_param_norms = [None for _ in range(gpc.get_world_size(zero_mode))]
|
|
||||||
dist.all_gather_object(zero_param_norms, model_parallel_param_norms, group=gpc.get_group(zero_mode))
|
|
||||||
for local_param_norm in zero_param_norms:
|
|
||||||
for param_name, param_norm in local_param_norm.items():
|
|
||||||
if param_name not in total_param_norms:
|
|
||||||
total_param_norms[param_name] = 0.0
|
|
||||||
if norm_type == inf:
|
|
||||||
total_param_norms[param_name] = max(total_param_norms[param_name], param_norm)
|
|
||||||
else:
|
|
||||||
total_param_norms[param_name] += param_norm
|
|
||||||
|
|
||||||
# moe
|
|
||||||
if is_moe_group:
|
|
||||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
|
||||||
scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device())
|
|
||||||
scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
|
|
||||||
dist.all_reduce(scaled_param_norm, group=pg)
|
|
||||||
for i, param_name in enumerate(total_param_norms.keys()):
|
|
||||||
total_param_norms[param_name] = scaled_param_norm[i].item()
|
|
||||||
|
|
||||||
# scale
|
|
||||||
for param_name, param_norm in total_param_norms.items():
|
|
||||||
if param_norm in (inf, -inf):
|
|
||||||
total_param_norms[param_name] = -1
|
|
||||||
elif math.isnan(param_norm):
|
|
||||||
total_param_norms[param_name] = -2
|
|
||||||
|
|
||||||
return total_param_norms
|
|
||||||
|
|
||||||
|
|
||||||
def compute_layer_norm(param_norms, loss_scale):
|
def compute_layer_norm(param_norms, loss_scale):
|
||||||
|
@ -454,6 +522,23 @@ def compute_layer_norm(param_norms, loss_scale):
|
||||||
return layer_norms, param_norms_groupby_layer
|
return layer_norms, param_norms_groupby_layer
|
||||||
|
|
||||||
|
|
||||||
|
def compute_layer_zero_grad_count(param_zero_grad_count):
|
||||||
|
param_zero_grad_count_groupby_layer = {}
|
||||||
|
layer_zero_grad_count = {}
|
||||||
|
|
||||||
|
for param_name, zero_grad_count in param_zero_grad_count.items():
|
||||||
|
layer_name, param_key = param_name.split("-")
|
||||||
|
if layer_name not in param_zero_grad_count_groupby_layer:
|
||||||
|
param_zero_grad_count_groupby_layer[layer_name] = {}
|
||||||
|
if layer_name not in layer_zero_grad_count:
|
||||||
|
layer_zero_grad_count[layer_name] = 0.0
|
||||||
|
|
||||||
|
param_zero_grad_count_groupby_layer[layer_name][param_key] = zero_grad_count
|
||||||
|
layer_zero_grad_count[layer_name] += zero_grad_count
|
||||||
|
|
||||||
|
return layer_zero_grad_count, param_zero_grad_count_groupby_layer
|
||||||
|
|
||||||
|
|
||||||
class BaseGradScaler(ABC):
|
class BaseGradScaler(ABC):
|
||||||
"""A base class for the gradient scaler.
|
"""A base class for the gradient scaler.
|
||||||
|
|
||||||
|
|
|
@ -530,17 +530,30 @@ def record_current_batch_training_metrics(
|
||||||
if gpc.config.get("grad_norm_profiling", False):
|
if gpc.config.get("grad_norm_profiling", False):
|
||||||
layer_norms = copy.deepcopy(grad_norm["layer_norms"])
|
layer_norms = copy.deepcopy(grad_norm["layer_norms"])
|
||||||
param_norms = copy.deepcopy(grad_norm["param_norms"])
|
param_norms = copy.deepcopy(grad_norm["param_norms"])
|
||||||
|
layer_zero_grad_count = copy.deepcopy(grad_norm["layer_zero_grad"])
|
||||||
|
param_zero_grad_count = copy.deepcopy(grad_norm["param_zero_grad"])
|
||||||
for group_name, value in layer_norms.items():
|
for group_name, value in layer_norms.items():
|
||||||
if value:
|
if value:
|
||||||
title = f"laye_norm_group_{group_name}"
|
title = f"laye_norm/{group_name}"
|
||||||
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
||||||
for group_name, layer_group in param_norms.items():
|
for group_name, layer_group in param_norms.items():
|
||||||
if layer_group:
|
if layer_group:
|
||||||
for layer_name, param_group in layer_group.items():
|
for layer_name, param_group in layer_group.items():
|
||||||
title = f"param_norm_{layer_name}_{group_name}"
|
title = f"param_norm/{group_name}/{layer_name}"
|
||||||
|
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
||||||
|
for group_name, value in layer_zero_grad_count.items():
|
||||||
|
if value:
|
||||||
|
title = f"laye_zero_grad/{group_name}"
|
||||||
|
writer.add_scalars(key=title, value=value, step=train_state.step_count)
|
||||||
|
for group_name, layer_group in param_zero_grad_count.items():
|
||||||
|
if layer_group:
|
||||||
|
for layer_name, param_group in layer_group.items():
|
||||||
|
title = f"param_zero_grad/{group_name}/{layer_name}"
|
||||||
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
||||||
del grad_norm["layer_norms"]
|
del grad_norm["layer_norms"]
|
||||||
del grad_norm["param_norms"]
|
del grad_norm["param_norms"]
|
||||||
|
del grad_norm["layer_zero_grad"]
|
||||||
|
del grad_norm["param_zero_grad"]
|
||||||
|
|
||||||
line = ""
|
line = ""
|
||||||
for key, value in infos.items():
|
for key, value in infos.items():
|
||||||
|
|
Loading…
Reference in New Issue