diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1f1993f..264927b 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -7,22 +7,23 @@ MLP_RATIO = 8 / 3 NUM_LAYER = 32 VOCAB_SIZE = 103168 -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +MODEL_ONLY_FOLDER = "local:llm_ckpts/20" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +LOAD_CKPT_FOLDER = "local:llm_ckpts/20" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" # LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" -CHECKPOINT_EVERY = 50 +CHECKPOINT_EVERY = 20 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + load_given_ckpt = False, # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. load_optimizer=True, # Wheter to load optimizer states when continuing training. checkpoint_every=CHECKPOINT_EVERY, @@ -32,7 +33,7 @@ ckpt = dict( oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = "/path/to/dataset" +TRAIN_FOLDER = "../../train_data"#"/path/to/dataset" VALID_FOLDER = "/path/to/dataset" data = dict( seq_len=SEQ_LEN, @@ -50,7 +51,7 @@ data = dict( rampup_batch_size="", # Datasets with less than 50 rows will be discarded min_length=50, - # train_folder=TRAIN_FOLDER, + train_folder=TRAIN_FOLDER, # valid_folder=VALID_FOLDER, ) @@ -111,7 +112,7 @@ beta2_scheduler = dict( ) model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=True, # The proportion of layers for activation checkpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, @@ -121,7 +122,7 @@ model = dict( num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, @@ -140,9 +141,11 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=8, + zero1=-1, pipeline=dict(size=1, interleaved_overlap=True), + tensor=2, sequence_parallel=False, + use_fsdp = False, ) cudnn_deterministic = False diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index 97021dc..e80ed3c 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -12,6 +12,7 @@ from .process_group_initializer import ( Initializer_Zero1, ParallelMode, ProcessGroupInitializer, + Initializer_Zero3_dp, ) from .random import ( add_seed, @@ -34,6 +35,7 @@ __all__ = [ "Initializer_Pipeline", "Initializer_Data", "Initializer_Zero1", + "Initializer_Zero3_dp", "ProcessGroupInitializer", "Initializer_Model", "seed", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 87d3114..47796fb 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -462,6 +462,8 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) + if self.config.parallel.use_fsdp: + initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) if self.pipeline_parallel_size > 1: initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) for initializer in initializers: diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 56cf16d..cae609a 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -31,6 +31,11 @@ class ParallelMode(Enum): # zero1 parallel ZERO1 = "zero1" + # zero3-dp parallel + # if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size + # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp + ZERO3_DP = "zero3_dp" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -332,3 +337,62 @@ class Initializer_Zero1(ProcessGroupInitializer): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + +class Initializer_Zero3_dp(ProcessGroupInitializer): + """A ProcessGroupInitializer for data parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero1 parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.data_parallel_size % self.zero1_parallel_size == 0 + + # the only difference between this initializer and DP_initializer + # when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding + # eg: when zero=4 and dp=8 + # no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group + # fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually + + self.data_parallel_size //= self.zero1_parallel_size + self.rank_num_per_dp_group = self.world_size // self.data_parallel_size + + assert self.world_size % self.data_parallel_size == 0 + + + def init_dist_group(self, use_cpu: bool = False): + """Initialize data parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Data parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.ZERO3_DP + + for i in range(self.rank_num_per_dp_group): + ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] + group = dist.new_group(ranks) + if use_cpu: + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode \ No newline at end of file diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index a69a506..18a3cbd 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -61,6 +61,17 @@ def args_sanity_check(): if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", 1) + if isinstance(gpc.config.parallel.pipeline, int): + pp = gpc.config.parallel.pipeline + else: + pp = gpc.config.parallel.pipeline.size + + if "use_fsdp" not in gpc.config.parallel: + gpc.config.parallel._add_item("use_fsdp", False) + elif gpc.config.parallel.use_fsdp and pp > 1: + logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP") + gpc.config.parallel._add_item("use_fsdp", False) + # processing the data config in gpc data = gpc.config.data diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 3da5bbe..b7178ad 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .hybrid_zero_optim import HybridZeroOptimizer +from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer -__all__ = ["HybridZeroOptimizer"] +__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"] diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8bdeccf..6cb4f3d 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -80,6 +80,132 @@ class BaseOptimizer(Optimizer): pass +class FSDPadaptOptimizer(BaseOptimizer): + ''' + optimizer for Pytorch FSDP if 'use_fsdp' is True in config file + reserve some necessary components of hybird-optim: + grad_scaler; + grad_clip and unscale; + state_dict and load_state_dict + ''' + + def __init__( + self, + optimizer: Optimizer, + grad_scal_cfg: Config = None, + zero_cfg: Config = None, + ): + super().__init__(optim=optimizer) + + # gradient scaler + self.grad_scaler = DynamicGradScaler( + initial_scale=grad_scal_cfg.fp16.initial_scale, + min_scale=grad_scal_cfg.fp16.min_scale, + growth_factor=grad_scal_cfg.growth_factor, + backoff_factor=grad_scal_cfg.backoff_factor, + growth_interval=grad_scal_cfg.fp16.growth_interval, + hysteresis=grad_scal_cfg.hysteresis, + max_scale=grad_scal_cfg.max_scale, + ) + + # clip gradient + self._clip_grad_norm = zero_cfg.clip_grad_norm + self.use_fsdp = gpc.config.parallel.use_fsdp + + @property + def loss_scale(self): + return self.grad_scaler.scale + + def backward(self, loss, retain_graph=False): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + def step(self): + # in case that fsdp-zero3 size is not equal to dp size + # FSDP module will only reduce gradient within FSDP process group + # so manually reduce grad is essential between two parallel FSDP process group + for group_idx in range(len(self.param_groups)): + params = self.param_groups[group_idx]["params"] + for param in params: + if param.requires_grad: + reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) + + # compute norm + found_inf = False + norm_groups = [] + for group_idx in range(len(self.param_groups)): + params = self.param_groups[group_idx]["params"] + gradients = [p.grad for p in params] + norm_group = compute_norm( + gradients=gradients, + parameters=params, + last_stage=True + ) + if norm_group == -1: + found_inf = True + break + norm_groups.append(norm_group) + + loss_scale = float(self.loss_scale.item()) # backup + self.grad_scaler.update(found_inf) + if found_inf: + if gpc.is_rank_for_log(): + logger.warning("Overflow occurs, please check it.") + self.zero_grad() + return False, None + + if self._clip_grad_norm > 0: + global_norm = sum(norm_groups) ** 0.5 + + # unscale + for group_idx in range(len(self.param_groups)): + params = self.param_groups[group_idx]["params"] + for p in params: + self._unscale_and_clip_grads(p.grad, global_norm, loss_scale) + + self.optim.step() + self.zero_grad() + + return True, [global_norm / loss_scale] + + def clip_grad_norm(self, model, max_norm): + # will conduct in the step() + pass + + ######################### + # utils from hybirdzero # + ######################### + + def _unscale_and_clip_grads(self, grad, total_norm, loss_scale): + # compute combined scale factor for this group + combined_scale = loss_scale + + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1.0: + combined_scale = clip * loss_scale + + # for grad in grad_groups_flat: + grad.data.mul_(1.0 / combined_scale) + + def state_dict(self): + states = {} + grad_scaler = self.grad_scaler.state_dict() + states["grad_scaler"] = grad_scaler + optim_states = self.optim.state_dict() + states["base_optim_states"] = optim_states + + return states + + def load_state_dict(self, states): + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) + optim_states = states["base_optim_states"] + self.optim.load_state_dict(optim_states) + + class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 457d7a4..a5a2995 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -6,6 +6,7 @@ from .training_internlm import ( initialize_optimizer, load_new_batch, record_current_batch_training_metrics, + warp_FSDP_model, ) __all__ = [ @@ -16,4 +17,5 @@ __all__ = [ "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", + "warp_FSDP_model", ] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index bab56f1..1a9f9fa 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -28,7 +28,7 @@ from internlm.monitor import set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger @@ -40,6 +40,24 @@ from internlm.utils.parallel import ( ) from internlm.utils.registry import MODEL_INITIALIZER +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + BackwardPrefetch, + ShardingStrategy, + MixedPrecision, + BackwardPrefetch, +) +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + enable_wrap, + wrap, +) +import functools +from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D + + logger = get_logger(__file__) @@ -83,6 +101,30 @@ def initialize_model(): return model +def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + if gpc.config.parallel.use_fsdp: + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D} + ) + mx = MixedPrecision( + param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype, + buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True) + grp = gpc.get_group(ParallelMode.ZERO1) + model = FSDP(module=model, + process_group=grp, + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=transformer_wrap_policy, + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + #cpu_offload=CPUOfload(offload_params=True) + #mixed_precision=mx, + #device_id=torch.cuda.current_device() + ) + + return model + + def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): """ Initialize optimizer. @@ -105,12 +147,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): eps=adam_cfg.adam_eps, ) - optimizer = HybridZeroOptimizer( - naive_optimizer, - grad_scal_cfg=gpc.config.grad_scaler, - zero_cfg=gpc.config.hybrid_zero_optimizer, - param_bcast_sync_handler=param_bcast_sync_handler, - ) + if not gpc.config.parallel.use_fsdp: + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + else: + optimizer = FSDPadaptOptimizer( + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 09bafa5..120df0c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -55,6 +55,26 @@ def get_model_topology(model): return topos +def get_state_dict(model): + """ + Only used for FSDP module saving. + It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter + (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. + 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu + + """ + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig + + # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, save_policy): + states = model.state_dict() + + return states + + def save_model_checkpoint(folder, model): """ Save the model according to the relationship between tp and dp. The principle is that the data of each tp @@ -69,7 +89,11 @@ def save_model_checkpoint(folder, model): model: The model to be saved """ - states = model.state_dict() + if gpc.config.parallel.use_fsdp: + states = get_state_dict(model) + else: + states = model.state_dict() + topo = get_model_topology(model) if folder is not None: diff --git a/train.py b/train.py index de7cc7c..1134ebe 100644 --- a/train.py +++ b/train.py @@ -28,6 +28,7 @@ from internlm.train import ( initialize_optimizer, load_new_batch, record_current_batch_training_metrics, + warp_FSDP_model, ) from internlm.utils.common import ( BatchSkipper, @@ -123,6 +124,9 @@ def main(args): # Loading model weights must be done before zero is initialized. ckpt_manager.try_load_model(current_time) + # if fsdp enabled, warp the model + model = warp_FSDP_model(model) + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states.