Merge branch 'develop' of https://github.com/InternLM/InternLM into storage_multipart_upload

pull/529/head
lijiaxing 2023-12-07 10:23:05 +08:00
commit 3f49409681
9 changed files with 255 additions and 89 deletions

View File

@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module):
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module") return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
def set_output_attr_to_module(module: nn.Module):
setattr(module, "is_output", True)
def module_is_output(module: nn.Module):
return hasattr(module, "is_output") and getattr(module, "is_output")
class NaiveAMPModel(nn.Module): class NaiveAMPModel(nn.Module):
""" """
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module):
sub_module.to(fp32_dtype) sub_module.to(fp32_dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
if gpc.config.get("output_tf32", False) and module_is_output(sub_module):
sub_module.to(fp32_dtype)
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))

View File

@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path):
return match_idxes[0] return match_idxes[0]
def unpack_data(input_ids, cu_seqlens): def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
"""
input_ids: (n, packed_length)
Return:
output: (batch_size, max_length)
""" """
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
else the shape is (micro_num, packed_length)
is_type_ids: whether the input_ids is type_ids
Return:
output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
else the shape is (micro_num, micro_bsz, max_length)
"""
bsz = input_ids.shape[0] bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"] num_sequence = gpc.config.data["micro_bsz"]
@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens):
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
outputs[i] = output outputs[i] = output
if bsz == 1: # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
if bsz == 1 and not is_type_ids:
outputs = outputs.squeeze(0) outputs = outputs.squeeze(0)
return outputs return outputs

View File

@ -356,6 +356,9 @@ def args_sanity_check():
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
) )
if "batch_count" not in gpc.config:
gpc.config._add_item("batch_count", 0)
if "moe_loss_coeff" not in gpc.config.loss: if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0) gpc.config.loss._add_item("moe_loss_coeff", 1.0)

View File

@ -13,6 +13,7 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER from internlm.core.context.random import _SEED_MANAGER
from internlm.core.naive_amp import set_output_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module):
dtype=dtype, dtype=dtype,
weight_scale=embed_grad_scale, weight_scale=embed_grad_scale,
) )
set_output_attr_to_module(self.head)
for _, param in self.head.named_parameters(): for _, param in self.head.named_parameters():
normal_(std=0.0052)(param) normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:

View File

@ -39,6 +39,7 @@ from .utils import (
compute_layer_zero_grad_count, compute_layer_zero_grad_count,
compute_norm, compute_norm,
compute_param_norm, compute_param_norm,
compute_vocab_grad_norm,
compute_zero_grad_count, compute_zero_grad_count,
) )
@ -563,6 +564,28 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
return total_param_norms return total_param_norms
def _compute_vocab_grad_norm_stage(
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_vocab_grad_norm=None
):
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
dtype = self.param_groups[group_id]["dtype"]
grads = [self.padding_grad.to(dtype)]
params = [self.padding_tensor.to(dtype)]
vocab_grad_norm = None
if self._clip_grad_norm > 0:
vocab_grad_norm = compute_vocab_grad_norm(
grads,
params,
last_stage=last_stage,
previous_vocab_grad_norm=previous_vocab_grad_norm,
zero_mode=self._broadcast_parallel_mode[group_id],
)
return vocab_grad_norm
def _count_zero_grads_stage( def _count_zero_grads_stage(
self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None self, group_id: int = 0, last_bucket: bool = False, last_stage: bool = False, previous_zero_grad_count=None
): ):
@ -615,12 +638,19 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms = [] groups_norms = []
groups_param_norms = [] groups_param_norms = []
group_param_zero_grad_count = [] group_param_zero_grad_count = []
group_vocab_norms = []
batch_count = gpc.config.get("batch_count")
interval_steps = grad_profiling_config.get("interval_steps", 1)
is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id)) if grad_profiling_config.get("grad_norm_profiling", False):
if grad_profiling_config.get("zero_grad_profiling", False): groups_param_norms.append(self._compute_param_norm_stage(group_id=group_id))
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id)) if grad_profiling_config.get("zero_grad_profiling", False):
group_param_zero_grad_count.append(self._count_zero_grads_stage(group_id=group_id))
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
group_vocab_norms.append(self._compute_vocab_grad_norm_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
# grads in the last bucket is reduced # grads in the last bucket is reduced
@ -637,6 +667,7 @@ class HybridZeroOptimizer(BaseOptimizer):
total_layer_grad_norms = {} total_layer_grad_norms = {}
total_param_zero_grad_count = {} total_param_zero_grad_count = {}
total_layer_zero_grad_count = {} total_layer_zero_grad_count = {}
total_vocab_grad_norms = {}
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}" group_name = f"{group_id}_{group_name}"
@ -646,39 +677,56 @@ class HybridZeroOptimizer(BaseOptimizer):
last_stage=True, last_stage=True,
previous_norm=groups_norms[group_id], previous_norm=groups_norms[group_id],
) )
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
param_norms = self._compute_param_norm_stage( if grad_profiling_config.get("grad_norm_profiling", False):
group_id=group_id, param_norms = self._compute_param_norm_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_param_norms=groups_param_norms[group_id], last_stage=True,
) previous_param_norms=groups_param_norms[group_id],
total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm( )
param_norms=param_norms, loss_scale=self.loss_scale.item() total_layer_grad_norms[group_name], total_param_grad_norms[group_name] = compute_layer_norm(
) param_norms=param_norms, loss_scale=self.loss_scale.item()
if grad_profiling_config.get("zero_grad_profiling", False): )
zero_grad_count = self._count_zero_grads_stage( if grad_profiling_config.get("zero_grad_profiling", False):
group_id=group_id, zero_grad_count = self._count_zero_grads_stage(
last_bucket=True, group_id=group_id,
last_stage=True, last_bucket=True,
previous_zero_grad_count=group_param_zero_grad_count[group_id], last_stage=True,
) previous_zero_grad_count=group_param_zero_grad_count[group_id],
( )
total_layer_zero_grad_count[group_name], (
total_param_zero_grad_count[group_name], total_layer_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count) total_param_zero_grad_count[group_name],
) = compute_layer_zero_grad_count(zero_grad_count)
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
vocab_grad_norms = self._compute_vocab_grad_norm_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_vocab_grad_norm=group_vocab_norms[group_id],
)
inf_mask = vocab_grad_norms == -1
nan_mask = vocab_grad_norms == -2
vocab_grad_norms = vocab_grad_norms**0.5 / self.loss_scale.item()
vocab_grad_norms[inf_mask] = -1
vocab_grad_norms[nan_mask] = -2
total_vocab_grad_norms[group_name] = vocab_grad_norms.to("cpu")
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
state, global_norms = self._step(closure=closure, norms=total_norms) state, global_norms = self._step(closure=closure, norms=total_norms)
if grad_profiling_config.get("grad_norm_profiling", False): if is_profiling:
global_norms["layer_grad_norm"] = total_layer_grad_norms if grad_profiling_config.get("grad_norm_profiling", False):
global_norms["param_grad_norm"] = total_param_grad_norms global_norms["layer_grad_norm"] = total_layer_grad_norms
if grad_profiling_config.get("zero_grad_profiling", False): global_norms["param_grad_norm"] = total_param_grad_norms
global_norms["layer_zero_grad"] = total_layer_zero_grad_count if grad_profiling_config.get("zero_grad_profiling", False):
global_norms["param_zero_grad"] = total_param_zero_grad_count global_norms["layer_zero_grad"] = total_layer_zero_grad_count
global_norms["param_zero_grad"] = total_param_zero_grad_count
if grad_profiling_config.get("vocab_grad_norm_profiling", False):
global_norms["vocab_grad_norm"] = total_vocab_grad_norms
return state, global_norms return state, global_norms

View File

@ -218,7 +218,17 @@ def calc_zero_grad(grads):
return torch.tensor([zero_count, grad_size]) return torch.tensor([zero_count, grad_size])
def reduce_grads(gradients, parameters, fine_grained=False): def get_norm(grads, norm_type, enable_cuda_kernels):
if norm_type == inf:
grad_norm = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels:
grad_norm = calc_l2_norm(grads) ** norm_type
else:
grad_norm = calc_lp(grads, norm_type)
return grad_norm
def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
parallel_grads = [] parallel_grads = []
if fine_grained: if fine_grained:
parallel_grads = {} parallel_grads = {}
@ -229,6 +239,14 @@ def reduce_grads(gradients, parameters, fine_grained=False):
if param_name not in parallel_grads: if param_name not in parallel_grads:
parallel_grads[param_name] = [] parallel_grads[param_name] = []
parallel_grads[param_name].append(g.data.float()) parallel_grads[param_name].append(g.data.float())
elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
if (
gpc.config.model["vocab_size"] == g.shape[0]
and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
parallel_grads.append(g.data.float())
else: else:
parallel_grads.append(g.data.float()) parallel_grads.append(g.data.float())
@ -306,10 +324,7 @@ def compute_norm(
else: else:
tensor_parallel_grads = reduce_grads(gradients, parameters) tensor_parallel_grads = reduce_grads(gradients, parameters)
if norm_type == 2.0 and enable_cuda_kernels: tensor_parallel_norm = get_norm(tensor_parallel_grads, norm_type, enable_cuda_kernels)
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
else:
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor. # If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
@ -359,6 +374,59 @@ def compute_norm(
return total_norm return total_norm
def compute_vocab_grad_norm(
gradients,
parameters,
last_stage=False,
previous_vocab_grad_norm=None,
norm_type=2,
zero_mode=ParallelMode.ZERO1,
):
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
vocab_size = gpc.config.model["vocab_size"]
param_grads = reduce_grads(gradients, parameters, only_output=True)
vocab_grad_norm = torch.zeros((vocab_size,), dtype=torch.float32).to(get_current_device())
if param_grads:
for grad in param_grads:
# get grad norm of each vocab
for i in range(vocab_size):
cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0]
vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True)
if last_stage is False:
return vocab_grad_norm
if previous_vocab_grad_norm is not None:
vocab_grad_norm = vocab_grad_norm + previous_vocab_grad_norm
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(
vocab_grad_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
dist.all_reduce(vocab_grad_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
if zero_mode == ParallelMode.EXPERT_DATA:
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = vocab_grad_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
vocab_grad_norm = scaled_norm_tensor.item()
# Scale.
vocab_grad_norm[vocab_grad_norm == float("inf")] = -1
vocab_grad_norm[vocab_grad_norm == -float("inf")] = -1
vocab_grad_norm[torch.isnan(vocab_grad_norm)] = -2
return vocab_grad_norm
def compute_param_metric( def compute_param_metric(
gradients, gradients,
parameters, parameters,
@ -384,12 +452,7 @@ def compute_param_metric(
for param_name, grads in param_grads.items(): for param_name, grads in param_grads.items():
if metric_type == "norm": if metric_type == "norm":
if norm_type == inf: param_metric = get_norm(grads, norm_type, enable_cuda_kernels)
param_metric = max(g.data.abs().max() for g in grads)
elif norm_type == 2.0 and enable_cuda_kernels:
param_metric = calc_l2_norm(grads) ** norm_type
else:
param_metric = calc_lp(grads, norm_type)
param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric param_metrics[param_name] = param_metric.item() if torch.is_tensor(param_metric) else param_metric
elif metric_type == "zero_grad": elif metric_type == "zero_grad":
param_zero_grad_count = calc_zero_grad(grads) param_zero_grad_count = calc_zero_grad(grads)

View File

@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import functools import functools
import os
import pickle
import time import time
from functools import partial from functools import partial
from typing import Callable, Iterable, Optional, Union from typing import Callable, Iterable, Optional, Union
@ -49,10 +51,11 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile, launch_time
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (
check_sequence_parallel,
set_model_params_layer_name, set_model_params_layer_name,
sync_model_param, sync_model_param,
sync_model_param_within_tp, sync_model_param_within_tp,
@ -111,6 +114,10 @@ def initialize_model():
# if fsdp enabled, wrap the model # if fsdp enabled, wrap the model
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute
if gpc.config.parallel.sequence_parallel is True:
check_sequence_parallel(model)
return model return model
@ -159,8 +166,10 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
A tuple of (optimizer, beta2_scheduler, lr_scheduler). A tuple of (optimizer, beta2_scheduler, lr_scheduler).
""" """
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( if (
"zero_grad_profiling", False grad_profiling_config.get("grad_norm_profiling", False)
or grad_profiling_config.get("zero_grad_profiling", False)
or grad_profiling_config.get("vocab_grad_norm_profiling", False)
): ):
# set the layer name as an attribute of the model parameters # set the layer name as an attribute of the model parameters
set_model_params_layer_name(model) set_model_params_layer_name(model)
@ -359,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
if batch[0].get("type_ids", None) is not None: if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids # if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn: if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"]) batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)
return batch, train_iter return batch, train_iter
@ -517,43 +526,40 @@ def record_current_batch_training_metrics(
infos[key] = value infos[key] = value
grad_profiling_config = gpc.config.get("grad_profiling", {}) grad_profiling_config = gpc.config.get("grad_profiling", {})
if grad_profiling_config.get("grad_norm_profiling", False) or grad_profiling_config.get( interval_steps = grad_profiling_config.get("interval_steps", 1)
"zero_grad_profiling", False if batch_count % interval_steps == 0:
): layer_metrics = [metric for metric in ["layer_grad_norm", "layer_zero_grad"] if metric in grad_norm]
layer_metrics = ["layer_grad_norm", "layer_zero_grad"] param_metrics = [metric for metric in ["param_grad_norm", "param_zero_grad"] if metric in grad_norm]
param_metrics = ["param_grad_norm", "param_zero_grad"]
layer_names = grad_profiling_config.get("layers", []) layer_names = grad_profiling_config.get("layers", [])
for layer_metric_name in layer_metrics: for metric_name in layer_metrics:
layer_metric = grad_norm.get(layer_metric_name, {}) metric = grad_norm.get(metric_name)
if layer_metric: for group_name, layer_group in metric.items():
for group_name, layer_group in layer_metric.items(): title = f"{metric_name}/{group_name}"
if layer_group: metrics = {k: v for k, v in layer_group.items() if not layer_names or k in layer_names}
title = f"{layer_metric_name}/{group_name}" if metrics:
if layer_names: writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
filter_layer_metrics = {} del grad_norm[metric_name]
for layer_name, metric_value in layer_group.items(): for metric_name in param_metrics:
if layer_name in layer_names: metric = grad_norm.get(metric_name)
filter_layer_metrics[layer_name] = metric_value for group_name, layer_group in metric.items():
writer.add_scalars(key=title, value=filter_layer_metrics, step=train_state.step_count) for param_name, param_group in layer_group.items():
else: title = f"{param_name}/{group_name}_{metric_name}"
writer.add_scalars(key=title, value=layer_group, step=train_state.step_count) metrics = {k: v for k, v in param_group.items() if not layer_names or k in layer_names}
del grad_norm[layer_metric_name] if metrics:
writer.add_scalars(key=title, value=metrics, step=train_state.step_count)
for param_metric_name in param_metrics: del grad_norm[metric_name]
param_metric = grad_norm.get(param_metric_name, {}) if grad_profiling_config.get("vocab_grad_norm_profiling", False):
if param_metric: local_save_path = f"RUN/{gpc.config.JOB_NAME}/{launch_time()}/grad_norm"
for group_name, layer_group in param_metric.items(): os.makedirs(local_save_path, exist_ok=True)
for param_name, param_group in layer_group.items(): local_save_file = f"{local_save_path}/vocab_grad_norm.pt"
title = f"{param_name}/{group_name}_{param_metric_name}" vocab_grad_norms = grad_norm.get("vocab_grad_norm")
if layer_names: if vocab_grad_norms:
filter_param_group = {} try:
for layer_name, metric_value in param_group.items(): with open(local_save_file, "ab+") as vocab_f:
if layer_name in layer_names: pickle.dump((train_state.step_count, vocab_grad_norms), vocab_f)
filter_param_group[layer_name] = param_group[layer_name] except IOError as e:
writer.add_scalars(key=title, value=filter_param_group, step=train_state.step_count) logger.warning(f"Error saving vocab_grad_norm: {e}")
else: del grad_norm["vocab_grad_norm"]
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
del grad_norm[param_metric_name]
line = "" line = ""
for key, value in infos.items(): for key, value in infos.items():

View File

@ -4,11 +4,14 @@
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.utils import try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform from internlm.solver.pipeline_utils import partition_uniform
RMSNorm = try_import_RMSNorm()
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
@ -98,3 +101,26 @@ def set_model_params_layer_name(model):
layer_param_name = f"{layer_name}-{param_name}" layer_param_name = f"{layer_name}-{param_name}"
param.__setattr__("layer_name", layer_name) param.__setattr__("layer_name", layer_name)
param.__setattr__("param_name", f"{layer_name}-{param_name}") param.__setattr__("param_name", f"{layer_name}-{param_name}")
def check_sequence_parallel(model):
"""
check whether the norm module has IS_SEQUENCE_PARALLEL attribute.
when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute
to illustrate the norm should conduct the all-reduce for its grad.
"""
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, module in _chunk.named_modules():
if isinstance(module, (RMSNorm, nn.LayerNorm)):
for param in module.parameters():
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the gpc.config.parallel.sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
)

View File

@ -197,6 +197,7 @@ def main(args):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
gpc.config.batch_count = batch_count
# load batch data # load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)