mirror of https://github.com/InternLM/InternLM
fix(internlm/train/training_internlm.py): update wrap class and fix lint error
parent
2e94870967
commit
5bca32e4dc
|
@ -8,6 +8,8 @@ from typing import Callable, Iterable, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
|
@ -17,7 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.random import set_mode
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
|
@ -32,9 +34,11 @@ from internlm.data.packed_dataset import (
|
|||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.modeling_internlm import (
|
||||
PackedFlashBaseLayer1D,
|
||||
PackedFlashInternLm1D,
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import try_import_RMSNorm
|
||||
|
@ -52,6 +56,7 @@ from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
|||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
|
@ -106,12 +111,20 @@ def initialize_model():
|
|||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
# set wrap_policy for fsdp wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm},
|
||||
transformer_layer_cls={
|
||||
Embedding1D,
|
||||
ParallelGPT2Embeddings,
|
||||
MHA,
|
||||
RMSNorm,
|
||||
FeedForward,
|
||||
ParallelFusedMLP,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
},
|
||||
)
|
||||
|
||||
# wrap the model
|
||||
|
|
Loading…
Reference in New Issue