feat(*): support sequence_parallel (#180)

* support sequence_parallel for no pipeline

* sequence_parallel does not support no-flash-attn

* support sequence parallel for pipeline

* add memory profiler

* Update 13B.py

* add memory profiler

* fix evaluation bug

* remove some unnecessary code

* remove some unnecessary code

* Update parallel_context.py

* modify the config

* remove memory profiler

* modify the config

* support selective dropout
pull/190/head
ytxiong 2023-08-07 16:42:52 +08:00 committed by GitHub
parent 853becfb6e
commit c219065348
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 95 additions and 37 deletions

View File

@ -118,6 +118,7 @@ model = dict(
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
) )
""" """
zero1 parallel: zero1 parallel:

View File

@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers: for initializer in initializers:
parallel_setting = initializer.init_dist_group() parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list): if isinstance(parallel_setting, list):

View File

@ -30,10 +30,17 @@ def get_tensor_shape():
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn: if gpc.config.model.use_flash_attn:
tensor_shape = ( if gpc.config.model.sequence_parallel:
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
gpc.config.HIDDEN_SIZE, tensor_shape = (
) gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
else: else:
tensor_shape = ( tensor_shape = (
gpc.config.data["micro_bsz"], gpc.config.data["micro_bsz"],
@ -132,6 +139,9 @@ class PipelineScheduler(BaseScheduler):
and gpc.is_initialized(ParallelMode.TENSOR) and gpc.is_initialized(ParallelMode.TENSOR)
and gpc.get_world_size(ParallelMode.TENSOR) > 1 and gpc.get_world_size(ParallelMode.TENSOR) > 1
) )
if gpc.config.model.sequence_parallel:
self.scatter_gather_tensors = False
# cache for the batch data # cache for the batch data
self.batch_data = None self.batch_data = None
@ -254,7 +264,6 @@ class PipelineScheduler(BaseScheduler):
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
self._call_hooks("post_helper_func", output_obj, label) self._call_hooks("post_helper_func", output_obj, label)
if return_output_label: if return_output_label:
return_tensors.append((output_obj, label)) return_tensors.append((output_obj, label))
if accum_loss is not None: if accum_loss is not None:

View File

@ -38,6 +38,7 @@ def get_default_parser():
parser.add_argument("--local_rank", type=int, help="local rank on the node") parser.add_argument("--local_rank", type=int, help="local rank on the node")
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024) parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.")
return parser return parser
@ -198,6 +199,10 @@ def args_sanity_check():
# process the model config # process the model config
if "use_flash_attn" not in gpc.config.model: if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True) gpc.config.model._add_item("use_flash_attn", True)
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
else:
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
def launch( def launch(

View File

@ -13,7 +13,7 @@ from torch import Tensor, nn
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from .utils import gather_forward_split_backward from .utils import gather_forward_split_backward, split_forward_gather_backward
class Embedding1D(nn.Module): class Embedding1D(nn.Module):
@ -55,7 +55,10 @@ class Embedding1D(nn.Module):
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output return output

View File

@ -38,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
out_features: int, out_features: int,
process_group: Optional[torch.distributed.ProcessGroup], process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True, bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
weight_scale: int = 1, weight_scale: int = 1,
@ -48,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.weight_scale = weight_scale self.weight_scale = weight_scale
def forward(self, input): # pylint: disable=W0622 def forward(self, input): # pylint: disable=W0622
@ -60,7 +58,7 @@ class ScaleColumnParallelLinear(nn.Linear):
else: else:
weight = self.weight weight = self.weight
return fused_dense_func_torch( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
) )
@ -87,12 +85,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
out_features: int, out_features: int,
process_group: Optional[torch.distributed.ProcessGroup], process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True, bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
weight_scale: int = 1, weight_scale: int = 1,
) -> None: ) -> None:
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale) super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
if bias: if bias:
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
@ -106,7 +103,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
else: else:
weight = self.weight weight = self.weight
return fused_dense_func_torch( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
) )
@ -168,19 +165,19 @@ class FeedForward(nn.Module):
hidden_features, hidden_features,
process_group, process_group,
bias, bias,
sequence_parallel=False, sequence_parallel=gpc.config.model.sequence_parallel,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w2 = ColumnParallelLinearTorch( self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
) )
self.w3 = RowParallelLinearTorch( self.w3 = RowParallelLinearTorch(
hidden_features, hidden_features,
out_features, out_features,
process_group, process_group,
bias=bias, bias=bias,
sequence_parallel=False, sequence_parallel=gpc.config.model.sequence_parallel,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View File

@ -89,7 +89,6 @@ class PackedFlashBaseLayer1D(nn.Module):
rotary_emb_dim=head_dim, rotary_emb_dim=head_dim,
rotary_emb_scale_base=0, rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
sequence_parallel=False,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -121,7 +120,7 @@ class PackedFlashBaseLayer1D(nn.Module):
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False, bias1=False,
bias2=False, bias2=False,
sequence_parallel=False, sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0, checkpoint_lvl=0,
heuristic="auto", heuristic="auto",
device=device, device=device,
@ -300,7 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
max_position_embeddings=-1, max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None, padding_idx=None,
sequence_parallel=False, sequence_parallel=gpc.config.model.sequence_parallel,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -342,7 +341,6 @@ class PackedFlashInternLm1D(nn.Module):
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False, bias=False,
sequence_parallel=False,
device=device, device=device,
dtype=dtype, dtype=dtype,
weight_scale=embed_grad_scale, weight_scale=embed_grad_scale,
@ -463,6 +461,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
sequence_parallel: bool = False,
): ):
""" """
Builde model with config Builde model with config

View File

@ -59,7 +59,6 @@ class MHA(nn.Module):
rotary_emb_dim: int = 0, rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0, rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True, use_flash_attn: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
@ -83,7 +82,7 @@ class MHA(nn.Module):
3 * embed_dim, 3 * embed_dim,
process_group, process_group,
bias=True, bias=True,
sequence_parallel=sequence_parallel, sequence_parallel=gpc.config.model.sequence_parallel,
**factory_kwargs, **factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577 ) # according to https://spaces.ac.cn/archives/9577
@ -96,7 +95,7 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = RowParallelLinearTorch( self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
) )
# need to assign tp attribute so that internlm know it is tensor parallel module # need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:

View File

@ -157,6 +157,35 @@ def fused_dense_func_torch(
else: else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def try_import_RMSNorm(): def try_import_RMSNorm():
""" """
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm

View File

@ -90,9 +90,15 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size( if gpc.config.model.sequence_parallel:
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
) tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler( with switch_evaluation_pipeline_scheduler(
trainer=trainer, trainer=trainer,
@ -108,7 +114,7 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz grad_accum_batch_size = data_cfg.micro_bsz
# import pdb; pdb.set_trace()
with switch_evaluation_no_pipeline_scheduler( with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, trainer=trainer,
grad_accum_size=grad_accum_size, grad_accum_size=grad_accum_size,
@ -155,3 +161,13 @@ def evaluate_on_val_dls(
trainer.train() trainer.train()
torch.cuda.empty_cache() torch.cuda.empty_cache()
dist.barrier() dist.barrier()
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode

View File

@ -41,7 +41,7 @@ from internlm.utils.common import (
launch_time, launch_time,
parse_args, parse_args,
) )
from internlm.utils.evaluation import evaluate_on_val_dls from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import ( from internlm.utils.model_checkpoint import (
@ -618,14 +618,15 @@ def main(args):
# evaluate on validation data loaders # evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0: if valid_every > 0 and train_state.step_count % valid_every == 0:
evaluate_on_val_dls( with switch_sequence_parallel_mode():
trainer=trainer, evaluate_on_val_dls(
val_dls=val_dls, trainer=trainer,
writer=writer, val_dls=val_dls,
logger=logger, writer=writer,
step_count=train_state.step_count, logger=logger,
update_panel=uniscale_logger is not None, step_count=train_state.step_count,
) update_panel=uniscale_logger is not None,
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# save batch sampler that tracks the true consumed samples # save batch sampler that tracks the true consumed samples