add grad_norm profiling interval && refactor save grad norm

pull/519/head
JiaoPL 2023-11-28 20:41:29 +08:00
parent 4eed07a3c3
commit 83ebebd5bc
5 changed files with 95 additions and 99 deletions

View File

@ -355,6 +355,9 @@ def args_sanity_check():
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
) )
if "batch_count" not in gpc.config:
gpc.config._add_item("batch_count", 0)
if "moe_loss_coeff" not in gpc.config.loss: if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0) gpc.config.loss._add_item("moe_loss_coeff", 1.0)

View File

@ -639,8 +639,12 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_param_norms = [] groups_param_norms = []
group_param_zero_grad_count = [] group_param_zero_grad_count = []
group_vocab_norms = [] group_vocab_norms = []
batch_count = gpc.config.batch_count
interval_steps = grad_profiling_config.get("interval_steps", 1)
is_profiling = batch_count % interval_steps == 0
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
if grad_profiling_config.get("zero_grad_profiling", False): if grad_profiling_config.get("zero_grad_profiling", False):
@ -673,6 +677,7 @@ class HybridZeroOptimizer(BaseOptimizer):
last_stage=True, last_stage=True,
previous_norm=groups_norms[group_id], previous_norm=groups_norms[group_id],
) )
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
param_norms = self._compute_param_norm_stage( param_norms = self._compute_param_norm_stage(
group_id=group_id, group_id=group_id,
@ -713,6 +718,7 @@ class HybridZeroOptimizer(BaseOptimizer):
timer("sync_grad").stop() timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms)
if is_profiling:
if grad_profiling_config.get("grad_norm_profiling", False): if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["layer_grad_norm"] = total_layer_grad_norms global_norms["layer_grad_norm"] = total_layer_grad_norms
global_norms["param_grad_norm"] = total_param_grad_norms global_norms["param_grad_norm"] = total_param_grad_norms

View File

@ -241,9 +241,11 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
parallel_grads[param_name].append(g.data.float()) parallel_grads[param_name].append(g.data.float())
elif only_output: elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding" param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
grad_profiling_config = gpc.config.get("grad_profiling", {}) if (
layer_names = grad_profiling_config.get("layers", []) gpc.config.model["vocab_size"] == g.shape[0]
if gpc.config.model["vocab_size"] == g.shape[0] and param_name.split("-")[0] in layer_names: and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())
else: else:
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())

View File

@ -521,55 +521,39 @@ def record_current_batch_training_metrics(
infos[key] = value infos[key] = value
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( interval_steps = grad_profiling_config.get("interval_steps", 1)
"zero_grad_profiling", False if batch_count % interval_steps == 0:
): layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
layer_metrics = ["layer_grad_norm", "layer_zero_grad"] param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
param_metrics = ["param_grad_norm", "param_zero_grad"]
layer_names = grad_profiling_config.get("layers", []) layer_names = grad_profiling_config.get("layers", [])
for layer_metric_name in layer_metrics: for metric_name in layer_metrics:
layer_metric = grad_norm.get(layer_metric_name, {}) metric = grad_norm.get(metric_name)
if layer_metric: for group_name, layer_group in metric.items():
for group_name, layer_group in layer_metric.items(): title = f"{metric_name}/{group_name}"
if layer_group: metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
title = f"{layer_metric_name}/{group_name}" if metrics:
if layer_names: writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
filter_layer_metrics = {} del grad_norm[metric_name]
for layer_name, metric_value in layer_group.items(): for metric_name in param_metrics:
if layer_name in layer_names: metric = grad_norm.get(metric_name)
filter_layer_metrics[layer_name] = metric_value for group_name, layer_group in metric.items():
if filter_layer_metrics:
writer.add_scalars(
key=title, value=filter_layer_metrics, step=train_state.step_count
)
else:
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count)
del grad_norm[layer_metric_name]
for param_metric_name in param_metrics:
param_metric = grad_norm.get(param_metric_name, {})
if param_metric:
for group_name, layer_group in param_metric.items():
for param_name, param_group in layer_group.items(): for param_name, param_group in layer_group.items():
title = f"{param_name}/{group_name}_{param_metric_name}" title = f"{param_name}/{group_name}_{metric_name}"
if layer_names: metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
filter_param_group = {} if metrics:
for layer_name, metric_value in param_group.items(): writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
if layer_name in layer_names: del grad_norm[metric_name]
filter_param_group[layer_name] = param_group[layer_name]
if filter_param_group:
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count)
else:
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
del grad_norm[param_metric_name]
if grad_profiling_config.get("vocab_grad_norm_profiling", False): if grad_profiling_config.get("vocab_grad_norm_profiling", False):
local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm" local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
os.makedirs(local_save_path, exist_ok=True) os.makedirs(local_save_path, exist_ok=True)
local_save_file = f"{local_save_path}/vocab_grad_norm.pt" local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
vocab_grad_norms = grad_norm.get("vocab_grad_norm", {}) vocab_grad_norms = grad_norm.get("vocab_grad_norm")
if vocab_grad_norms: if vocab_grad_norms:
try:
with open(local_save_file, "ab+") as vocab_f: with open(local_save_file, "ab+") as vocab_f:
pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f) pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
except IOError as e:
logger.warning(f"Error saving vocab_grad_norm: {e}")
del grad_norm["vocab_grad_norm"] del grad_norm["vocab_grad_norm"]
line = "" line = ""

View File

@ -197,6 +197,7 @@ def main(args):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
gpc.config.batch_count = batch_count
# load batch data # load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)