fix(internlm/train/training_internlm.py): update wrap class and fix lint error

pull/293/head
huangting4201 2023-10-09 11:11:04 +08:00
parent 2e94870967
commit 5bca32e4dc
1 changed files with 19 additions and 6 deletions

View File

@ -8,6 +8,8 @@ from typing import Callable, Iterable, Union
import torch
import torch.distributed as dist
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
@ -17,7 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
@ -32,9 +34,11 @@ from internlm.data.packed_dataset import (
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.modeling_internlm import (
PackedFlashBaseLayer1D,
PackedFlashInternLm1D,
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import try_import_RMSNorm
@ -52,6 +56,7 @@ from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
RMSNorm = try_import_RMSNorm()
logger = get_logger(__file__)
@ -106,12 +111,20 @@ def initialize_model():
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp:
# set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm},
transformer_layer_cls={
Embedding1D,
ParallelGPT2Embeddings,
MHA,
RMSNorm,
FeedForward,
ParallelFusedMLP,
RewardModelLinear,
ScaleColumnParallelLinear,
},
)
# wrap the model