mirror of https://github.com/InternLM/InternLM
feat(*): support sequence_parallel (#180)
* support sequence_parallel for no pipeline * sequence_parallel does not support no-flash-attn * support sequence parallel for pipeline * add memory profiler * Update 13B.py * add memory profiler * fix evaluation bug * remove some unnecessary code * remove some unnecessary code * Update parallel_context.py * modify the config * remove memory profiler * modify the config * support selective dropoutpull/190/head
parent
853becfb6e
commit
c219065348
|
@ -118,6 +118,7 @@ model = dict(
|
|||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
sequence_parallel=False,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
|
|
@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
|
||||
for initializer in initializers:
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
|
|
|
@ -30,10 +30,17 @@ def get_tensor_shape():
|
|||
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
if gpc.config.model.use_flash_attn:
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
if gpc.config.model.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.data["micro_bsz"],
|
||||
|
@ -132,6 +139,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
and gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
||||
)
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
self.scatter_gather_tensors = False
|
||||
|
||||
# cache for the batch data
|
||||
self.batch_data = None
|
||||
|
@ -254,7 +264,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
self._call_hooks("post_helper_func", output_obj, label)
|
||||
|
||||
if return_output_label:
|
||||
return_tensors.append((output_obj, label))
|
||||
if accum_loss is not None:
|
||||
|
|
|
@ -38,6 +38,7 @@ def get_default_parser():
|
|||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||
parser.add_argument("--seed", type=int, default=1024)
|
||||
parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -198,6 +199,10 @@ def args_sanity_check():
|
|||
# process the model config
|
||||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
if "sequence_parallel" not in gpc.config.model:
|
||||
gpc.config.model._add_item("sequence_parallel", False)
|
||||
else:
|
||||
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
|
||||
def launch(
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch import Tensor, nn
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
from .utils import gather_forward_split_backward
|
||||
from .utils import gather_forward_split_backward, split_forward_gather_backward
|
||||
|
||||
|
||||
class Embedding1D(nn.Module):
|
||||
|
@ -55,7 +55,10 @@ class Embedding1D(nn.Module):
|
|||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
|
@ -48,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
|
@ -60,7 +58,7 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func_torch(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||
)
|
||||
|
||||
|
||||
|
@ -87,12 +85,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
out_features: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
bias: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
weight_scale: int = 1,
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
|
||||
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
if bias:
|
||||
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||
|
@ -106,7 +103,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
else:
|
||||
weight = self.weight
|
||||
return fused_dense_func_torch(
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||
)
|
||||
|
||||
|
||||
|
@ -168,19 +165,19 @@ class FeedForward(nn.Module):
|
|||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = ColumnParallelLinearTorch(
|
||||
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
|
||||
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
|
||||
)
|
||||
self.w3 = RowParallelLinearTorch(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -89,7 +89,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=use_flash_attn,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -121,7 +120,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
|
@ -300,7 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -342,7 +341,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
|
@ -463,6 +461,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = False,
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
|
|
@ -59,7 +59,6 @@ class MHA(nn.Module):
|
|||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
|
@ -83,7 +82,7 @@ class MHA(nn.Module):
|
|||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
sequence_parallel=sequence_parallel,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
|
@ -96,7 +95,7 @@ class MHA(nn.Module):
|
|||
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinearTorch(
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
|
|
|
@ -157,6 +157,35 @@ def fused_dense_func_torch(
|
|||
else:
|
||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _split(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def try_import_RMSNorm():
|
||||
"""
|
||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||
|
|
|
@ -90,9 +90,15 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
if gpc.config.model.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
else:
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
|
@ -108,7 +114,7 @@ def evaluate_on_val_dls(
|
|||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
|
@ -155,3 +161,13 @@ def evaluate_on_val_dls(
|
|||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.model.sequence_parallel
|
||||
try:
|
||||
gpc.config.model.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
19
train.py
19
train.py
|
@ -41,7 +41,7 @@ from internlm.utils.common import (
|
|||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
|
@ -618,14 +618,15 @@ def main(args):
|
|||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
writer=writer,
|
||||
logger=logger,
|
||||
step_count=train_state.step_count,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
with switch_sequence_parallel_mode():
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
writer=writer,
|
||||
logger=logger,
|
||||
step_count=train_state.step_count,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# save batch sampler that tracks the true consumed samples
|
||||
|
|
Loading…
Reference in New Issue