mirror of https://github.com/InternLM/InternLM
more wrap
parent
c703938fb3
commit
80f1eb9a36
|
@ -36,6 +36,16 @@ from internlm.model.modeling_internlm import (
|
|||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
)
|
||||
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
|
||||
from internlm.monitor import send_heartbeat, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -107,9 +117,17 @@ def initialize_model():
|
|||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
||||
transformer_auto_wrap_policy, transformer_layer_cls={
|
||||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
MHA,
|
||||
FlashCrossAttention,
|
||||
FlashSelfAttention,
|
||||
RMSNorm}
|
||||
)
|
||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||
model = FSDP(
|
||||
|
|
Loading…
Reference in New Issue