diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1c5f7a7..e994b24 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -63,15 +63,14 @@ def args_sanity_check(): gpc.config.parallel._add_item("tensor", 1) if isinstance(gpc.config.parallel.pipeline, int): - pp = gpc.config.parallel.pipeline + pp = gpc.config.parallel.pipelines else: pp = gpc.config.parallel.pipeline.size - tp = gpc.config.parallel.tensor if "use_fsdp" not in gpc.config.parallel: gpc.config.parallel._add_item("use_fsdp", False) - elif gpc.config.parallel.use_fsdp and (pp > 1 or tp > 1): - logger.warning("FSDP not support when pipeline/tensor parallel is enabled, auto-close FSDP") + elif gpc.config.parallel.use_fsdp and pp > 1: + logger.warning("FSDP not support when pipeline parallel is enabled, auto-close FSDP") gpc.config.parallel._add_item("use_fsdp", False) # processing the data config in gpc diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index a19de0e..463ed91 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -50,9 +50,6 @@ class FSDPadaptOptimizer(BaseOptimizer): self._clip_grad_norm = zero_cfg.clip_grad_norm self.use_fsdp = gpc.config.parallel.use_fsdp - # mark whether a module is part of TP or not - # TODO: is_tensor_parallel_dict = dict() - # fp16 and fp32 params # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group self._fp16_param_groups = dict() @@ -66,7 +63,7 @@ class FSDPadaptOptimizer(BaseOptimizer): self._fp16_param_groups[group_idx] = group_params # create copy of fp32 weight - fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params] + fp32_tensor_param = [param.data.float() for param in group_params] self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param # replace @@ -81,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer): loss.backward(retain_graph=retain_graph) def _compute_norm_with_fsdp_flatten(self, group_id): - params = self._fp16_param_groups[group_id] - gradients = [p.grad for p in params] + params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0] + gradients = [p.grad for p in params if p.storage().size() != 0] norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) return norm_group @@ -99,7 +96,7 @@ class FSDPadaptOptimizer(BaseOptimizer): for group_idx in range(len(self.param_groups)): params = self._fp16_param_groups[group_idx] for param in params: - if param.requires_grad: + if param.requires_grad and param.grad is not None: handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) handle.wait() @@ -131,11 +128,12 @@ class FSDPadaptOptimizer(BaseOptimizer): # create gradient for fp32 params for group_idx in range(len(self.param_groups)): dtype = self._fp32_param_tensor_groups[group_idx][0].dtype - fp16_params = self._fp16_param_groups[group_idx] + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0] grad_fp32 = [p.grad.to(dtype) for p in fp16_params] device = self._fp32_param_tensor_groups[group_idx][0].device - for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32): + nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0] + for p, g in zip(nonzero_fp32, grad_fp32): p.grad = g.to(device) # unscale @@ -145,8 +143,8 @@ class FSDPadaptOptimizer(BaseOptimizer): self.zero_grad() for group_idx in range(len(self._fp16_param_groups)): - fp16_params = self._fp16_param_groups[group_idx] - fp32_tensor_params = self._fp32_param_tensor_groups[group_idx] + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0] + fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0] # release fp32 grad release_param_grad(fp32_tensor_params) # update fp16 param @@ -177,9 +175,10 @@ class FSDPadaptOptimizer(BaseOptimizer): if clip > 1.0: combined_scale_groups[group_id] = clip * loss_scale - for group_id, grads in self._fp32_param_tensor_groups.items(): - for g in grads: - g.grad.data.mul_(1.0 / combined_scale_groups[group_id]) + for group_id, param in self._fp32_param_tensor_groups.items(): + for p in param: + if p.storage().size() != 0: + p.grad.data.mul_(1.0 / combined_scale_groups[group_id]) def state_dict(self): states = {} diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index e9f508b..6380ed8 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -17,7 +17,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import ConcatDataset, DataLoader -from internlm.core.context import ParallelMode +from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.random import set_mode from internlm.core.naive_amp import NaiveAMPModel @@ -36,16 +36,8 @@ from internlm.model.modeling_internlm import ( PackedFlashBaseLayer1D, PackedFlashInternLm1D, ) - from internlm.model.multi_head_attention import MHA -from flash_attn.modules.mha import ( - CrossAttention, - FlashCrossAttention, - FlashSelfAttention, - SelfAttention, - _update_kv_cache, -) - +from internlm.model.utils import try_import_RMSNorm from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler @@ -117,18 +109,23 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm RMSNorm = try_import_RMSNorm() if gpc.config.parallel.use_fsdp: + # pre-save info for tensor parallel + tp_dict = dict() + for name, param in model.named_parameters(): + if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL): + tp_dict.update({name.replace("model.", ""): True}) + else: + tp_dict.update({name.replace("model.", ""): False}) + + # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={ - PackedFlashBaseLayer1D, - PackedFlashInternLm1D, - MHA, - FlashCrossAttention, - FlashSelfAttention, - RMSNorm} + transformer_auto_wrap_policy, + transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm}, ) + + # wrap the model grp = gpc.get_group(ParallelMode.ZERO1) model = FSDP( module=model, @@ -138,8 +135,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): forward_prefetch=True, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, limit_all_gathers=True, + use_orig_params=True, ) + # re-set attribute for fsdp module + for (name, param), pre in zip(model.named_parameters(), tp_dict): + if pre in name and tp_dict[pre]: + setattr(param, IS_TENSOR_PARALLEL, True) + return model diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 20a7d49..ff0ce67 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -171,8 +171,8 @@ def get_shard_state_dict(shard_model): # with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): # states = model.state_dict() - # in this version, FSDP model can only save with sharded shape - with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT): + # in this version, FSDP model can only save with sharded shapeLOCAL_STATE_DICT + with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): shard_states = shard_model.state_dict() return shard_states @@ -184,7 +184,7 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs): """ - with FSDP.state_dict_type(shard_model, StateDictType.LOCAL_STATE_DICT): + with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs) return (missing_k, unexpected_keys) diff --git a/train.py b/train.py index b9fe6af..26222e2 100644 --- a/train.py +++ b/train.py @@ -289,6 +289,8 @@ def main(args): if __name__ == "__main__": + assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}" + args = parse_args() hostname = socket.gethostname()