From 5bca32e4dc424a69b3ad62bad86cdac3e66c1cae Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 9 Oct 2023 11:11:04 +0800 Subject: [PATCH] fix(internlm/train/training_internlm.py): update wrap class and fix lint error --- internlm/train/training_internlm.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 154ec04..53a5711 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -8,6 +8,8 @@ from typing import Callable, Iterable, Union import torch import torch.distributed as dist +from flash_attn.modules.embedding import ParallelGPT2Embeddings +from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( @@ -17,7 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import ConcatDataset, DataLoader -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.random import set_mode from internlm.core.naive_amp import NaiveAMPModel @@ -32,9 +34,11 @@ from internlm.data.packed_dataset import ( get_packed_dataset_without_short_length, ) from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data -from internlm.model.modeling_internlm import ( - PackedFlashBaseLayer1D, - PackedFlashInternLm1D, +from internlm.model.embedding import Embedding1D +from internlm.model.linear import ( + FeedForward, + RewardModelLinear, + ScaleColumnParallelLinear, ) from internlm.model.multi_head_attention import MHA from internlm.model.utils import try_import_RMSNorm @@ -52,6 +56,7 @@ from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout +RMSNorm = try_import_RMSNorm() logger = get_logger(__file__) @@ -106,12 +111,20 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - RMSNorm = try_import_RMSNorm() if gpc.config.parallel.use_fsdp: # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, - transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm}, + transformer_layer_cls={ + Embedding1D, + ParallelGPT2Embeddings, + MHA, + RMSNorm, + FeedForward, + ParallelFusedMLP, + RewardModelLinear, + ScaleColumnParallelLinear, + }, ) # wrap the model