Merge pull request #4 from yingtongxiong/fstp/refactor-config

feat(initialize/launch.py): refactor config for fstp
pull/456/head
ytxiong 2023-10-20 17:48:20 +08:00 committed by GitHub
commit f22e5b3b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 41 additions and 38 deletions

View File

@ -152,19 +152,19 @@ zero1 parallel (dict):
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. mode: str, the mode should be 'origin_tp' or 'fstp', defaults to 'origin_tp'. If the mode is 'fstp',
the sequence_parallel should be True.
2. sp: str, the sequence parallel mode, should be in ['none', 'megatron', 'flash-attn', 'intern'],
defaults to 'none', means the sequence parallel will be disabled.
3. intern_overlap: bool, enable/disable all_gather/reduce_scatter communication overlap when using 'intern' mode sp,
defaults to False.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)
cudnn_deterministic = False

View File

@ -306,15 +306,20 @@ def args_sanity_check():
), "sequence parallel does not support use_flash_attn=False"
if isinstance(gpc.config.parallel["tensor"], int):
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode="origin_tp")
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
assert (
gpc.config.parallel.sequence_parallel is True
), "when the tp_mode is fstp, the sequence_parallel should be True."
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], sp="none", intern_overlap=False)
if gpc.config.parallel["tensor"].get("sp", None) is None:
gpc.config.parallel["tensor"]["sp"] = "none"
if gpc.config.parallel["tensor"].get("intern_overlap", None) is None:
gpc.config.parallel["tensor"]["intern_overlap"] = False
assert gpc.config.parallel["tensor"].get("sp", None) in [
"none",
"megatron",
"flash-attn",
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
gpc.config.parallel.sequence_parallel = True
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:

View File

@ -77,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
):
super().__init__()
self.checkpoint = checkpoint
@ -102,7 +102,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
sp_mode=sp_mode,
)
self.dropout1 = nn.Dropout(drop_rate)
@ -114,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
mlp_cls = FSTPFeedForward if sp_mode == "intern" else FeedForward
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
@ -297,7 +297,7 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
if is_reward:
head_cls = RewardModelLinear
@ -343,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
sp_mode=self.sp_mode,
)
for lid in range(num_layers)
]
@ -389,8 +389,8 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
# if the sequence parallel mode is 'intern', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -175,7 +175,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp",
sp_mode: str = "none",
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -203,7 +203,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
Wqkv_cls = FSTPLinear if sp_mode == "intern" else ColumnParallelLinearTorch
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
@ -219,12 +219,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == "fstp":
if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
out_proj_cls = FSTPLinear if sp_mode == "intern" else RowParallelLinearTorch
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Optional, Union
from typing import Optional
import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index
module_name = ctx.module_name

View File

@ -2,8 +2,8 @@
# -*- encoding: utf-8 -*-
import math
from typing import Optional, List
from functools import partial
from typing import List, Optional
import torch
import torch.distributed as dist
@ -11,7 +11,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
@ -83,7 +83,7 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler
# Zero related args
@ -366,8 +366,8 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad)
# release cuda memory.
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None
_grad = None
bucket.reset_by_rank(reduce_rank)

View File

@ -45,7 +45,7 @@ class BucketStore(BaseStore):
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def num_params_in_bucket(self, reduce_rank: int = None):
return len(self._params[reduce_rank])

View File

@ -110,9 +110,8 @@ def initialize_model():
gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler

View File

@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
gpc.config.parallel.sequence_parallel = True
else:
gpc.config.parallel.sequence_parallel = False
@ -106,7 +106,7 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
if gpc.config.parallel["tensor"]["mode"] == "fstp":
if gpc.config.parallel["tensor"]["sp"] == "intern":
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[

View File

@ -45,7 +45,7 @@ def empty_cache_and_diag(batch_count, interval=50):
# # import time
# # time.sleep(10)
# print(e, "rank = ", gpc.get_global_rank(), flush=True)
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# do empty_cache after the bench
torch.cuda.empty_cache()