diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index f8b70b6..79ce106 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -392,7 +392,7 @@ class Initializer_Nettest(ProcessGroupInitializer): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode - + class Initializer_Zero3_dp(ProcessGroupInitializer): """A ProcessGroupInitializer for data parallelism. @@ -421,7 +421,6 @@ class Initializer_Zero3_dp(ProcessGroupInitializer): assert self.world_size % self.data_parallel_size == 0 - def init_dist_group(self, use_cpu: bool = False): """Initialize data parallel groups, and assign local_ranks and groups to each gpu. diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index db776a6..e903342 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -90,13 +90,13 @@ class FSDPadaptOptimizer(BaseOptimizer): ''' def __init__( - self, - optimizer: Optimizer, - grad_scal_cfg: Config = None, - zero_cfg: Config = None, - ): + self, + optimizer: Optimizer, + grad_scal_cfg: Config = None, + zero_cfg: Config = None, + ): super().__init__(optim=optimizer) - + # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=grad_scal_cfg.fp16.initial_scale, @@ -113,7 +113,7 @@ class FSDPadaptOptimizer(BaseOptimizer): self.use_fsdp = gpc.config.parallel.use_fsdp # mark whether a module is part of TP or not - is_tensor_parallel_dict = dict() + # TODO: is_tensor_parallel_dict = dict() # fp16 and fp32 params # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group @@ -150,7 +150,7 @@ class FSDPadaptOptimizer(BaseOptimizer): parameters=params, last_stage=True ) - + return norm_group def zero_grad(self): @@ -187,12 +187,12 @@ class FSDPadaptOptimizer(BaseOptimizer): logger.warning("Overflow occurs, please check it.") self.zero_grad() return False, None - + # get the global norm global_norm_groups = {} if self._clip_grad_norm > 0: for group_name, norm in norm_groups.items(): - global_norm_groups[group_name] = norm**0.5 + global_norm_groups[group_name] = norm**0.5 # create gradient for fp32 params for group_idx in range(len(self.param_groups)): @@ -207,7 +207,7 @@ class FSDPadaptOptimizer(BaseOptimizer): # unscale self._unscale_and_clip_grads(list(global_norm_groups.values()), loss_scale) - self.optim.step() + self.optim.step() self.zero_grad() # update fp16 param @@ -221,7 +221,6 @@ class FSDPadaptOptimizer(BaseOptimizer): global_norm_groups[group_name] = global_norm / loss_scale return True, global_norm_groups - def clip_grad_norm(self, model, max_norm): # will conduct in the step() pass diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 4fc7465..a0dd913 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -42,18 +42,11 @@ from internlm.utils.registry import MODEL_INITIALIZER from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( - CPUOffload, - BackwardPrefetch, ShardingStrategy, MixedPrecision, BackwardPrefetch, ) -from torch.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - transformer_auto_wrap_policy, - enable_wrap, - wrap, -) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import functools from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlashInternLm1D @@ -107,23 +100,17 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]): if gpc.config.parallel.use_fsdp: transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, - transformer_layer_cls = {PackedFlashBaseLayer1D, PackedFlashInternLm1D} + transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D} ) - mx = MixedPrecision( - param_dtype=gpc.config.model.dtype, reduce_dtype=gpc.config.model.dtype, - buffer_dtype=gpc.config.model.dtype, keep_low_precision_grads=True) grp = gpc.get_group(ParallelMode.ZERO1) - model = FSDP(module=model, + model = FSDP(module=model, process_group=grp, sharding_strategy=ShardingStrategy.FULL_SHARD, auto_wrap_policy=transformer_wrap_policy, forward_prefetch=True, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - #cpu_offload=CPUOfload(offload_params=True) - #mixed_precision=mx, - #device_id=torch.cuda.current_device() - ) - + ) + return model @@ -159,9 +146,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): ) else: optimizer = FSDPadaptOptimizer( - naive_optimizer, - grad_scal_cfg=gpc.config.grad_scaler, - zero_cfg=gpc.config.hybrid_zero_optimizer, + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 120df0c..b36afec 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -57,14 +57,14 @@ def get_model_topology(model): def get_state_dict(model): """ - Only used for FSDP module saving. - It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter + Only used for FSDP module saving. + It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu - + """ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import FullStateDictConfig, StateDictType# , FullOptimStateDictConfig + from torch.distributed.fsdp import FullStateDictConfig, StateDictType # TODO: rank0_only can save memory for non-rank0 gpu, but when tp is enabled, model saving will left some parameters save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) @@ -91,9 +91,9 @@ def save_model_checkpoint(folder, model): if gpc.config.parallel.use_fsdp: states = get_state_dict(model) - else: + else: states = model.state_dict() - + topo = get_model_topology(model) if folder is not None: