mirror of https://github.com/InternLM/InternLM
add grad_norm profiling interval && refactor save grad norm
parent
4eed07a3c3
commit
83ebebd5bc
|
@ -355,6 +355,9 @@ def args_sanity_check():
|
|||
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
|
||||
)
|
||||
|
||||
if "batch_count" not in gpc.config:
|
||||
gpc.config._add_item("batch_count", 0)
|
||||
|
||||
if "moe_loss_coeff" not in gpc.config.loss:
|
||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||
|
||||
|
|
|
@ -639,14 +639,18 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_param_norms = []
|
||||
group_param_zero_grad_count = []
|
||||
group_vocab_norms = []
|
||||
batch_count = gpc.config.batch_count
|
||||
interval_steps = grad_profiling_config.get("interval_steps", 1)
|
||||
is_profiling = batch_count % interval_steps == 0
|
||||
for group_id in range(self.num_param_groups):
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
# grads in the last bucket is reduced
|
||||
|
@ -673,54 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
param_norms = self._compute_param_norm_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_param_norms=groups_param_norms[group_id],
|
||||
)
|
||||
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
|
||||
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
||||
)
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
zero_grad_count = self._count_zero_grads_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_zero_grad_count=group_param_zero_grad_count[group_id],
|
||||
)
|
||||
(
|
||||
total_layer_zero_grad_count[group_name],
|
||||
total_param_zero_grad_count[group_name],
|
||||
) = compute_layer_zero_grad_count(zero_grad_count)
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_vocab_grad_norm=group_vocab_norms[group_id],
|
||||
)
|
||||
inf_mask = vocab_grad_norms == -1
|
||||
nan_mask = vocab_grad_norms == -2
|
||||
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
|
||||
vocab_grad_norms[inf_mask] = -1
|
||||
vocab_grad_norms[nan_mask] = -2
|
||||
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
param_norms = self._compute_param_norm_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_param_norms=groups_param_norms[group_id],
|
||||
)
|
||||
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
|
||||
param_norms=param_norms, loss_scale=self.loss_scale.item()
|
||||
)
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
zero_grad_count = self._count_zero_grads_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_zero_grad_count=group_param_zero_grad_count[group_id],
|
||||
)
|
||||
(
|
||||
total_layer_zero_grad_count[group_name],
|
||||
total_param_zero_grad_count[group_name],
|
||||
) = compute_layer_zero_grad_count(zero_grad_count)
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_vocab_grad_norm=group_vocab_norms[group_id],
|
||||
)
|
||||
inf_mask = vocab_grad_norms == -1
|
||||
nan_mask = vocab_grad_norms == -2
|
||||
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
|
||||
vocab_grad_norms[inf_mask] = -1
|
||||
vocab_grad_norms[nan_mask] = -2
|
||||
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
state, global_norms = self._step(closure=closure, norms=total_norms)
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
global_norms["layer_grad_norm"] = total_layer_grad_norms
|
||||
global_norms["param_grad_norm"] = total_param_grad_norms
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
||||
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
|
||||
if is_profiling:
|
||||
if grad_profiling_config.get("grad_norm_profiling", False):
|
||||
global_norms["layer_grad_norm"] = total_layer_grad_norms
|
||||
global_norms["param_grad_norm"] = total_param_grad_norms
|
||||
if grad_profiling_config.get("zero_grad_profiling", False):
|
||||
global_norms["layer_zero_grad"] = total_layer_zero_grad_count
|
||||
global_norms["param_zero_grad"] = total_param_zero_grad_count
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
|
||||
|
||||
return state, global_norms
|
||||
|
||||
|
|
|
@ -241,9 +241,11 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
|
|||
parallel_grads[param_name].append(g.data.float())
|
||||
elif only_output:
|
||||
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
|
||||
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||
layer_names = grad_profiling_config.get("layers", [])
|
||||
if gpc.config.model["vocab_size"] == g.shape[0] and param_name.split("-")[0] in layer_names:
|
||||
if (
|
||||
gpc.config.model["vocab_size"] == g.shape[0]
|
||||
and gpc.config.model["hidden_size"] == g.shape[1]
|
||||
and "embedding" not in param_name.lower()
|
||||
):
|
||||
parallel_grads.append(g.data.float())
|
||||
else:
|
||||
parallel_grads.append(g.data.float())
|
||||
|
|
|
@ -521,55 +521,39 @@ def record_current_batch_training_metrics(
|
|||
infos[key] = value
|
||||
|
||||
grad_profiling_config = gpc.config.get("grad_profiling", {})
|
||||
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get(
|
||||
"zero_grad_profiling", False
|
||||
):
|
||||
layer_metrics = ["layer_grad_norm", "layer_zero_grad"]
|
||||
param_metrics = ["param_grad_norm", "param_zero_grad"]
|
||||
interval_steps = grad_profiling_config.get("interval_steps", 1)
|
||||
if batch_count % interval_steps == 0:
|
||||
layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
|
||||
param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
|
||||
layer_names = grad_profiling_config.get("layers", [])
|
||||
for layer_metric_name in layer_metrics:
|
||||
layer_metric = grad_norm.get(layer_metric_name, {})
|
||||
if layer_metric:
|
||||
for group_name, layer_group in layer_metric.items():
|
||||
if layer_group:
|
||||
title = f"{layer_metric_name}/{group_name}"
|
||||
if layer_names:
|
||||
filter_layer_metrics = {}
|
||||
for layer_name, metric_value in layer_group.items():
|
||||
if layer_name in layer_names:
|
||||
filter_layer_metrics[layer_name] = metric_value
|
||||
if filter_layer_metrics:
|
||||
writer.add_scalars(
|
||||
key=title, value=filter_layer_metrics, step=train_state.step_count
|
||||
)
|
||||
else:
|
||||
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
|
||||
del grad_norm[layer_metric_name]
|
||||
|
||||
for param_metric_name in param_metrics:
|
||||
param_metric = grad_norm.get(param_metric_name, {})
|
||||
if param_metric:
|
||||
for group_name, layer_group in param_metric.items():
|
||||
for param_name, param_group in layer_group.items():
|
||||
title = f"{param_name}/{group_name}_{param_metric_name}"
|
||||
if layer_names:
|
||||
filter_param_group = {}
|
||||
for layer_name, metric_value in param_group.items():
|
||||
if layer_name in layer_names:
|
||||
filter_param_group[layer_name] = param_group[layer_name]
|
||||
if filter_param_group:
|
||||
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
|
||||
else:
|
||||
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
|
||||
del grad_norm[param_metric_name]
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
|
||||
os.makedirs(local_save_path, exist_ok=True)
|
||||
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
|
||||
vocab_grad_norms = grad_norm.get("vocab_grad_norm", {})
|
||||
if vocab_grad_norms:
|
||||
with open(local_save_file, "ab+") as vocab_f:
|
||||
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
|
||||
for metric_name in layer_metrics:
|
||||
metric = grad_norm.get(metric_name)
|
||||
for group_name, layer_group in metric.items():
|
||||
title = f"{metric_name}/{group_name}"
|
||||
metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
|
||||
if metrics:
|
||||
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
|
||||
del grad_norm[metric_name]
|
||||
for metric_name in param_metrics:
|
||||
metric = grad_norm.get(metric_name)
|
||||
for group_name, layer_group in metric.items():
|
||||
for param_name, param_group in layer_group.items():
|
||||
title = f"{param_name}/{group_name}_{metric_name}"
|
||||
metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
|
||||
if metrics:
|
||||
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
|
||||
del grad_norm[metric_name]
|
||||
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
|
||||
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
|
||||
os.makedirs(local_save_path, exist_ok=True)
|
||||
local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
|
||||
vocab_grad_norms = grad_norm.get("vocab_grad_norm")
|
||||
if vocab_grad_norms:
|
||||
try:
|
||||
with open(local_save_file, "ab+") as vocab_f:
|
||||
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
|
||||
except IOError as e:
|
||||
logger.warning(f"Error saving vocab_grad_norm: {e}")
|
||||
del grad_norm["vocab_grad_norm"]
|
||||
|
||||
line = ""
|
||||
|
|
1
train.py
1
train.py
|
@ -197,6 +197,7 @@ def main(args):
|
|||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
gpc.config.batch_count = batch_count
|
||||
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
|
|
Loading…
Reference in New Issue