feat(grad_norm): vocab grad norm profiling (#519)

* compute vocab grad norm && save pt

* add grad_norm profiling interval && refactor save grad norm

* fix ci test_pipeline
pull/530/head
jiaopenglong 2023-12-06 13:52:42 +08:00 committed by GitHub
parent 9fc252f40e
commit 112c34ae09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 197 additions and 81 deletions

View File

@ -356,6 +356,9 @@ def args_sanity_check():
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
) )
if "batch_count" not in gpc.config:
gpc.config._add_item("batch_count", 0)
if "moe_loss_coeff" not in gpc.config.loss: if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0) gpc.config.loss._add_item("moe_loss_coeff", 1.0)

View File

@ -39,6 +39,7 @@ from .utils import (
compute_layer_zero_grad_count, compute_layer_zero_grad_count,
compute_norm, compute_norm,
compute_param_norm, compute_param_norm,
compute_vocab_grad_norm,
compute_zero_grad_count, compute_zero_grad_count,
) )
@ -563,6 +564,28 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
return total_param_norms return total_param_norms
def _compute_vocab_grad_norm_stage(
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_vocab_grad_norm=None
):
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
dtype = self.param_groups[group_id]["dtype"]
grads = [self.padding_grad.to(dtype)]
params = [self.padding_tensor.to(dtype)]
vocab_grad_norm = None
if self._clip_grad_norm > 0:
vocab_grad_norm = compute_vocab_grad_norm(
grads,
params,
last_stage=last_stage,
previous_vocab_grad_norm=previous_vocab_grad_norm,
zero_mode=self._broadcast_parallel_mode[group_id],
)
return vocab_grad_norm
def _count_zero_grads_stage( def _count_zero_grads_stage(
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None
): ):
@ -615,12 +638,19 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms = [] groups_norms = []
groups_param_norms = [] groups_param_norms = []
group_param_zero_grad_count = [] group_param_zero_grad_count = []
group_vocab_norms = []
batch_count = gpc.config.get("batch_count")
interval_steps = grad_profiling_config.get("interval_steps", 1)
is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
if grad_profiling_config.get("zero_grad_profiling", False): if grad_profiling_config.get("zero_grad_profiling", False):
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
# grads in the last bucket is reduced # grads in the last bucket is reduced
@ -637,6 +667,7 @@ class HybridZeroOptimizer(BaseOptimizer):
total_layer_grad_norms = {} total_layer_grad_norms = {}
total_param_zero_grad_count = {} total_param_zero_grad_count = {}
total_layer_zero_grad_count = {} total_layer_zero_grad_count = {}
total_vocab_grad_norms = {}
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}" group_name = f"{group_id}_{group_name}"
@ -646,6 +677,7 @@ class HybridZeroOptimizer(BaseOptimizer):
last_stage=True, last_stage=True,
previous_norm=groups_norms[group_id], previous_norm=groups_norms[group_id],
) )
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
param_norms = self._compute_param_norm_stage( param_norms = self._compute_param_norm_stage(
group_id=group_id, group_id=group_id,
@ -667,18 +699,34 @@ class HybridZeroOptimizer(BaseOptimizer):
total_layer_zero_grad_count[group_name], total_layer_zero_grad_count[group_name],
total_param_zero_grad_count[group_name], total_param_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count) ) = compute_layer_zero_grad_count(zero_grad_count)
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_vocab_grad_norm=group_vocab_norms[group_id],
)
inf_mask = vocab_grad_norms == -1
nan_mask = vocab_grad_norms == -2
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
vocab_grad_norms[inf_mask] = -1
vocab_grad_norms[nan_mask] = -2
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms)
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms global_norms["layer_grad_norm"] = total_layer_grad_norms
global_norms["param_grad_norm"] = total_param_grad_norms global_norms["param_grad_norm"] = total_param_grad_norms
if grad_profiling_config.get("zero_grad_profiling", False): if grad_profiling_config.get("zero_grad_profiling", False):
global_norms["layer_zero_grad"] = total_layer_zero_grad_count global_norms["layer_zero_grad"] = total_layer_zero_grad_count
global_norms["param_zero_grad"] = total_param_zero_grad_count global_norms["param_zero_grad"] = total_param_zero_grad_count
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
return state, global_norms return state, global_norms

View File

@ -218,7 +218,17 @@ def calc_zero_grad(grads):
return torch.tensor([zero_count, grad_size]) return torch.tensor([zero_count, grad_size])
def reduce_grads(gradients, parameters, fine_grained=False): def get_norm(grads, norm_type, enable_cuda_kernels):
if norm_type == inf:
grad_norm = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels:
grad_norm = calc_l2_norm(grads) ** norm_type
else:
grad_norm = calc_lp(grads, norm_type)
return grad_norm
def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
parallel_grads = [] parallel_grads = []
if fine_grained: if fine_grained:
parallel_grads = {} parallel_grads = {}
@ -229,6 +239,14 @@ def reduce_grads(gradients, parameters, fine_grained=False):
if param_name not in parallel_grads: if param_name not in parallel_grads:
parallel_grads[param_name] = [] parallel_grads[param_name] = []
parallel_grads[param_name].append(g.data.float()) parallel_grads[param_name].append(g.data.float())
elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
if (
gpc.config.model["vocab_size"] == g.shape[0]
and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
parallel_grads.append(g.data.float())
else: else:
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())
@ -306,10 +324,7 @@ def compute_norm(
else: else:
tensor_parallel_grads = reduce_grads(gradients, parameters) tensor_parallel_grads = reduce_grads(gradients, parameters)
if norm_type == 2.0 and enable_cuda_kernels: tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels)
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor. # If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
@ -359,6 +374,59 @@ def compute_norm(
return total_norm return total_norm
def compute_vocab_grad_norm(
gradients,
parameters,
last_stage=False,
previous_vocab_grad_norm=None,
norm_type=2,
zero_mode=ParallelMode.ZERO1,
):
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
vocab_size = gpc.config.model["vocab_size"]
param_grads = reduce_grads(gradients, parameters, only_output=True)
vocab_grad_norm = torch.zeros((vocab_size,), dtype=torch.float32).to(get_current_device())
if param_grads:
for grad in param_grads:
# get grad norm of each vocab
for i in range(vocab_size):
cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0]
vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True)
if last_stage is False:
return vocab_grad_norm
if previous_vocab_grad_norm is not None:
vocab_grad_norm = vocab_grad_norm + previous_vocab_grad_norm
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(
vocab_grad_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
dist.all_reduce(vocab_grad_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
if zero_mode == ParallelMode.EXPERT_DATA:
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = vocab_grad_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
vocab_grad_norm = scaled_norm_tensor.item()
# Scale.
vocab_grad_norm[vocab_grad_norm == float("inf")] = -1
vocab_grad_norm[vocab_grad_norm == -float("inf")] = -1
vocab_grad_norm[torch.isnan(vocab_grad_norm)] = -2
return vocab_grad_norm
def compute_param_metric( def compute_param_metric(
gradients, gradients,
parameters, parameters,
@ -384,12 +452,7 @@ def compute_param_metric(
for param_name, grads in param_grads.items(): for param_name, grads in param_grads.items():
if metric_type == "norm": if metric_type == "norm":
if norm_type == inf: param_metric = get_norm(grads, norm_type, enable_cuda_kernels)
param_metric = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels:
param_metric = calc_l2_norm(grads) ** norm_type
else:
param_metric = calc_lp(grads, norm_type)
param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric
elif metric_type == "zero_grad": elif metric_type == "zero_grad":
param_zero_grad_count = calc_zero_grad(grads) param_zero_grad_count = calc_zero_grad(grads)

View File

@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import functools import functools
import os
import pickle
import time import time
from functools import partial from functools import partial
from typing import Callable, Iterable, Optional, Union from typing import Callable, Iterable, Optional, Union
@ -49,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile, launch_time
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (
@ -164,8 +166,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
A tuple of (optimizer, beta2_scheduler, lr_scheduler). A tuple of (optimizer, beta2_scheduler, lr_scheduler).
""" """
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( if (
"zero_grad_profiling", False grad_profiling_config.get("grad_norm_profiling", False)
or grad_profiling_config.get("zero_grad_profiling", False)
or grad_profiling_config.get("vocab_grad_norm_profiling", False)
): ):
# set the layer name as an attribute of the model parameters # set the layer name as an attribute of the model parameters
set_model_params_layer_name(model) set_model_params_layer_name(model)
@ -522,43 +526,40 @@ def record_current_batch_training_metrics(
infos[key] = value infos[key] = value
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( interval_steps = grad_profiling_config.get("interval_steps", 1)
"zero_grad_profiling", False if batch_count % interval_steps == 0:
): layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
layer_metrics = ["layer_grad_norm", "layer_zero_grad"] param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
param_metrics = ["param_grad_norm", "param_zero_grad"]
layer_names = grad_profiling_config.get("layers", []) layer_names = grad_profiling_config.get("layers", [])
for layer_metric_name in layer_metrics: for metric_name in layer_metrics:
layer_metric = grad_norm.get(layer_metric_name, {}) metric = grad_norm.get(metric_name)
if layer_metric: for group_name, layer_group in metric.items():
for group_name, layer_group in layer_metric.items(): title = f"{metric_name}/{group_name}"
if layer_group: metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
title = f"{layer_metric_name}/{group_name}" if metrics:
if layer_names: writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
filter_layer_metrics = {} del grad_norm[metric_name]
for layer_name, metric_value in layer_group.items(): for metric_name in param_metrics:
if layer_name in layer_names: metric = grad_norm.get(metric_name)
filter_layer_metrics[layer_name] = metric_value for group_name, layer_group in metric.items():
writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count)
else:
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
del grad_norm[layer_metric_name]
for param_metric_name in param_metrics:
param_metric = grad_norm.get(param_metric_name, {})
if param_metric:
for group_name, layer_group in param_metric.items():
for param_name, param_group in layer_group.items(): for param_name, param_group in layer_group.items():
title = f"{param_name}/{group_name}_{param_metric_name}" title = f"{param_name}/{group_name}_{metric_name}"
if layer_names: metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
filter_param_group = {} if metrics:
for layer_name, metric_value in param_group.items(): writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
if layer_name in layer_names: del grad_norm[metric_name]
filter_param_group[layer_name] = param_group[layer_name] if grad_profiling_config.get("vocab_grad_norm_profiling", False):
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
else: os.makedirs(local_save_path, exist_ok=True)
writer.add_scalars(key=title, value=param_group, step=train_state.step_count) local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
del grad_norm[param_metric_name] vocab_grad_norms = grad_norm.get("vocab_grad_norm")
if vocab_grad_norms:
try:
with open(local_save_file, "ab+") as vocab_f:
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
except IOError as e:
logger.warning(f"Error saving vocab_grad_norm: {e}")
del grad_norm["vocab_grad_norm"]
line = "" line = ""
for key, value in infos.items(): for key, value in infos.items():

View File

@ -197,6 +197,7 @@ def main(args):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
gpc.config.batch_count = batch_count
# load batch data # load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)