mirror of https://github.com/InternLM/InternLM
fix(internlm/train/training_internlm.py): update wrap class and fix lint error
parent
2e94870967
commit
5bca32e4dc
|
@ -8,6 +8,8 @@ from typing import Callable, Iterable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
|
@ -17,7 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.context.random import set_mode
|
from internlm.core.context.random import set_mode
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
@ -32,9 +34,11 @@ from internlm.data.packed_dataset import (
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.modeling_internlm import (
|
from internlm.model.embedding import Embedding1D
|
||||||
PackedFlashBaseLayer1D,
|
from internlm.model.linear import (
|
||||||
PackedFlashInternLm1D,
|
FeedForward,
|
||||||
|
RewardModelLinear,
|
||||||
|
ScaleColumnParallelLinear,
|
||||||
)
|
)
|
||||||
from internlm.model.multi_head_attention import MHA
|
from internlm.model.multi_head_attention import MHA
|
||||||
from internlm.model.utils import try_import_RMSNorm
|
from internlm.model.utils import try_import_RMSNorm
|
||||||
|
@ -52,6 +56,7 @@ from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,12 +111,20 @@ def initialize_model():
|
||||||
|
|
||||||
|
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
RMSNorm = try_import_RMSNorm()
|
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
transformer_layer_cls={PackedFlashBaseLayer1D, PackedFlashInternLm1D, MHA, RMSNorm},
|
transformer_layer_cls={
|
||||||
|
Embedding1D,
|
||||||
|
ParallelGPT2Embeddings,
|
||||||
|
MHA,
|
||||||
|
RMSNorm,
|
||||||
|
FeedForward,
|
||||||
|
ParallelFusedMLP,
|
||||||
|
RewardModelLinear,
|
||||||
|
ScaleColumnParallelLinear,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# wrap the model
|
# wrap the model
|
||||||
|
|
Loading…
Reference in New Issue