feat(initialize/launch.py): refactor config for fstp

pull/456/head
huangting4201 2023-10-20 15:59:40 +08:00
parent 815a584930
commit d91a5d9d9e
7 changed files with 63 additions and 44 deletions

View File

@ -152,19 +152,19 @@ zero1 parallel (dict):
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict): tensor parallel (dict):
1. size: int, the size of tensor parallel. 1. size: int, the size of tensor parallel.
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
the sequence_parallel should be True. defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
pipeline parallel (dict): pipeline parallel (dict):
1. size: int, the size of pipeline parallel. 1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False. defaults to False.
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True), tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
) )
cudnn_deterministic = False cudnn_deterministic = False

View File

@ -306,15 +306,20 @@ def args_sanity_check():
), "sequence parallel does not support use_flash_attn=False" ), "sequence parallel does not support use_flash_attn=False"
if isinstance(gpc.config.parallel["tensor"], int): if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
if gpc.config.parallel["tensor"].get("sp", None) is None:
if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["sp"] = "none"
gpc.config.parallel["tensor"]["mode"] = "origin_tp" if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
gpc.config.parallel["tensor"]["intern_overlap"] = False
if gpc.config.parallel["tensor"].get("mode", None) == "fstp": assert gpc.config.parallel["tensor"].get("sp", None) in [
assert ( "none",
gpc.config.parallel.sequence_parallel is True "megatron",
), "when the tp_mode is fstp, the sequence_parallel should be True." "flash-attn",
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
gpc.config.parallel.sequence_parallel = True
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
tp_mode: str = "origin_tp", sp_mode: str = "none",
): ):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
device=device, device=device,
dtype=dtype, dtype=dtype,
tp_mode=tp_mode, sp_mode=sp_mode,
) )
self.dropout1 = nn.Dropout(drop_rate) self.dropout1 = nn.Dropout(drop_rate)
@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu: if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
self.mlp = mlp_cls( self.mlp = mlp_cls(
hidden_size, hidden_size,
int(hidden_size * mlp_ratio), int(hidden_size * mlp_ratio),
@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__() super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint) checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"] self.sp_mode = gpc.config.parallel["tensor"]["sp"]
if is_reward: if is_reward:
head_cls = RewardModelLinear head_cls = RewardModelLinear
@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init, use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode, sp_mode=self.sp_mode,
) )
for lid in range(num_layers) for lid in range(num_layers)
] ]
@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1 assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input. # The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0] indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -175,7 +175,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True, use_flash_attn: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp", sp_mode: str = "none",
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -203,7 +203,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
self.Wqkv = Wqkv_cls( self.Wqkv = Wqkv_cls(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
@ -219,12 +219,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls( self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
) )
if tp_mode == "fstp": if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,

View File

@ -10,7 +10,10 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool from internlm.model.utils import (
release_reduce_scatter_memory_pool,
split_forward_gather_backward,
)
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import ( from internlm.solver.optimizer.store import (
BucketStore, BucketStore,
@ -40,8 +43,20 @@ from .utils import compute_norm
inf = math.inf inf = math.inf
logger = get_logger(__file__) logger = get_logger(__file__)
def print_memory(msg): def print_memory(msg):
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) print(
msg,
" rank = ",
gpc.get_global_rank(),
" memory allocated: ",
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
" reverved memory: ",
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
" max memory: ",
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
flush=True,
)
print("===========================================") print("===========================================")
@ -70,7 +85,7 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler self._fstp_handler = gpc.config.fstp_handler
# Zero related args # Zero related args

View File

@ -110,9 +110,8 @@ def initialize_model():
gpc.config.fstp_handler = None gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook() handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler gpc.config.fstp_handler = handler

View File

@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode(): def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel prev_mode = gpc.config.parallel.sequence_parallel
try: try:
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["sp"] == "intern":
gpc.config.parallel.sequence_parallel = True gpc.config.parallel.sequence_parallel = True
else: else:
gpc.config.parallel.sequence_parallel = False gpc.config.parallel.sequence_parallel = False
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["sp"] == "intern":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size( tensor_shape = torch.Size(
[ [