add grad_norm profiling interval && refactor save grad norm

pull/519/head
JiaoPL 2023-11-28 20:41:29 +08:00
parent 4eed07a3c3
commit 83ebebd5bc
5 changed files with 95 additions and 99 deletions

View File

@ -355,6 +355,9 @@ def args_sanity_check():
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
)
if "batch_count" not in gpc.config:
gpc.config._add_item("batch_count", 0)
if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0)

View File

@ -639,14 +639,18 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_param_norms = []
group_param_zero_grad_count = []
group_vocab_norms = []
batch_count = gpc.config.batch_count
interval_steps = grad_profiling_config.get("interval_steps", 1)
is_profiling = batch_count % interval_steps == 0
for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if grad_profiling_config.get("grad_norm_profiling", False):
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
if grad_profiling_config.get("zero_grad_profiling", False):
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False):
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
if grad_profiling_config.get("zero_grad_profiling", False):
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
# clear reduced grads
# grads in the last bucket is reduced
@ -673,54 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer):
last_stage=True,
previous_norm=groups_norms[group_id],
)
if grad_profiling_config.get("grad_norm_profiling", False):
param_norms = self._compute_param_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_param_norms=groups_param_norms[group_id],
)
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
param_norms=param_norms, loss_scale=self.loss_scale.item()
)
if grad_profiling_config.get("zero_grad_profiling", False):
zero_grad_count = self._count_zero_grads_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_zero_grad_count=group_param_zero_grad_count[group_id],
)
(
total_layer_zero_grad_count[group_name],
total_param_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count)
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_vocab_grad_norm=group_vocab_norms[group_id],
)
inf_mask = vocab_grad_norms == -1
nan_mask = vocab_grad_norms == -2
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
vocab_grad_norms[inf_mask] = -1
vocab_grad_norms[nan_mask] = -2
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False):
param_norms = self._compute_param_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_param_norms=groups_param_norms[group_id],
)
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
param_norms=param_norms, loss_scale=self.loss_scale.item()
)
if grad_profiling_config.get("zero_grad_profiling", False):
zero_grad_count = self._count_zero_grads_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_zero_grad_count=group_param_zero_grad_count[group_id],
)
(
total_layer_zero_grad_count[group_name],
total_param_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count)
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_vocab_grad_norm=group_vocab_norms[group_id],
)
inf_mask = vocab_grad_norms == -1
nan_mask = vocab_grad_norms == -2
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
vocab_grad_norms[inf_mask] = -1
vocab_grad_norms[nan_mask] = -2
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms)
if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms
global_norms["param_grad_norm"] = total_param_grad_norms
if grad_profiling_config.get("zero_grad_profiling", False):
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
global_norms["param_zero_grad"] = total_param_zero_grad_count
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms
global_norms["param_grad_norm"] = total_param_grad_norms
if grad_profiling_config.get("zero_grad_profiling", False):
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
global_norms["param_zero_grad"] = total_param_zero_grad_count
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
return state, global_norms

View File

@ -241,9 +241,11 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
parallel_grads[param_name].append(g.data.float())
elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
grad_profiling_config = gpc.config.get("grad_profiling", {})
layer_names = grad_profiling_config.get("layers", [])
if gpc.config.model["vocab_size"] == g.shape[0] and param_name.split("-")[0] in layer_names:
if (
gpc.config.model["vocab_size"] == g.shape[0]
and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
parallel_grads.append(g.data.float())
else:
parallel_grads.append(g.data.float())

View File

@ -521,55 +521,39 @@ def record_current_batch_training_metrics(
infos[key] = value
grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
"zero_grad_profiling", False
):
layer_metrics = ["layer_grad_norm", "layer_zero_grad"]
param_metrics = ["param_grad_norm", "param_zero_grad"]
interval_steps = grad_profiling_config.get("interval_steps", 1)
if batch_count % interval_steps == 0:
layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
layer_names = grad_profiling_config.get("layers", [])
for layer_metric_name in layer_metrics:
layer_metric = grad_norm.get(layer_metric_name, {})
if layer_metric:
for group_name, layer_group in layer_metric.items():
if layer_group:
title = f"{layer_metric_name}/{group_name}"
if layer_names:
filter_layer_metrics = {}
for layer_name, metric_value in layer_group.items():
if layer_name in layer_names:
filter_layer_metrics[layer_name] = metric_value
if filter_layer_metrics:
writer.add_scalars(
key=title, value=filter_layer_metrics, step=train_state.step_count
)
else:
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
del grad_norm[layer_metric_name]
for param_metric_name in param_metrics:
param_metric = grad_norm.get(param_metric_name, {})
if param_metric:
for group_name, layer_group in param_metric.items():
for param_name, param_group in layer_group.items():
title = f"{param_name}/{group_name}_{param_metric_name}"
if layer_names:
filter_param_group = {}
for layer_name, metric_value in param_group.items():
if layer_name in layer_names:
filter_param_group[layer_name] = param_group[layer_name]
if filter_param_group:
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
else:
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
del grad_norm[param_metric_name]
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
os.makedirs(local_save_path, exist_ok=True)
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
vocab_grad_norms = grad_norm.get("vocab_grad_norm", {})
if vocab_grad_norms:
with open(local_save_file, "ab+") as vocab_f:
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
for metric_name in layer_metrics:
metric = grad_norm.get(metric_name)
for group_name, layer_group in metric.items():
title = f"{metric_name}/{group_name}"
metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
if metrics:
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
del grad_norm[metric_name]
for metric_name in param_metrics:
metric = grad_norm.get(metric_name)
for group_name, layer_group in metric.items():
for param_name, param_group in layer_group.items():
title = f"{param_name}/{group_name}_{metric_name}"
metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
if metrics:
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
del grad_norm[metric_name]
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
os.makedirs(local_save_path, exist_ok=True)
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
vocab_grad_norms = grad_norm.get("vocab_grad_norm")
if vocab_grad_norms:
try:
with open(local_save_file, "ab+") as vocab_f:
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
except IOError as e:
logger.warning(f"Error saving vocab_grad_norm: {e}")
del grad_norm["vocab_grad_norm"]
line = ""

View File

@ -197,6 +197,7 @@ def main(args):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time()
timer("one-batch").start()
gpc.config.batch_count = batch_count
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)