diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 505d17b..865b959 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -154,7 +154,7 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=8, + zero1=dict(size=8, fsdp=False), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index 3fc7deb..5cbb832 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -11,6 +11,7 @@ from .process_group_initializer import ( Initializer_Pipeline, Initializer_Tensor, Initializer_Zero1, + Initializer_Zero3_dp, ParallelMode, ProcessGroupInitializer, ) @@ -36,6 +37,7 @@ __all__ = [ "Initializer_Data", "Initializer_Zero1", "Initializer_Nettest", + "Initializer_Zero3_dp", "ProcessGroupInitializer", "Initializer_Model", "seed", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index da6a0d7..997bd46 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta): return False return self.is_last_rank(ParallelMode.PIPELINE) + def is_no_pp_or_last_stage(self): + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage() + def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. @@ -429,6 +432,16 @@ class ParallelContext(metaclass=SingletonMeta): assert self.zero1_parallel_size > 0 assert self.data_parallel_size % self.zero1_parallel_size == 0 + # check for fsdp: + # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights + # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank + # pytorch vision: 1.13.1+cu117 + if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False): + logger.warning( + f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, " + "will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value" + ) + def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: ele = config[key] @@ -495,6 +508,8 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) + if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False): + initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) if self.pipeline_parallel_size > 1: initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 7f61e64..e9afa2e 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -37,6 +37,11 @@ class ParallelMode(Enum): # runntime network test NETTEST = "nettest" + # zero3-dp parallel + # if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size + # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp + ZERO3_DP = "zero3_dp" + # expert parallel EXPERT = "expert" @@ -594,3 +599,62 @@ class Initializer_Expert_Data(ProcessGroupInitializer): ) return groups + + +class Initializer_Zero3_dp(ProcessGroupInitializer): + """A ProcessGroupInitializer for data parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero1 parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.data_parallel_size % self.zero1_parallel_size == 0 + + # the only difference between this initializer and DP_initializer + # when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding + # eg: when zero=4 and dp=8 + # no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group + # fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually + + self.data_parallel_size //= self.zero1_parallel_size + self.rank_num_per_dp_group = self.world_size // self.data_parallel_size + + assert self.world_size % self.data_parallel_size == 0 + + def init_dist_group(self, use_cpu: bool = False): + """Initialize data parallel groups, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A Data parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.ZERO3_DP + + for i in range(self.rank_num_per_dp_group): + ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] + group = dist.new_group(ranks) + if use_cpu: + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 985e57f..fead575 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout -# check pacakge +# check package try: import numa from numa import memory, schedule @@ -65,14 +65,36 @@ def args_sanity_check(): # procssing the parallel config in gpc if "zero1" not in gpc.config.parallel: - gpc.config.parallel._add_item("zero1", -1) + gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False)) + + if isinstance(gpc.config.parallel.zero1, int): + zero1_size = gpc.config.parallel.zero1 + gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) if "pipeline" not in gpc.config.parallel: - gpc.config.parallel._add_item("pipeline", 1) + gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", 1) + if isinstance(gpc.config.parallel.pipeline, int): + pp = gpc.config.parallel.pipeline + else: + pp = gpc.config.parallel.pipeline.size + + # check fsdp config + if "fsdp" not in gpc.config.parallel.zero1: + gpc.config.parallel.zero1._add_item("fsdp", False) + + assert not ( + gpc.config.parallel.zero1.fsdp and pp > 1 + ), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP" + + if gpc.config.parallel.zero1.fsdp: + assert ( + torch.__version__ >= "2.0.1" + ), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}" + # processing the data config in gpc data = gpc.config.data @@ -271,6 +293,9 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) + assert not ( + gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp + ), "FSDP does not support num_experts > 1" # process the parallel config if "sequence_parallel" not in gpc.config.parallel: diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 24ce592..3a77f8b 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -6,7 +6,6 @@ from torch_scatter import scatter from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.utils.parallel import is_no_pp_or_last_stage class AccPerplex: @@ -138,7 +137,7 @@ class AccPerplex: self.total_log_probs += total_log_probs def get_metric(self, reset=True): - if is_no_pp_or_last_stage() and self.dp_pg is not None: + if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None: torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) @@ -236,7 +235,7 @@ class LossWithTypeId: self.ds_token_num += token_num_type def get_metric(self, reset=True): - if is_no_pp_or_last_stage() and self.dp_pg is not None: + if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None: torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) if hasattr(self, "total_type_count"): diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 99051f4..7c6a1c6 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from .fsdp_optimizer import FSDPadaptOptimizer from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff -__all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"] +__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "reload_zero_fp32_buff"] diff --git a/internlm/solver/optimizer/base_optimizer.py b/internlm/solver/optimizer/base_optimizer.py new file mode 100644 index 0000000..61d26ca --- /dev/null +++ b/internlm/solver/optimizer/base_optimizer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +from torch.optim import Optimizer + + +class BaseOptimizer(Optimizer): + """ + Base Optimizer. + """ + + def __init__(self, optim: Optimizer): # pylint: disable=W0231 + self.optim = optim + + @property + def param_groups(self): + return self.optim.param_groups + + @property + def defaults(self): + return self.optim.defaults + + def add_param_group(self, *args, **kwargs): + return self.optim.add_param_group(*args, **kwargs) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + self.optim.zero_grad(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + self.optim.load_state_dict(*args, **kwargs) + + def state_dict(self): + return self.optim.state_dict() + + def backward(self, loss): + loss.backward() + + def backward_by_grad(self, tensor, grad): + torch.autograd.backward(tensors=tensor, grad_tensors=grad) + + def clip_grad_norm(self): + pass diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py new file mode 100644 index 0000000..6000185 --- /dev/null +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from torch.optim import Optimizer + +from internlm.core.context import Config, ParallelMode +from internlm.core.context import global_context as gpc +from internlm.solver.optimizer.utils import ( + DynamicGradScaler, + reduce_tensor, + release_param_grad, +) +from internlm.utils.logger import get_logger + +from .base_optimizer import BaseOptimizer +from .utils import compute_norm + +logger = get_logger(__file__) + + +class FSDPadaptOptimizer(BaseOptimizer): + """ + optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file + reserve some necessary components of hybird-optim: + grad_scaler; + grad_clip and unscale; + state_dict and load_state_dict + """ + + def __init__( + self, + optimizer: Optimizer, + grad_scal_cfg: Config = None, + zero_cfg: Config = None, + ): + super().__init__(optim=optimizer) + + # gradient scaler + self.grad_scaler = DynamicGradScaler( + initial_scale=grad_scal_cfg.fp16.initial_scale, + min_scale=grad_scal_cfg.fp16.min_scale, + growth_factor=grad_scal_cfg.growth_factor, + backoff_factor=grad_scal_cfg.backoff_factor, + growth_interval=grad_scal_cfg.fp16.growth_interval, + hysteresis=grad_scal_cfg.hysteresis, + max_scale=grad_scal_cfg.max_scale, + ) + + # clip gradient + self._clip_grad_norm = zero_cfg.clip_grad_norm + + # fp16 and fp32 params + # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group + self._fp16_param_groups = dict() + self._fp32_param_tensor_groups = dict() + + # init fp16 and fp32 params + for group_idx, param_group in enumerate(self.optim.param_groups): + group_params = param_group["params"] + + # fp16 FlatParam storage + self._fp16_param_groups[group_idx] = group_params + + # create copy of fp32 weight + fp32_tensor_param = [param.data.float() for param in group_params] + self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param + + # replace + param_group["params"] = fp32_tensor_param + + @property + def loss_scale(self): + return self.grad_scaler.scale + + def backward(self, loss, retain_graph=False): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + def _compute_norm_with_fsdp_flatten(self, group_id): + params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0] + gradients = [p.grad for p in params if p.untyped_storage().size() != 0] + norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) + + return norm_group + + def zero_grad(self): + for _, param_group in self._fp16_param_groups.items(): + for param in param_group: + param.grad = None + + def step(self): + # in case that fsdp-zero3 size is not equal to dp size + # FSDP module will only reduce gradient within FSDP process group + # so manually reduce grad is essential between two parallel FSDP process group + for group_idx in range(len(self.param_groups)): + params = self._fp16_param_groups[group_idx] + for param in params: + if param.requires_grad and param.grad is not None: + handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) + handle.wait() + + # compute norm + found_inf = False + norm_groups = {} + for group_idx in range(len(self.param_groups)): + group_name = self.param_groups[group_idx]["name"] if "name" in self.param_groups[group_idx] else "default" + group_name = f"{group_idx}_{group_name}" + norm_group = self._compute_norm_with_fsdp_flatten(group_idx) + if norm_group == -1: + found_inf = True + norm_groups[group_name] = norm_group + + loss_scale = float(self.loss_scale.item()) # backup + self.grad_scaler.update(found_inf) + if found_inf: + if gpc.is_rank_for_log(): + logger.warning("Overflow occurs, please check it.") + self.zero_grad() + return False, norm_groups + + # get the global norm + global_norm_groups = {} + if self._clip_grad_norm > 0: + for group_name, norm in norm_groups.items(): + global_norm_groups[group_name] = norm**0.5 + + # create gradient for fp32 params + for group_idx in range(len(self.param_groups)): + dtype = self._fp32_param_tensor_groups[group_idx][0].dtype + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] + grad_fp32 = [p.grad.to(dtype) for p in fp16_params] + + device = self._fp32_param_tensor_groups[group_idx][0].device + nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0] + for p, g in zip(nonzero_fp32, grad_fp32): + p.grad = g.to(device) + + # unscale + self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale) + + self.optim.step() + self.zero_grad() + + for group_idx in range(len(self._fp16_param_groups)): + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] + fp32_tensor_params = [ + p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0 + ] + # release fp32 grad + release_param_grad(fp32_tensor_params) + # update fp16 param + for p, q in zip(fp16_params, fp32_tensor_params): + p.data.copy_(q) + + for group_name, global_norm in global_norm_groups.items(): + global_norm_groups[group_name] = global_norm / loss_scale + return True, global_norm_groups + + def clip_grad_norm(self, model, max_norm): + # will conduct in the step() + pass + + ######################### + # utils from hybirdzero # + ######################### + + def _unscale_and_clip_grads(self, total_norm_groups, loss_scale): + # compute combined scale factor for this group + combined_scale_groups = [] + + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + for group_id, total_norm in enumerate(total_norm_groups): + combined_scale_groups.append(loss_scale) + clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1.0: + combined_scale_groups[group_id] = clip * loss_scale + + for group_id, param in self._fp32_param_tensor_groups.items(): + for p in param: + if p.untyped_storage().size() != 0: + p.grad.data.mul_(1.0 / combined_scale_groups[group_id]) + + def state_dict(self): + states = {} + grad_scaler = self.grad_scaler.state_dict() + states["grad_scaler"] = grad_scaler + optim_states = self.optim.state_dict() + states["base_optim_states"] = optim_states + + flat_fp32_weights = {} + for group_idx, param in self._fp32_param_tensor_groups.items(): + flat_fp32_weights[group_idx] = param + states["flat_fp32_weights"] = flat_fp32_weights + + return states + + def load_state_dict(self, states): + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) + optim_states = states["base_optim_states"] + self.optim.load_state_dict(optim_states) + + # load fp32 optimizer weight + flat_fp32_weights = states["flat_fp32_weights"] + assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) + for group_idx, param in flat_fp32_weights.items(): + self_param = self._fp32_param_tensor_groups[group_idx] + assert len(self_param) == len( + param + ), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" + for p, q in zip(self_param, param): + p.data.copy_(q.data) + + # load fp16 model weight + for group_idx, param in flat_fp32_weights.items(): + fp16_param = self._fp16_param_groups[group_idx] + fp32_param = self._fp32_param_tensor_groups[group_idx] + for p, q in zip(fp16_param, fp32_param): + p.data.copy_(q.data) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b2680ed..97004eb 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -33,53 +33,13 @@ from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.timeout import llm_timeout +from .base_optimizer import BaseOptimizer from .utils import compute_norm inf = math.inf logger = get_logger(__file__) -class BaseOptimizer(Optimizer): - """ - Base Optimizer. - """ - - def __init__(self, optim: Optimizer): # pylint: disable=W0231 - self.optim = optim - - @property - def param_groups(self): - return self.optim.param_groups - - @property - def defaults(self): - return self.optim.defaults - - def add_param_group(self, *args, **kwargs): - return self.optim.add_param_group(*args, **kwargs) - - def step(self, *args, **kwargs): - return self.optim.step(*args, **kwargs) - - def zero_grad(self, *args, **kwargs): - self.optim.zero_grad(*args, **kwargs) - - def load_state_dict(self, *args, **kwargs): - self.optim.load_state_dict(*args, **kwargs) - - def state_dict(self): - return self.optim.state_dict() - - def backward(self, loss): - loss.backward() - - def backward_by_grad(self, tensor, grad): - torch.autograd.backward(tensors=tensor, grad_tensors=grad) - - def clip_grad_norm(self): - pass - - class HybridZeroOptimizer(BaseOptimizer): """ Hybrid Zero Optimizer. diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 457d7a4..1fd0802 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -6,6 +6,7 @@ from .training_internlm import ( initialize_optimizer, load_new_batch, record_current_batch_training_metrics, + wrap_FSDP_model, ) __all__ = [ @@ -16,4 +17,5 @@ __all__ = [ "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", + "wrap_FSDP_model", ] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index b7a369a..7af58dd 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -1,13 +1,22 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import functools import time from functools import partial from typing import Callable, Iterable, Union import torch import torch.distributed as dist +from flash_attn.modules.embedding import ParallelGPT2Embeddings +from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode @@ -25,24 +34,29 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data +from internlm.model.embedding import Embedding1D +from internlm.model.linear import ( + FeedForward, + RewardModelLinear, + ScaleColumnParallelLinear, +) +from internlm.model.multi_head_attention import MHA +from internlm.model.utils import try_import_RMSNorm from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.train.utils import create_param_groups from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import ( - is_no_pp_or_last_stage, - sync_model_param, - sync_model_param_within_tp, -) +from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout +RMSNorm = try_import_RMSNorm() logger = get_logger(__file__) @@ -72,7 +86,7 @@ def initialize_model(): else: model = NaiveAMPModel( model=model, - output_to_fp32=is_no_pp_or_last_stage(), + output_to_fp32=gpc.is_no_pp_or_last_stage(), dtype=gpc.config.model.get("dtype", torch.half), sync_buffer=False, ) @@ -90,6 +104,42 @@ def initialize_model(): # state in the same dp group are all the same. set_mode(ParallelMode.DATA) + # if fsdp enabled, wrap the model + model = wrap_FSDP_model(model) + + return model + + +def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + if gpc.config.parallel.zero1.fsdp: + # set wrap_policy for fsdp wrap + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + Embedding1D, + ParallelGPT2Embeddings, + MHA, + RMSNorm, + FeedForward, + ParallelFusedMLP, + RewardModelLinear, + ScaleColumnParallelLinear, + }, + ) + + # wrap the model + grp = gpc.get_group(ParallelMode.ZERO1) + model = FSDP( + module=model, + process_group=grp, + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=transformer_wrap_policy, + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + use_orig_params=True, + ) + return model @@ -118,12 +168,19 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): eps=adam_cfg.adam_eps, ) - optimizer = HybridZeroOptimizer( - naive_optimizer, - grad_scal_cfg=gpc.config.grad_scaler, - zero_cfg=gpc.config.hybrid_zero_optimizer, - param_bcast_sync_handler=param_bcast_sync_handler, - ) + if not gpc.config.parallel.zero1.fsdp: + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + else: + optimizer = FSDPadaptOptimizer( + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) @@ -360,7 +417,7 @@ def record_current_batch_training_metrics( timer.store_last_timers() if success_update in (0, True): train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) - if is_no_pp_or_last_stage(): + if gpc.is_no_pp_or_last_stage(): acc_perplex = metric.get_metric() if success_update and gpc.is_rank_for_log(): diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 00e7436..d63ed7a 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -12,6 +12,9 @@ from enum import Enum from typing import Callable, Dict, Union import torch +from torch.distributed._shard.api import load_with_process_group +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -157,11 +160,38 @@ def get_model_topology(model): return topos +def get_shard_state_dict(shard_model): + """ + Only used for FSDP module saving. + It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter + (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. + 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu + + """ + + # FSDP model can only save with sharded shape SHARDED_STATE_DICT when set use_orig_params=True + with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): + shard_states = shard_model.state_dict() + + return shard_states + + +def load_shard_state_dict(shard_model, shard_state, **kwargs): + """ + Only used for FSDP module loading. + + """ + + with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): + missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs) + + return (missing_k, unexpected_keys) + + def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): load_content_str = "" load_ckpt_folder = load_info["path"] load_content: CheckpointLoadMask = load_info["content"] - if gpc.is_rank_for_log(): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") @@ -224,6 +254,10 @@ def save_model_checkpoint(folder, model): - folder - model_tp{tp_rank}_pp{pp_rank}.pt + If fsdp is activated, the saved weight is named: + - folder + - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank} + If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. Args: @@ -231,7 +265,11 @@ def save_model_checkpoint(folder, model): model: The model to be saved """ - states = model.state_dict() + if gpc.config.parallel.zero1.fsdp: + states = get_shard_state_dict(model) + else: + states = model.state_dict() + # get non-expert parameters states = get_non_moe_state_dict(states) topo = get_model_topology(model) @@ -247,15 +285,21 @@ def save_model_checkpoint(folder, model): # even if pp is not considered, it will definitely not be written on the same machine. should_save_rank_pair = set() # (tp_rank, dp_rank) for i in range(tp_size): - should_save_rank_pair.add((i, i % dp_size)) + if gpc.config.parallel.zero1.fsdp: + for j in range(dp_size): + should_save_rank_pair.add((i, j)) + else: + should_save_rank_pair.add((i, i % dp_size)) - if (tp_rank, dp_rank) in should_save_rank_pair: - fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" - fp = os.path.join(folder, fn) - llm_save(fp, saved_obj=states) - topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" - topo_fp = os.path.join(folder, topo_fn) - llm_save(topo_fp, saved_obj=topo) + if (tp_rank, dp_rank) in should_save_rank_pair: + f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else "" + fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt" + fp = os.path.join(folder, fn) + llm_save(fp, saved_obj=states) + if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size: + topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" + topo_fp = os.path.join(folder, topo_fn) + llm_save(topo_fp, saved_obj=topo) # try to save expert parameter to separate files if model have moe layer expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) @@ -276,21 +320,39 @@ def load_model_checkpoint(folder, model): - folder - model_tp{tp_rank}_pp{pp_rank}.pt + If fsdp is activated, the saved weight is named: + - folder + - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank} + If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. """ tp_size = gpc.get_world_size(ParallelMode.TENSOR) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + dp_size = gpc.get_world_size(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + dp_rank = gpc.get_local_rank(ParallelMode.DATA) fns = get_fns(folder) - max_pp, max_tp = 0, 0 + + # avoid ckpt misuse between FSDP and no-FSDP + test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() + assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or ( + "_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp + ), "FSDP model wants to load no-FSDP ckpts or reverse" + + max_pp, max_tp, max_zo = 0, 0, 0 for fn in fns: if fn.startswith("model_t") and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") - max_pp = max(max_pp, int(segements[-1][2:])) - max_tp = max(max_tp, int(segements[-2][2:])) + if gpc.config.parallel.zero1.fsdp: + max_zo = max(max_zo, int(segements[-1][2:])) + max_pp = max(max_pp, int(segements[-2][2:])) + max_tp = max(max_tp, int(segements[-3][2:])) + else: + max_pp = max(max_pp, int(segements[-1][2:])) + max_tp = max(max_tp, int(segements[-2][2:])) assert ( pp_size == max_pp + 1 @@ -298,10 +360,20 @@ def load_model_checkpoint(folder, model): assert ( tp_size == max_tp + 1 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" + if gpc.config.parallel.zero1.fsdp: + assert ( + dp_size == max_zo + 1 + ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" + if gpc.config.parallel.zero1.fsdp: + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" + else: + should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name) - states = llm_load(fp, map_location=get_current_device()) + + # for FSDP shards loading, we need to set process group + with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)): + states = llm_load(fp, map_location=get_current_device()) """ # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in @@ -316,7 +388,10 @@ def load_model_checkpoint(folder, model): # try to load expert parameter to separate files if model have moe layer try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank) - missing_k, unexpected_keys = model.load_state_dict(states, strict=False) + if gpc.config.parallel.zero1.fsdp: + missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) + else: + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") if len(unexpected_keys) != 0: diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 9efef10..3029af5 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -50,10 +50,6 @@ def sync_model_param_within_tp(model): dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) -def is_no_pp_or_last_stage(): - return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) - - def get_parallel_log_file_name(): if gpc.is_rank_for_log(): fn_prefix = "main_" # Indicates a rank with more output information