diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index 3727012..5cbb832 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -11,9 +11,9 @@ from .process_group_initializer import ( Initializer_Pipeline, Initializer_Tensor, Initializer_Zero1, + Initializer_Zero3_dp, ParallelMode, ProcessGroupInitializer, - Initializer_Zero3_dp, ) from .random import ( add_seed, diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 18544a7..3100236 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -478,7 +478,7 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) - if self.config.parallel.use_fsdp: + if self.config.parallel.get("use_fsdp", False): initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) if self.pipeline_parallel_size > 1: diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 7968e75..c4a1eb7 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,6 +1,10 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff +from .hybrid_zero_optim import ( + FSDPadaptOptimizer, + HybridZeroOptimizer, + reload_zero_fp32_buff, +) -__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"] \ No newline at end of file +__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"] diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 330b696..3b384f3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -82,13 +82,13 @@ class BaseOptimizer(Optimizer): class FSDPadaptOptimizer(BaseOptimizer): - ''' + """ optimizer for Pytorch FSDP if 'use_fsdp' is True in config file reserve some necessary components of hybird-optim: grad_scaler; grad_clip and unscale; state_dict and load_state_dict - ''' + """ def __init__( self, @@ -146,11 +146,7 @@ class FSDPadaptOptimizer(BaseOptimizer): def _compute_norm_with_fsdp_flatten(self, group_id): params = self._fp16_param_groups[group_id] gradients = [p.grad for p in params] - norm_group = compute_norm( - gradients=gradients, - parameters=params, - last_stage=True - ) + norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) return norm_group @@ -178,7 +174,6 @@ class FSDPadaptOptimizer(BaseOptimizer): norm_group = self._compute_norm_with_fsdp_flatten(group_idx) if norm_group == -1: found_inf = True - break norm_groups[group_name] = norm_group loss_scale = float(self.loss_scale.item()) # backup @@ -187,7 +182,7 @@ class FSDPadaptOptimizer(BaseOptimizer): if gpc.is_rank_for_log(): logger.warning("Overflow occurs, please check it.") self.zero_grad() - return False, None + return False, norm_groups # get the global norm global_norm_groups = {} @@ -211,10 +206,12 @@ class FSDPadaptOptimizer(BaseOptimizer): self.optim.step() self.zero_grad() - # update fp16 param for group_idx in range(len(self._fp16_param_groups)): fp16_params = self._fp16_param_groups[group_idx] fp32_tensor_params = self._fp32_param_tensor_groups[group_idx] + # release fp32 grad + release_param_grad(fp32_tensor_params) + # update fp16 param for p, q in zip(fp16_params, fp32_tensor_params): p.data.copy_(q) @@ -272,8 +269,8 @@ class FSDPadaptOptimizer(BaseOptimizer): assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups) for group_idx, param in flat_fp32_weights.items(): self_param = self._fp32_param_tensor_groups[group_idx] - assert ( - len(self_param) == len(param) + assert len(self_param) == len( + param ), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}" for p, q in zip(self_param, param): p.data.copy_(q.data) diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index a5a2995..1fd0802 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -6,7 +6,7 @@ from .training_internlm import ( initialize_optimizer, load_new_batch, record_current_batch_training_metrics, - warp_FSDP_model, + wrap_FSDP_model, ) __all__ = [ @@ -17,5 +17,5 @@ __all__ = [ "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", - "warp_FSDP_model", + "wrap_FSDP_model", ] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index e5b5097..b82aef2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import functools import time from functools import partial from typing import Callable, Iterable, Union @@ -8,6 +9,12 @@ from typing import Callable, Iterable, Union import torch import torch.distributed as dist from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import ConcatDataset, DataLoader from internlm.core.context import ParallelMode @@ -25,11 +32,15 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data +from internlm.model.modeling_internlm import ( + PackedFlashBaseLayer1D, + PackedFlashInternLm1D, +) from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.solver.optimizer import HybridZeroOptimizer, FSDPadaptOptimizer +from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger @@ -42,17 +53,6 @@ from internlm.utils.parallel import ( from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - ShardingStrategy, - MixedPrecision, - BackwardPrefetch, -) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -import functools -from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D - - logger = get_logger(__file__) @@ -103,19 +103,20 @@ def initialize_model(): return model -def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]): +def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): if gpc.config.parallel.use_fsdp: transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D} + transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D} ) grp = gpc.get_group(ParallelMode.ZERO1) - model = FSDP(module=model, - process_group=grp, - sharding_strategy=ShardingStrategy.FULL_SHARD, - auto_wrap_policy=transformer_wrap_policy, - forward_prefetch=True, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + model = FSDP( + module=model, + process_group=grp, + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=transformer_wrap_policy, + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, ) return model diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index a7a0c16..1e5007b 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -11,6 +11,9 @@ from enum import Enum from typing import Callable, Dict, Union import torch +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -162,13 +165,10 @@ def get_state_dict(model): 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu """ - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import FullStateDictConfig, StateDictType # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, save_policy): + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): states = model.state_dict() return states diff --git a/train.py b/train.py index 35e9612..cab61c7 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,7 @@ from internlm.train import ( initialize_optimizer, load_new_batch, record_current_batch_training_metrics, - warp_FSDP_model, + wrap_FSDP_model, ) from internlm.utils.common import ( BatchSkipper, @@ -111,8 +111,8 @@ def main(args): # initialize and resume train state train_state = TrainState(gpc.config, train_dl.batch_sampler) - # if fsdp enabled, warp the model - model = warp_FSDP_model(model) + # if fsdp enabled, warp the model + model = wrap_FSDP_model(model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)