add grad_norm profiling interval && refactor save grad norm

pull/519/head
JiaoPL 2023-11-28 20:41:29 +08:00
parent 4eed07a3c3
commit 83ebebd5bc
5 changed files with 95 additions and 99 deletions

View File

@ -355,6 +355,9 @@ def args_sanity_check():
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
) )
if "batch_count" not in gpc.config:
gpc.config._add_item("batch_count", 0)
if "moe_loss_coeff" not in gpc.config.loss: if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0) gpc.config.loss._add_item("moe_loss_coeff", 1.0)

View File

@ -639,14 +639,18 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_param_norms = [] groups_param_norms = []
group_param_zero_grad_count = [] group_param_zero_grad_count = []
group_vocab_norms = [] group_vocab_norms = []
batch_count = gpc.config.batch_count
interval_steps = grad_profiling_config.get("interval_steps", 1)
is_profiling = batch_count % interval_steps == 0
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) if grad_profiling_config.get("grad_norm_profiling", False):
if grad_profiling_config.get("zero_grad_profiling", False): groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) if grad_profiling_config.get("zero_grad_profiling", False):
if grad_profiling_config.get("vocab_grad_norm_profiling", False): group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id)) if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
# grads in the last bucket is reduced # grads in the last bucket is reduced
@ -673,54 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer):
last_stage=True, last_stage=True,
previous_norm=groups_norms[group_id], previous_norm=groups_norms[group_id],
) )
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
param_norms = self._compute_param_norm_stage( if grad_profiling_config.get("grad_norm_profiling", False):
group_id=group_id, param_norms = self._compute_param_norm_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_param_norms=groups_param_norms[group_id], last_stage=True,
) previous_param_norms=groups_param_norms[group_id],
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( )
param_norms=param_norms, loss_scale=self.loss_scale.item() total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
) param_norms=param_norms, loss_scale=self.loss_scale.item()
if grad_profiling_config.get("zero_grad_profiling", False): )
zero_grad_count = self._count_zero_grads_stage( if grad_profiling_config.get("zero_grad_profiling", False):
group_id=group_id, zero_grad_count = self._count_zero_grads_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_zero_grad_count=group_param_zero_grad_count[group_id], last_stage=True,
) previous_zero_grad_count=group_param_zero_grad_count[group_id],
( )
total_layer_zero_grad_count[group_name], (
total_param_zero_grad_count[group_name], total_layer_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count) total_param_zero_grad_count[group_name],
if grad_profiling_config.get("vocab_grad_norm_profiling", False): ) = compute_layer_zero_grad_count(zero_grad_count)
vocab_grad_norms = self._compute_vocab_grad_norm_stage( if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_id=group_id, vocab_grad_norms = self._compute_vocab_grad_norm_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_vocab_grad_norm=group_vocab_norms[group_id], last_stage=True,
) previous_vocab_grad_norm=group_vocab_norms[group_id],
inf_mask = vocab_grad_norms == -1 )
nan_mask = vocab_grad_norms == -2 inf_mask = vocab_grad_norms == -1
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item() nan_mask = vocab_grad_norms == -2
vocab_grad_norms[inf_mask] = -1 vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
vocab_grad_norms[nan_mask] = -2 vocab_grad_norms[inf_mask] = -1
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu") vocab_grad_norms[nan_mask] = -2
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms)
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
global_norms["layer_grad_norm"] = total_layer_grad_norms if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["param_grad_norm"] = total_param_grad_norms global_norms["layer_grad_norm"] = total_layer_grad_norms
if grad_profiling_config.get("zero_grad_profiling", False): global_norms["param_grad_norm"] = total_param_grad_norms
global_norms["layer_zero_grad"] = total_layer_zero_grad_count if grad_profiling_config.get("zero_grad_profiling", False):
global_norms["param_zero_grad"] = total_param_zero_grad_count global_norms["layer_zero_grad"] = total_layer_zero_grad_count
if grad_profiling_config.get("vocab_grad_norm_profiling", False): global_norms["param_zero_grad"] = total_param_zero_grad_count
global_norms["vocab_grad_norm"] = total_vocab_grad_norms if grad_profiling_config.get("vocab_grad_norm_profiling", False):
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
return state, global_norms return state, global_norms

View File

@ -241,9 +241,11 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
parallel_grads[param_name].append(g.data.float()) parallel_grads[param_name].append(g.data.float())
elif only_output: elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
grad_profiling_config = gpc.config.get("grad_profiling", {}) if (
layer_names = grad_profiling_config.get("layers", []) gpc.config.model["vocab_size"] == g.shape[0]
if gpc.config.model["vocab_size"] == g.shape[0] and param_name.split("-")[0] in layer_names: and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())
else: else:
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())

View File

@ -521,55 +521,39 @@ def record_current_batch_training_metrics(
infos[key] = value infos[key] = value
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( interval_steps = grad_profiling_config.get("interval_steps", 1)
"zero_grad_profiling", False if batch_count % interval_steps == 0:
): layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
layer_metrics = ["layer_grad_norm", "layer_zero_grad"] param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
param_metrics = ["param_grad_norm", "param_zero_grad"]
layer_names = grad_profiling_config.get("layers", []) layer_names = grad_profiling_config.get("layers", [])
for layer_metric_name in layer_metrics: for metric_name in layer_metrics:
layer_metric = grad_norm.get(layer_metric_name, {}) metric = grad_norm.get(metric_name)
if layer_metric: for group_name, layer_group in metric.items():
for group_name, layer_group in layer_metric.items(): title = f"{metric_name}/{group_name}"
if layer_group: metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
title = f"{layer_metric_name}/{group_name}" if metrics:
if layer_names: writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
filter_layer_metrics = {} del grad_norm[metric_name]
for layer_name, metric_value in layer_group.items(): for metric_name in param_metrics:
if layer_name in layer_names: metric = grad_norm.get(metric_name)
filter_layer_metrics[layer_name] = metric_value for group_name, layer_group in metric.items():
if filter_layer_metrics: for param_name, param_group in layer_group.items():
writer.add_scalars( title = f"{param_name}/{group_name}_{metric_name}"
key=title, value=filter_layer_metrics, step=train_state.step_count metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
) if metrics:
else: writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) del grad_norm[metric_name]
del grad_norm[layer_metric_name] if grad_profiling_config.get("vocab_grad_norm_profiling", False):
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
for param_metric_name in param_metrics: os.makedirs(local_save_path, exist_ok=True)
param_metric = grad_norm.get(param_metric_name, {}) local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
if param_metric: vocab_grad_norms = grad_norm.get("vocab_grad_norm")
for group_name, layer_group in param_metric.items(): if vocab_grad_norms:
for param_name, param_group in layer_group.items(): try:
title = f"{param_name}/{group_name}_{param_metric_name}" with open(local_save_file, "ab+") as vocab_f:
if layer_names: pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
filter_param_group = {} except IOError as e:
for layer_name, metric_value in param_group.items(): logger.warning(f"Error saving vocab_grad_norm: {e}")
if layer_name in layer_names:
filter_param_group[layer_name] = param_group[layer_name]
if filter_param_group:
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
else:
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
del grad_norm[param_metric_name]
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
os.makedirs(local_save_path, exist_ok=True)
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
vocab_grad_norms = grad_norm.get("vocab_grad_norm", {})
if vocab_grad_norms:
with open(local_save_file, "ab+") as vocab_f:
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
del grad_norm["vocab_grad_norm"] del grad_norm["vocab_grad_norm"]
line = "" line = ""

View File

@ -197,6 +197,7 @@ def main(args):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
gpc.config.batch_count = batch_count
# load batch data # load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)