mirror of https://github.com/InternLM/InternLM
more wrap
parent
c703938fb3
commit
80f1eb9a36
|
@ -36,6 +36,16 @@ from internlm.model.modeling_internlm import (
|
||||||
PackedFlashBaseLayer1D,
|
PackedFlashBaseLayer1D,
|
||||||
PackedFlashInternLm1D,
|
PackedFlashInternLm1D,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from internlm.model.multi_head_attention import MHA
|
||||||
|
from flash_attn.modules.mha import (
|
||||||
|
CrossAttention,
|
||||||
|
FlashCrossAttention,
|
||||||
|
FlashSelfAttention,
|
||||||
|
SelfAttention,
|
||||||
|
_update_kv_cache,
|
||||||
|
)
|
||||||
|
|
||||||
from internlm.monitor import send_heartbeat, set_env_var
|
from internlm.monitor import send_heartbeat, set_env_var
|
||||||
from internlm.monitor.monitor import monitor_manager as mm
|
from internlm.monitor.monitor import monitor_manager as mm
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
|
@ -107,9 +117,17 @@ def initialize_model():
|
||||||
|
|
||||||
|
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy, transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D}
|
transformer_auto_wrap_policy, transformer_layer_cls={
|
||||||
|
PackedFlashBaseLayer1D,
|
||||||
|
PackedFlashInternLm1D,
|
||||||
|
MHA,
|
||||||
|
FlashCrossAttention,
|
||||||
|
FlashSelfAttention,
|
||||||
|
RMSNorm}
|
||||||
)
|
)
|
||||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
model = FSDP(
|
model = FSDP(
|
||||||
|
|
Loading…
Reference in New Issue