more wrap

pull/293/head
zaglc 2023-09-27 17:35:28 +08:00
parent c703938fb3
commit 80f1eb9a36
1 changed files with 19 additions and 1 deletions

View File

@ -36,6 +36,16 @@ from internlm.model.modeling_internlm import (
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
)
from internlm.model.multi_head_attention import MHA
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -107,9 +117,17 @@ def initialize_model():
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp:
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
transformer_auto_wrap_policy, transformer_layer_cls={
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
MHA,
FlashCrossAttention,
FlashSelfAttention,
RMSNorm}
)
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(