Merge pull request #4 from yingtongxiong/fstp/refactor-config

feat(initialize/launch.py): refactor config for fstp
pull/456/head
ytxiong 2023-10-20 17:48:20 +08:00 committed by GitHub
commit f22e5b3b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 41 additions and 38 deletions

View File

@ -152,19 +152,19 @@ zero1 parallel (dict):
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict): tensor parallel (dict):
1. size: int, the size of tensor parallel. 1. size: int, the size of tensor parallel.
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp', 2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
the sequence_parallel should be True. defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
pipeline parallel (dict): pipeline parallel (dict):
1. size: int, the size of pipeline parallel. 1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False. defaults to False.
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True), tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
) )
cudnn_deterministic = False cudnn_deterministic = False

View File

@ -306,15 +306,20 @@ def args_sanity_check():
), "sequence parallel does not support use_flash_attn=False" ), "sequence parallel does not support use_flash_attn=False"
if isinstance(gpc.config.parallel["tensor"], int): if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp") gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
if gpc.config.parallel["tensor"].get("sp", None) is None:
if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["sp"] = "none"
gpc.config.parallel["tensor"]["mode"] = "origin_tp" if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
gpc.config.parallel["tensor"]["intern_overlap"] = False
if gpc.config.parallel["tensor"].get("mode", None) == "fstp": assert gpc.config.parallel["tensor"].get("sp", None) in [
assert ( "none",
gpc.config.parallel.sequence_parallel is True "megatron",
), "when the tp_mode is fstp, the sequence_parallel should be True." "flash-attn",
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
gpc.config.parallel.sequence_parallel = True
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
tp_mode: str = "origin_tp", sp_mode: str = "none",
): ):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
device=device, device=device,
dtype=dtype, dtype=dtype,
tp_mode=tp_mode, sp_mode=sp_mode,
) )
self.dropout1 = nn.Dropout(drop_rate) self.dropout1 = nn.Dropout(drop_rate)
@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu: if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
self.mlp = mlp_cls( self.mlp = mlp_cls(
hidden_size, hidden_size,
int(hidden_size * mlp_ratio), int(hidden_size * mlp_ratio),
@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__() super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint) checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"] self.sp_mode = gpc.config.parallel["tensor"]["sp"]
if is_reward: if is_reward:
head_cls = RewardModelLinear head_cls = RewardModelLinear
@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init, use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode, sp_mode=self.sp_mode,
) )
for lid in range(num_layers) for lid in range(num_layers)
] ]
@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1 assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input. # The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0] indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. # if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -175,7 +175,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True, use_flash_attn: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp", sp_mode: str = "none",
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -203,7 +203,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
self.Wqkv = Wqkv_cls( self.Wqkv = Wqkv_cls(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
@ -219,12 +219,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls( self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
) )
if tp_mode == "fstp": if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Any, Optional, Union from typing import Optional
import fused_dense_lib as fused_dense_cuda import fused_dense_lib as fused_dense_cuda
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index block_index = ctx.block_index
module_name = ctx.module_name module_name = ctx.module_name

View File

@ -2,8 +2,8 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
from typing import Optional, List
from functools import partial from functools import partial
from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -11,7 +11,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import ( from internlm.solver.optimizer.store import (
BucketStore, BucketStore,
@ -83,7 +83,7 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler self._fstp_handler = gpc.config.fstp_handler
# Zero related args # Zero related args
@ -366,8 +366,8 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad) _param.grad.add_(_grad)
# release cuda memory. # release cuda memory.
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None self._fstp_handler.reduce_scatter_handlers[_key] = None
_grad = None
bucket.reset_by_rank(reduce_rank) bucket.reset_by_rank(reduce_rank)

View File

@ -110,9 +110,8 @@ def initialize_model():
gpc.config.fstp_handler = None gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook() handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler gpc.config.fstp_handler = handler

View File

@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode(): def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel prev_mode = gpc.config.parallel.sequence_parallel
try: try:
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["sp"] == "intern":
gpc.config.parallel.sequence_parallel = True gpc.config.parallel.sequence_parallel = True
else: else:
gpc.config.parallel.sequence_parallel = False gpc.config.parallel.sequence_parallel = False
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["sp"] == "intern":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size( tensor_shape = torch.Size(
[ [

View File

@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50):
# # import time # # import time
# # time.sleep(10) # # time.sleep(10)
# print(e, "rank = ", gpc.get_global_rank(), flush=True) # print(e, "rank = ", gpc.get_global_rank(), flush=True)
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# do empty_cache after the bench # do empty_cache after the bench
torch.cuda.empty_cache() torch.cuda.empty_cache()