mirror of https://github.com/InternLM/InternLM
feat(grad_norm): vocab grad norm profiling (#519)
* compute vocab grad norm && save pt * add grad_norm profiling interval && refactor save grad norm * fix ci test_pipelinepull/530/head
parent
9fc252f40e
commit
112c34ae09
|
@ -356,6 +356,9 @@ def args_sanity_check():
|
|||
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
|
||||
)
|
||||
|
||||
if "batch_count" not in gpc.config:
|
||||
gpc.config._add_item("batch_count", 0)
|
||||
|
||||
if "moe_loss_coeff" not in gpc.config.loss:
|
||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ from .utils import (
|
|||
compute_layer_zero_grad_count,
|
||||
compute_norm,
|
||||
compute_param_norm,
|
||||
compute_vocab_grad_norm,
|
||||
compute_zero_grad_count,
|
||||
)
|
||||
|
||||
|
@ -563,6 +564,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
return total_param_norms
|
||||
|
||||
def _compute_vocab_grad_norm_stage(
|
||||
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_vocab_grad_norm=None
|
||||
):
|
||||
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||
if len(params) == 0:
|
||||
dtype = self.param_groups[group_id]["dtype"]
|
||||
grads = [self.padding_grad.to(dtype)]
|
||||
params = [self.padding_tensor.to(dtype)]
|
||||
|
||||
vocab_grad_norm = None
|
||||
|
||||
if self._clip_grad_norm > 0:
|
||||
vocab_grad_norm = compute_vocab_grad_norm(
|
||||
grads,
|
||||
params,
|
||||
last_stage=last_stage,
|
||||
previous_vocab_grad_norm=previous_vocab_grad_norm,
|
||||
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||
)
|
||||
|
||||
return vocab_grad_norm
|
||||
|
||||
def _count_zero_grads_stage(
|
||||
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None
|
||||
):
|
||||
|
@ -615,12 +638,19 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_norms = []
|
||||
groups_param_norms = []
|
||||
group_param_zero_grad_count = []
|
||||
group_vocab_norms = []
|
||||
batch_count = gpc.config.get("batch_count")
|
||||
interval_steps = grad_profiling_config.get("interval_steps", 1)
|
||||
is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False
|
||||
for group_id in range(self.num_param_groups):
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
# grads in the last bucket is reduced
|
||||
|
@ -637,6 +667,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
total_layer_grad_norms = {}
|
||||
total_param_zero_grad_count = {}
|
||||
total_layer_zero_grad_count = {}
|
||||
total_vocab_grad_norms = {}
|
||||
for group_id in range(self.num_param_groups):
|
||||
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||
group_name = f"{group_id}_{group_name}"
|
||||
|
@ -646,6 +677,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
param_norms = self._compute_param_norm_stage(
|
||||
group_id=group_id,
|
||||
|
@ -667,18 +699,34 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
total_layer_zero_grad_count[group_name],
|
||||
total_param_zero_grad_count[group_name],
|
||||
) = compute_layer_zero_grad_count(zero_grad_count)
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_vocab_grad_norm=group_vocab_norms[group_id],
|
||||
)
|
||||
inf_mask = vocab_grad_norms == -1
|
||||
nan_mask = vocab_grad_norms == -2
|
||||
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
|
||||
vocab_grad_norms[inf_mask] = -1
|
||||
vocab_grad_norms[nan_mask] = -2
|
||||
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
state, global_norms = self._step(closure=closure, norms=total_norms)
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
global_norms["layer_grad_norm"] = total_layer_grad_norms
|
||||
global_norms["param_grad_norm"] = total_param_grad_norms
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
||||
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
|
||||
|
||||
return state, global_norms
|
||||
|
||||
|
|
|
@ -218,7 +218,17 @@ def calc_zero_grad(grads):
|
|||
return torch.tensor([zero_count, grad_size])
|
||||
|
||||
|
||||
def reduce_grads(gradients, parameters, fine_grained=False):
|
||||
def get_norm(grads, norm_type, enable_cuda_kernels):
|
||||
if norm_type == inf:
|
||||
grad_norm = max(g.data.abs().max() for g in grads)
|
||||
elif norm_type == 2.0 and enable_cuda_kernels:
|
||||
grad_norm = calc_l2_norm(grads) ** norm_type
|
||||
else:
|
||||
grad_norm = calc_lp(grads, norm_type)
|
||||
return grad_norm
|
||||
|
||||
|
||||
def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
|
||||
parallel_grads = []
|
||||
if fine_grained:
|
||||
parallel_grads = {}
|
||||
|
@ -229,6 +239,14 @@ def reduce_grads(gradients, parameters, fine_grained=False):
|
|||
if param_name not in parallel_grads:
|
||||
parallel_grads[param_name] = []
|
||||
parallel_grads[param_name].append(g.data.float())
|
||||
elif only_output:
|
||||
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
|
||||
if (
|
||||
gpc.config.model["vocab_size"] == g.shape[0]
|
||||
and gpc.config.model["hidden_size"] == g.shape[1]
|
||||
and "embedding" not in param_name.lower()
|
||||
):
|
||||
parallel_grads.append(g.data.float())
|
||||
else:
|
||||
parallel_grads.append(g.data.float())
|
||||
|
||||
|
@ -306,10 +324,7 @@ def compute_norm(
|
|||
else:
|
||||
tensor_parallel_grads = reduce_grads(gradients, parameters)
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
|
@ -359,6 +374,59 @@ def compute_norm(
|
|||
return total_norm
|
||||
|
||||
|
||||
def compute_vocab_grad_norm(
|
||||
gradients,
|
||||
parameters,
|
||||
last_stage=False,
|
||||
previous_vocab_grad_norm=None,
|
||||
norm_type=2,
|
||||
zero_mode=ParallelMode.ZERO1,
|
||||
):
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
vocab_size = gpc.config.model["vocab_size"]
|
||||
|
||||
param_grads = reduce_grads(gradients, parameters, only_output=True)
|
||||
|
||||
vocab_grad_norm = torch.zeros((vocab_size,), dtype=torch.float32).to(get_current_device())
|
||||
if param_grads:
|
||||
for grad in param_grads:
|
||||
# get grad norm of each vocab
|
||||
for i in range(vocab_size):
|
||||
cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0]
|
||||
vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True)
|
||||
|
||||
if last_stage is False:
|
||||
return vocab_grad_norm
|
||||
|
||||
if previous_vocab_grad_norm is not None:
|
||||
vocab_grad_norm = vocab_grad_norm + previous_vocab_grad_norm
|
||||
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(
|
||||
vocab_grad_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
|
||||
dist.all_reduce(vocab_grad_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
||||
|
||||
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_norm = vocab_grad_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
vocab_grad_norm = scaled_norm_tensor.item()
|
||||
|
||||
# Scale.
|
||||
vocab_grad_norm[vocab_grad_norm == float("inf")] = -1
|
||||
vocab_grad_norm[vocab_grad_norm == -float("inf")] = -1
|
||||
vocab_grad_norm[torch.isnan(vocab_grad_norm)] = -2
|
||||
|
||||
return vocab_grad_norm
|
||||
|
||||
|
||||
def compute_param_metric(
|
||||
gradients,
|
||||
parameters,
|
||||
|
@ -384,12 +452,7 @@ def compute_param_metric(
|
|||
|
||||
for param_name, grads in param_grads.items():
|
||||
if metric_type == "norm":
|
||||
if norm_type == inf:
|
||||
param_metric = max(g.data.abs().max() for g in grads)
|
||||
elif norm_type == 2.0 and enable_cuda_kernels:
|
||||
param_metric = calc_l2_norm(grads) ** norm_type
|
||||
else:
|
||||
param_metric = calc_lp(grads, norm_type)
|
||||
param_metric = get_norm(grads, norm_type, enable_cuda_kernels)
|
||||
param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric
|
||||
elif metric_type == "zero_grad":
|
||||
param_zero_grad_count = calc_zero_grad(grads)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
|
@ -49,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
|||
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.common import DummyProfile, launch_time
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
|
@ -164,8 +166,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
|
||||
"zero_grad_profiling", False
|
||||
if (
|
||||
grad_profiling_config.get("grad_norm_profiling", False)
|
||||
or grad_profiling_config.get("zero_grad_profiling", False)
|
||||
or grad_profiling_config.get("vocab_grad_norm_profiling", False)
|
||||
):
|
||||
# set the layer name as an attribute of the model parameters
|
||||
set_model_params_layer_name(model)
|
||||
|
@ -522,43 +526,40 @@ def record_current_batch_training_metrics(
|
|||
infos[key] = value
|
||||
|
||||
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
|
||||
"zero_grad_profiling", False
|
||||
):
|
||||
layer_metrics = ["layer_grad_norm", "layer_zero_grad"]
|
||||
param_metrics = ["param_grad_norm", "param_zero_grad"]
|
||||
interval_steps = grad_profiling_config.get("interval_steps", 1)
|
||||
if batch_count % interval_steps == 0:
|
||||
layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
|
||||
param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
|
||||
layer_names = grad_profiling_config.get("layers", [])
|
||||
for layer_metric_name in layer_metrics:
|
||||
layer_metric = grad_norm.get(layer_metric_name, {})
|
||||
if layer_metric:
|
||||
for group_name, layer_group in layer_metric.items():
|
||||
if layer_group:
|
||||
title = f"{layer_metric_name}/{group_name}"
|
||||
if layer_names:
|
||||
filter_layer_metrics = {}
|
||||
for layer_name, metric_value in layer_group.items():
|
||||
if layer_name in layer_names:
|
||||
filter_layer_metrics[layer_name] = metric_value
|
||||
writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count)
|
||||
else:
|
||||
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
|
||||
del grad_norm[layer_metric_name]
|
||||
|
||||
for param_metric_name in param_metrics:
|
||||
param_metric = grad_norm.get(param_metric_name, {})
|
||||
if param_metric:
|
||||
for group_name, layer_group in param_metric.items():
|
||||
for metric_name in layer_metrics:
|
||||
metric = grad_norm.get(metric_name)
|
||||
for group_name, layer_group in metric.items():
|
||||
title = f"{metric_name}/{group_name}"
|
||||
metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
|
||||
if metrics:
|
||||
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
|
||||
del grad_norm[metric_name]
|
||||
for metric_name in param_metrics:
|
||||
metric = grad_norm.get(metric_name)
|
||||
for group_name, layer_group in metric.items():
|
||||
for param_name, param_group in layer_group.items():
|
||||
title = f"{param_name}/{group_name}_{param_metric_name}"
|
||||
if layer_names:
|
||||
filter_param_group = {}
|
||||
for layer_name, metric_value in param_group.items():
|
||||
if layer_name in layer_names:
|
||||
filter_param_group[layer_name] = param_group[layer_name]
|
||||
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
|
||||
else:
|
||||
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
||||
del grad_norm[param_metric_name]
|
||||
title = f"{param_name}/{group_name}_{metric_name}"
|
||||
metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
|
||||
if metrics:
|
||||
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
|
||||
del grad_norm[metric_name]
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
|
||||
os.makedirs(local_save_path, exist_ok=True)
|
||||
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
|
||||
vocab_grad_norms = grad_norm.get("vocab_grad_norm")
|
||||
if vocab_grad_norms:
|
||||
try:
|
||||
with open(local_save_file, "ab+") as vocab_f:
|
||||
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
|
||||
except IOError as e:
|
||||
logger.warning(f"Error saving vocab_grad_norm: {e}")
|
||||
del grad_norm["vocab_grad_norm"]
|
||||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
|
|
1
train.py
1
train.py
|
@ -197,6 +197,7 @@ def main(args):
|
|||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
gpc.config.batch_count = batch_count
|
||||
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
|
|
Loading…
Reference in New Issue