mirror of https://github.com/InternLM/InternLM
feat(initialize/launch.py): refactor config for fstp
parent
815a584930
commit
d91a5d9d9e
|
@ -152,19 +152,19 @@ zero1 parallel (dict):
|
|||
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
|
||||
tensor parallel (dict):
|
||||
1. size: int, the size of tensor parallel.
|
||||
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
|
||||
the sequence_parallel should be True.
|
||||
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
|
||||
defaults to 'none', means the sequence parallel will be disabled.
|
||||
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
|
||||
defaults to False.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
|
||||
defaults to False.
|
||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -306,15 +306,20 @@ def args_sanity_check():
|
|||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
if isinstance(gpc.config.parallel["tensor"], int):
|
||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")
|
||||
|
||||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
||||
|
||||
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
|
||||
assert (
|
||||
gpc.config.parallel.sequence_parallel is True
|
||||
), "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
|
||||
if gpc.config.parallel["tensor"].get("sp", None) is None:
|
||||
gpc.config.parallel["tensor"]["sp"] = "none"
|
||||
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
|
||||
gpc.config.parallel["tensor"]["intern_overlap"] = False
|
||||
assert gpc.config.parallel["tensor"].get("sp", None) in [
|
||||
"none",
|
||||
"megatron",
|
||||
"flash-attn",
|
||||
"intern",
|
||||
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
|
||||
# adapt to old version's sequence parallel config
|
||||
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
|
||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||
|
|
|
@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
tp_mode: str = "origin_tp",
|
||||
sp_mode: str = "none",
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_flash_attn=use_flash_attn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tp_mode=tp_mode,
|
||||
sp_mode=sp_mode,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
|
@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
if use_swiglu:
|
||||
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
|
||||
mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
|
||||
self.mlp = mlp_cls(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
|
@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
|
||||
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
|
||||
|
||||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
|
@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp_mode=self.tp_mode,
|
||||
sp_mode=self.sp_mode,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
|
||||
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
|
||||
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
|
||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
|
|
@ -175,7 +175,7 @@ class MHA(nn.Module):
|
|||
use_flash_attn: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tp_mode: str = "origin_tp",
|
||||
sp_mode: str = "none",
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
@ -203,7 +203,7 @@ class MHA(nn.Module):
|
|||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
|
@ -219,12 +219,12 @@ class MHA(nn.Module):
|
|||
self.inner_cross_attn = inner_cross_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
if tp_mode == "fstp":
|
||||
if sp_mode == "intern":
|
||||
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
|
|
|
@ -10,7 +10,10 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
||||
from internlm.model.utils import (
|
||||
release_reduce_scatter_memory_pool,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
@ -40,8 +43,20 @@ from .utils import compute_norm
|
|||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def print_memory(msg):
|
||||
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||
print(
|
||||
msg,
|
||||
" rank = ",
|
||||
gpc.get_global_rank(),
|
||||
" memory allocated: ",
|
||||
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
|
||||
" reverved memory: ",
|
||||
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
|
||||
" max memory: ",
|
||||
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
|
||||
flush=True,
|
||||
)
|
||||
print("===========================================")
|
||||
|
||||
|
||||
|
@ -69,8 +84,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
backoff_factor = grad_scal_cfg.backoff_factor
|
||||
hysteresis = grad_scal_cfg.hysteresis
|
||||
max_scale = grad_scal_cfg.max_scale
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
||||
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.config.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
|
@ -306,7 +321,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
|
||||
reduce_scatter_checker = partial(
|
||||
self._wait_reduce_scatter_and_accumulate_grad,
|
||||
param=param,
|
||||
|
@ -354,7 +369,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
@ -374,7 +389,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
|
||||
bucket.reset_by_rank(rank)
|
||||
|
||||
|
||||
def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
|
@ -397,11 +412,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
||||
|
||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
||||
# continue
|
||||
|
||||
|
@ -418,7 +433,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
|
||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
current_bucket.add_param(param, reduce_rank)
|
||||
|
||||
|
@ -685,16 +700,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
|
||||
print_memory("No 4")
|
||||
|
||||
|
||||
try:
|
||||
res = self._step(closure=closure, norms=total_norms)
|
||||
res = self._step(closure=closure, norms=total_norms)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
print(e, flush=True)
|
||||
print(torch.cuda.memory_summary(), flush=True)
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
|
||||
|
||||
return res
|
||||
|
||||
def _step(self, closure=None, norms=None):
|
||||
|
@ -822,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
|
|
|
@ -110,9 +110,8 @@ def initialize_model():
|
|||
|
||||
gpc.config.fstp_handler = None
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
else:
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
|
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern":
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = torch.Size(
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue