mirror of https://github.com/InternLM/InternLM
feat(initialize/launch.py): refactor config for fstp
parent
815a584930
commit
d91a5d9d9e
|
@ -152,19 +152,19 @@ zero1 parallel (dict):
|
||||||
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
||||||
tensor parallel (dict):
|
tensor parallel (dict):
|
||||||
1. size: int, the size of tensor parallel.
|
1. size: int, the size of tensor parallel.
|
||||||
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
|
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
|
||||||
the sequence_parallel should be True.
|
defaults to 'none', means the sequence parallel will be disabled.
|
||||||
|
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
|
||||||
|
defaults to False.
|
||||||
pipeline parallel (dict):
|
pipeline parallel (dict):
|
||||||
1. size: int, the size of pipeline parallel.
|
1. size: int, the size of pipeline parallel.
|
||||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
||||||
defaults to False.
|
defaults to False.
|
||||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -306,15 +306,20 @@ def args_sanity_check():
|
||||||
), "sequence parallel does not support use_flash_attn=False"
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
if isinstance(gpc.config.parallel["tensor"], int):
|
if isinstance(gpc.config.parallel["tensor"], int):
|
||||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")
|
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
|
||||||
|
if gpc.config.parallel["tensor"].get("sp", None) is None:
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
gpc.config.parallel["tensor"]["sp"] = "none"
|
||||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
|
||||||
|
gpc.config.parallel["tensor"]["intern_overlap"] = False
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
|
assert gpc.config.parallel["tensor"].get("sp", None) in [
|
||||||
assert (
|
"none",
|
||||||
gpc.config.parallel.sequence_parallel is True
|
"megatron",
|
||||||
), "when the tp_mode is fstp, the sequence_parallel should be True."
|
"flash-attn",
|
||||||
|
"intern",
|
||||||
|
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
|
||||||
|
# adapt to old version's sequence parallel config
|
||||||
|
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
|
||||||
|
gpc.config.parallel.sequence_parallel = True
|
||||||
|
|
||||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||||
|
|
|
@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
tp_mode: str = "origin_tp",
|
sp_mode: str = "none",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tp_mode=tp_mode,
|
sp_mode=sp_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
|
@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
|
mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
|
||||||
self.mlp = mlp_cls(
|
self.mlp = mlp_cls(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * mlp_ratio),
|
int(hidden_size * mlp_ratio),
|
||||||
|
@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||||
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
|
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
|
||||||
|
|
||||||
if is_reward:
|
if is_reward:
|
||||||
head_cls = RewardModelLinear
|
head_cls = RewardModelLinear
|
||||||
|
@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
tp_mode=self.tp_mode,
|
sp_mode=self.sp_mode,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
assert len(indexes) == 1
|
assert len(indexes) == 1
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
|
||||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
|
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||||
|
|
|
@ -175,7 +175,7 @@ class MHA(nn.Module):
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
tp_mode: str = "origin_tp",
|
sp_mode: str = "none",
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -203,7 +203,7 @@ class MHA(nn.Module):
|
||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
|
||||||
self.Wqkv = Wqkv_cls(
|
self.Wqkv = Wqkv_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
|
@ -219,12 +219,12 @@ class MHA(nn.Module):
|
||||||
self.inner_cross_attn = inner_cross_attn_cls(
|
self.inner_cross_attn = inner_cross_attn_cls(
|
||||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
)
|
)
|
||||||
if tp_mode == "fstp":
|
if sp_mode == "intern":
|
||||||
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
|
||||||
self.out_proj = out_proj_cls(
|
self.out_proj = out_proj_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
|
|
@ -10,7 +10,10 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
from internlm.model.utils import (
|
||||||
|
release_reduce_scatter_memory_pool,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
|
@ -40,8 +43,20 @@ from .utils import compute_norm
|
||||||
inf = math.inf
|
inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def print_memory(msg):
|
def print_memory(msg):
|
||||||
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
print(
|
||||||
|
msg,
|
||||||
|
" rank = ",
|
||||||
|
gpc.get_global_rank(),
|
||||||
|
" memory allocated: ",
|
||||||
|
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||||
|
" reverved memory: ",
|
||||||
|
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||||
|
" max memory: ",
|
||||||
|
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
print("===========================================")
|
print("===========================================")
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +85,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
hysteresis = grad_scal_cfg.hysteresis
|
hysteresis = grad_scal_cfg.hysteresis
|
||||||
max_scale = grad_scal_cfg.max_scale
|
max_scale = grad_scal_cfg.max_scale
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
self._fstp_handler = gpc.config.fstp_handler
|
self._fstp_handler = gpc.config.fstp_handler
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
|
|
|
@ -110,9 +110,8 @@ def initialize_model():
|
||||||
|
|
||||||
gpc.config.fstp_handler = None
|
gpc.config.fstp_handler = None
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
||||||
def switch_sequence_parallel_mode():
|
def switch_sequence_parallel_mode():
|
||||||
prev_mode = gpc.config.parallel.sequence_parallel
|
prev_mode = gpc.config.parallel.sequence_parallel
|
||||||
try:
|
try:
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||||
gpc.config.parallel.sequence_parallel = True
|
gpc.config.parallel.sequence_parallel = True
|
||||||
else:
|
else:
|
||||||
gpc.config.parallel.sequence_parallel = False
|
gpc.config.parallel.sequence_parallel = False
|
||||||
|
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
tensor_shape = torch.Size(
|
tensor_shape = torch.Size(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue