refactor linear

pull/456/head
yingtongxiong 2023-10-20 17:50:56 +08:00
parent ed7232777a
commit dcd89ed304
8 changed files with 356 additions and 299 deletions

View File

@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)

View File

@ -19,25 +19,26 @@ from internlm.model.utils import (
all_gather_raw_memory_pool,
fstp_fused_dense_func,
fused_dense_func_torch,
megatron_fused_dense_func_torch,
)
class ScaleColumnParallelLinear(nn.Linear):
class BaseScaleColumnParallelLinear(nn.Linear):
"""
ScaleColumnParallelLinear.
Base class for ScaleColumnParallelLinear.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""
def __init__(
@ -57,6 +58,10 @@ class ScaleColumnParallelLinear(nn.Linear):
self.process_group = process_group
self.weight_scale = weight_scale
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
"""
ScaleColumnParallelLinear in flash implementation.
"""
def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
@ -74,6 +79,27 @@ class ScaleColumnParallelLinear(nn.Linear):
gather_dim=gather_dim,
)
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
"""
ScaleColumnParallelLinear in megatron implementation.
"""
def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return megatron_fused_dense_func_torch(
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
gather_dim=gather_dim,
)
class RewardModelLinear(ScaleColumnParallelLinear):
"""
@ -129,7 +155,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func_torch(
x,
self.weight,
@ -139,6 +164,19 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
gather_dim=gather_dim,
)
class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return megatron_fused_dense_func_torch(
x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
gather_dim=gather_dim,
)
class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
@ -150,10 +188,20 @@ class RowParallelLinearTorch(RowParallelLinear):
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class MegatronRowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = megatron_fused_dense_func_torch(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FeedForward(nn.Module):
class BaseFeedForward(nn.Module):
"""
FeedForward.
Base FeedForward in flash implementation.
Args:
in_features (int): size of each input sample
@ -177,13 +225,13 @@ class FeedForward(nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
block_idx: int = 0,
colum_cls = None,
row_cls = None,
):
super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinearTorch(
self.w1 = colum_cls(
in_features,
hidden_features,
process_group,
@ -192,7 +240,7 @@ class FeedForward(nn.Module):
device=device,
dtype=dtype,
)
self.w2 = ColumnParallelLinearTorch(
self.w2 = colum_cls(
in_features,
hidden_features,
process_group,
@ -201,7 +249,7 @@ class FeedForward(nn.Module):
device=device,
dtype=dtype,
)
self.w3 = RowParallelLinearTorch(
self.w3 = row_cls(
hidden_features,
out_features,
process_group,
@ -217,21 +265,9 @@ class FeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o))
return out
class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
block_index = gpc.config.fstp_handler.module_to_index[self]
name_index = gpc.config.fstp_handler.module_name_index[self]
name = gpc.config.fstp_handler.module_name[name_index]
return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group,
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
)
class FSTPFeedForward(nn.Module):
class FeedForward(BaseFeedForward):
"""
FeedForward.
FeedForward in flash implementation.
Args:
in_features (int): size of each input sample
@ -255,169 +291,106 @@ class FSTPFeedForward(nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
block_idx: int = 0,
):
super().__init__()
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch)
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
class MegatronFeedForward(BaseFeedForward):
"""
FeedForward in megatron implementation.
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int = None,
process_group: Optional[torch.distributed.ProcessGroup] = None,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch)
class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(F.silu(w1_o) * w2_o)
return out
block_index = gpc.config.fstp_handler.module_to_index[self]
name_index = gpc.config.fstp_handler.module_name_index[self]
name = gpc.config.fstp_handler.module_name[name_index]
return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group,
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
)
class FSTPAllGatherSyncHandler:
class FSTPFeedForward(BaseFeedForward):
"""
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
FeedForward in FSTP.
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.module_handler = dict() # key: FSTP module; value: all-gather handler
self.module_block = dict() # key: FSTP module; value: transformer block index
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int = None,
process_group: Optional[torch.distributed.ProcessGroup] = None,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, FSTPLinear, FSTPLinear)
self.reduce_scatter_handlers = {}
self.all_reduce_handlers = {}
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if isinstance(child, FSTPLinear):
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
self.FSTP_modules.append(child)
self.module_block[child] = idx
self.block_module[idx][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue
def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
"""
def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
def _post_forward_hook(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.module_handler:
del self.module_handler[module]
def _pre_backward_hook(module: nn.Module, grad_output):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook)
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_hook(_post_backward_hook)
def get_mlp_cls(sp_mode: str):
if sp_mode in ["none", "flash-attn"]:
mlp_cls = FeedForward
elif sp_mode == "megatron":
mlp_cls = MegatronFeedForward
else:
mlp_cls = FSTPFeedForward
return mlp_cls
def get_linear_cls(sp_mode: str, parallel_mode: str):
if parallel_mode == "column":
if sp_mode in ["none", "flash-attn"]:
cls = ColumnParallelLinearTorch
elif sp_mode == "megatron":
cls = MegatronColumnParallelLinearTorch
else:
cls = FSTPLinear
elif parallel_mode == 'row':
if sp_mode in ["none", "flash-attn"]:
cls = RowParallelLinearTorch
elif sp_mode == "megatron":
cls = MegatronRowParallelLinearTorch
else:
cls = FSTPLinear
return cls
class CoarseGrainedFSTPAllGatherSyncHandler:
"""
@ -468,7 +441,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
# print(f"name: {name}", flush=True)
if name == "out_proj":
self.FSTP_outs.append(child)
self.module_to_index[child] = idx

View File

@ -15,9 +15,12 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
MegatronFeedForward,
FSTPFeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
MegatronScaleColumnParallelLinear,
get_mlp_cls,
)
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import (
@ -77,8 +80,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
block_idx: int = 0,
sp_mode: str = "none",
):
super().__init__()
self.checkpoint = checkpoint
@ -103,8 +105,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
tp_mode=tp_mode,
block_idx=block_idx,
sp_mode=sp_mode,
)
self.dropout1 = nn.Dropout(drop_rate)
@ -116,7 +117,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
mlp_cls = get_mlp_cls(sp_mode)
self.mlp = mlp_cls(
hidden_size,
int(hidden_size * mlp_ratio),
@ -299,12 +300,16 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
self.sp_mode = gpc.config.parallel["tensor"]["sp"]
if self.sp_mode == "none":
gpc.config.parallel.sequence_parallel = False
else:
gpc.config.parallel.sequence_parallel = True
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
@ -345,8 +350,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
block_idx=lid,
sp_mode=self.sp_mode,
)
for lid in range(num_layers)
]
@ -393,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -42,6 +42,9 @@ from internlm.model.linear import (
ColumnParallelLinearTorch,
FSTPLinear,
RowParallelLinearTorch,
MegatronColumnParallelLinearTorch,
MegatronRowParallelLinearTorch,
get_linear_cls,
)
@ -175,8 +178,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp",
block_idx: int = 0,
sp_mode: str = "none",
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -204,7 +206,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
Wqkv_cls = get_linear_cls(sp_mode, "column")
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
@ -220,12 +222,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
if tp_mode == "fstp":
if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
out_proj_cls = get_linear_cls(sp_mode, 'row')
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,

View File

@ -164,7 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function):
"tp fused dense function"
"FusedDenseFunc for tensor parallel in flash-attn implementation."
@staticmethod
@custom_fwd
@ -255,9 +255,96 @@ class FusedDenseFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None, None
class MegatronFusedDenseFunc(torch.autograd.Function):
'''
FusedDenseFunc for tensor parallel in megatron implementation.
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
so that the all-gather in backward is ommited.
'''
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim)
else:
total_x = x
if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(total_x, weight)
else:
ctx.save_for_backward(weight)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
total_x, weight = ctx.saved_tensors
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc."""
'''FusedDenseFunc in flash implementation for supporting torch.float32'''
@staticmethod
@custom_bwd
@ -307,17 +394,61 @@ class FusedDenseFuncTorch(FusedDenseFunc):
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class MegatronFusedDenseFuncTorch(FusedDenseFunc):
'''FusedDenseFunc in megatron implementation for supporting torch.float32'''
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
gather_dim = ctx.gather_dim
if ctx.compute_weight_gradient:
total_x, weight = ctx.saved_tensors
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class FSTPFusedDenseFunc(torch.autograd.Function):
"FSTP fused dense function"
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None):
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.all_gather_handler = all_gather_handler
ctx.overlap_handler = overlap_handler
ctx.module = module
ctx.block_index = block_index
ctx.module_name = module_name
@ -329,13 +460,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
# do all_gather for weight and bias before actual computation
if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights:
# total_weight = all_gather_handler.FSTP_global_weights[module]
total_weight = gpc.config.block_memory[block_index % 2][module_name]
if overlap_handler is not None:
total_weight = gpc.config.block_memory[block_index % 2][module_name]
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
# TODO memory pool for bias
if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait()
@ -356,6 +486,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, total_weight, total_bias)
# release memory
del total_weight
del total_bias
if ctx.compute_weight_gradient:
@ -372,8 +503,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
overlap_handler = ctx.overlap_handler
block_index = ctx.block_index
module_name = ctx.module_name
@ -389,51 +519,35 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
total_weight = gpc.config.block_memory[block_index % 2][module_name]
# # do all-gather for weight before backward
# if module in all_gather_handler.FSTP_global_weights:
# total_weight = all_gather_handler.FSTP_global_weights[module]
# else:
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# handle_weight.wait()
if overlap_handler is not None:
total_weight = gpc.config.block_memory[block_index % 2][module_name]
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
else:
total_weight = weight
# compute weight grad
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
if world_size > 1:
if gpc.config.fstp_handler is not None:
# grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True)
# assert hasattr(weight, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async)
# grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
# if grad_bias is not None:
# grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True)
# assert hasattr(bias, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
if overlap_handler is not None:
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
else:
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
# grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
# if grad_bias is not None:
# grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
@ -449,7 +563,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
del total_weight
if ctx.needs_input_grad[1]:
if world_size > 1 and gpc.config.fstp_handler is None:
if world_size > 1 and overlap_handler is None:
handle_grad_weight.wait()
if grad_bias is not None:
handle_grad_bias.wait()
@ -473,6 +587,22 @@ def fused_dense_func_torch(
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def megatron_fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
gather_dim: int = 0,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
else:
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def fstp_fused_dense_func(
x: Tensor,

View File

@ -40,11 +40,6 @@ from .utils import compute_norm
inf = math.inf
logger = get_logger(__file__)
def print_memory(msg):
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
print("===========================================")
class HybridZeroOptimizer(BaseOptimizer):
"""
Hybrid Zero Optimizer.
@ -70,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
self._fstp_handler = gpc.config.fstp_handler
# Zero related args
@ -358,20 +353,7 @@ class HybridZeroOptimizer(BaseOptimizer):
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
bucket.reset_by_rank(rank)
@ -402,21 +384,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
current_bucket.reset_by_rank(reduce_rank)
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
@ -634,7 +601,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
print_memory("No 1")
if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]:
@ -659,7 +625,6 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
print_memory("No 2")
# compute norm for gradients in the last bucket
total_norms = {}
for group_id in range(self.num_param_groups):
@ -681,19 +646,11 @@ class HybridZeroOptimizer(BaseOptimizer):
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norms[group_name] = scaled_norm_tensor.item()
print_memory("No 3")
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
print_memory("No 4")
try:
res = self._step(closure=closure, norms=total_norms)
except torch.cuda.OutOfMemoryError as e:
print(e, flush=True)
print(torch.cuda.memory_summary(), flush=True)
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
res = self._step(closure=closure, norms=total_norms)
return res
@ -740,7 +697,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, norms
print_memory("No 5")
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
for group_id in range(self.num_param_groups):
@ -781,7 +737,6 @@ class HybridZeroOptimizer(BaseOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
print_memory("No 6")
# unscale and clip grads
# get the global norm
global_norm_groups = {}
@ -804,12 +759,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters.
if self.has_params:
print_memory("No 7")
self.optim.step()
print_memory("No 8")
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
print_memory("No 9")
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
if self.param_group_has_params[group_id]:
@ -818,7 +770,6 @@ class HybridZeroOptimizer(BaseOptimizer):
)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
print_memory("No 10")
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()
@ -829,7 +780,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# so synchronization is maintained
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
print_memory("No 11")
return True, global_norm_groups
def broadcast_params(self):

View File

@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
CoarseGrainedFSTPAllGatherSyncHandler,
FeedForward,
FSTPAllGatherSyncHandler,
RewardModelLinear,
ScaleColumnParallelLinear,
)
@ -111,7 +110,7 @@ def initialize_model():
gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()

View File

@ -195,7 +195,7 @@ def main(args):
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
torch.cuda.memory._record_memory_history()
# torch.cuda.memory._record_memory_history()
start_time = time.time()
timer("one-batch").start()
@ -300,7 +300,7 @@ def main(args):
if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory = {}
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats()
ckpt_manager.wait_async_upload_finish()