feat(*): support sequence_parallel (#180)

* support sequence_parallel for no pipeline

* sequence_parallel does not support no-flash-attn

* support sequence parallel for pipeline

* add memory profiler

* Update 13B.py

* add memory profiler

* fix evaluation bug

* remove some unnecessary code

* remove some unnecessary code

* Update parallel_context.py

* modify the config

* remove memory profiler

* modify the config

* support selective dropout
pull/190/head
ytxiong 2023-08-07 16:42:52 +08:00 committed by GitHub
parent 853becfb6e
commit c219065348
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 95 additions and 37 deletions

View File

@ -118,6 +118,7 @@ model = dict(
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
)
"""
zero1 parallel:

View File

@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):

View File

@ -30,10 +30,17 @@ def get_tensor_shape():
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
if gpc.config.model.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.data["micro_bsz"],
@ -132,6 +139,9 @@ class PipelineScheduler(BaseScheduler):
and gpc.is_initialized(ParallelMode.TENSOR)
and gpc.get_world_size(ParallelMode.TENSOR) > 1
)
if gpc.config.model.sequence_parallel:
self.scatter_gather_tensors = False
# cache for the batch data
self.batch_data = None
@ -254,7 +264,6 @@ class PipelineScheduler(BaseScheduler):
if gpc.is_last_rank(ParallelMode.PIPELINE):
self._call_hooks("post_helper_func", output_obj, label)
if return_output_label:
return_tensors.append((output_obj, label))
if accum_loss is not None:

View File

@ -38,6 +38,7 @@ def get_default_parser():
parser.add_argument("--local_rank", type=int, help="local rank on the node")
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
parser.add_argument("--seed", type=int, default=1024)
parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.")
return parser
@ -198,6 +199,10 @@ def args_sanity_check():
# process the model config
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
else:
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
def launch(

View File

@ -13,7 +13,7 @@ from torch import Tensor, nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from .utils import gather_forward_split_backward
from .utils import gather_forward_split_backward, split_forward_gather_backward
class Embedding1D(nn.Module):
@ -55,7 +55,10 @@ class Embedding1D(nn.Module):
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output

View File

@ -38,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
@ -48,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.weight_scale = weight_scale
def forward(self, input): # pylint: disable=W0622
@ -60,7 +58,7 @@ class ScaleColumnParallelLinear(nn.Linear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
)
@ -87,12 +85,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
out_features: int,
process_group: Optional[torch.distributed.ProcessGroup],
bias: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_scale: int = 1,
) -> None:
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
if bias:
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
@ -106,7 +103,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
else:
weight = self.weight
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
)
@ -168,19 +165,19 @@ class FeedForward(nn.Module):
hidden_features,
process_group,
bias,
sequence_parallel=False,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
)
self.w3 = RowParallelLinearTorch(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=False,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
)

View File

@ -89,7 +89,6 @@ class PackedFlashBaseLayer1D(nn.Module):
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
sequence_parallel=False,
device=device,
dtype=dtype,
)
@ -121,7 +120,7 @@ class PackedFlashBaseLayer1D(nn.Module):
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
@ -300,7 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=False,
sequence_parallel=gpc.config.model.sequence_parallel,
device=device,
dtype=dtype,
)
@ -342,7 +341,6 @@ class PackedFlashInternLm1D(nn.Module):
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
sequence_parallel=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
@ -463,6 +461,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
sequence_parallel: bool = False,
):
"""
Builde model with config

View File

@ -59,7 +59,6 @@ class MHA(nn.Module):
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
@ -83,7 +82,7 @@ class MHA(nn.Module):
3 * embed_dim,
process_group,
bias=True,
sequence_parallel=sequence_parallel,
sequence_parallel=gpc.config.model.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
@ -96,7 +95,7 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:

View File

@ -157,6 +157,35 @@ def fused_dense_func_torch(
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def try_import_RMSNorm():
"""
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm

View File

@ -90,9 +90,15 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
if gpc.config.model.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
@ -108,7 +114,7 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
# import pdb; pdb.set_trace()
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
@ -155,3 +161,13 @@ def evaluate_on_val_dls(
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode

View File

@ -41,7 +41,7 @@ from internlm.utils.common import (
launch_time,
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
@ -618,14 +618,15 @@ def main(args):
# evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0:
evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
writer=writer,
logger=logger,
step_count=train_state.step_count,
update_panel=uniscale_logger is not None,
)
with switch_sequence_parallel_mode():
evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
writer=writer,
logger=logger,
step_count=train_state.step_count,
update_panel=uniscale_logger is not None,
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# save batch sampler that tracks the true consumed samples