diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index ba08abf..e9f508b 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -36,6 +36,16 @@ from internlm.model.modeling_internlm import ( PackedFlashBaseLayer1D, PackedFlashInternLm1D, ) + +from internlm.model.multi_head_attention import MHA +from flash_attn.modules.mha import ( + CrossAttention, + FlashCrossAttention, + FlashSelfAttention, + SelfAttention, + _update_kv_cache, +) + from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.beta2_scheduler import Beta2Scheduler @@ -107,9 +117,17 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm + RMSNorm = try_import_RMSNorm() if gpc.config.parallel.use_fsdp: transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D} + transformer_auto_wrap_policy, transformer_layer_cls={ + PackedFlashBaseLayer1D, + PackedFlashInternLm1D, + MHA, + FlashCrossAttention, + FlashSelfAttention, + RMSNorm} ) grp = gpc.get_group(ParallelMode.ZERO1) model = FSDP(