refactor linear

pull/456/head
yingtongxiong 2023-10-20 17:50:56 +08:00
parent ed7232777a
commit dcd89ed304
8 changed files with 356 additions and 299 deletions

View File

@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True), tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True, sequence_parallel=True,
) )

View File

@ -19,25 +19,26 @@ from internlm.model.utils import (
all_gather_raw_memory_pool, all_gather_raw_memory_pool,
fstp_fused_dense_func, fstp_fused_dense_func,
fused_dense_func_torch, fused_dense_func_torch,
megatron_fused_dense_func_torch,
) )
class ScaleColumnParallelLinear(nn.Linear): class BaseScaleColumnParallelLinear(nn.Linear):
""" """
ScaleColumnParallelLinear. Base class for ScaleColumnParallelLinear.
Args: Args:
in_features (int): size of each input sample in_features (int): size of each input sample
out_features (int): size of each output sample out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config. in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul. we do an all_gather of x before doing the matmul.
If not, then the input is already gathered. If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used. device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data. dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default. weight_scale (int): For training stability. 1 by default.
""" """
def __init__( def __init__(
@ -57,6 +58,10 @@ class ScaleColumnParallelLinear(nn.Linear):
self.process_group = process_group self.process_group = process_group
self.weight_scale = weight_scale self.weight_scale = weight_scale
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
"""
ScaleColumnParallelLinear in flash implementation.
"""
def forward(self, input, gather_dim=0): # pylint: disable=W0622 def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul. # we do an all_gather of x before doing the matmul.
@ -74,6 +79,27 @@ class ScaleColumnParallelLinear(nn.Linear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
"""
ScaleColumnParallelLinear in megatron implementation.
"""
def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
if self.weight_scale != 1:
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return megatron_fused_dense_func_torch(
input,
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
gather_dim=gather_dim,
)
class RewardModelLinear(ScaleColumnParallelLinear): class RewardModelLinear(ScaleColumnParallelLinear):
""" """
@ -129,7 +155,6 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul. # we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered. # If not, then the input is already gathered.
return fused_dense_func_torch( return fused_dense_func_torch(
x, x,
self.weight, self.weight,
@ -139,6 +164,19 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return megatron_fused_dense_func_torch(
x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
gather_dim=gather_dim,
)
class RowParallelLinearTorch(RowParallelLinear): class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x): def forward(self, x):
@ -150,10 +188,20 @@ class RowParallelLinearTorch(RowParallelLinear):
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group) return reduce_fn(out, self.process_group)
class MegatronRowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = megatron_fused_dense_func_torch(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FeedForward(nn.Module):
class BaseFeedForward(nn.Module):
""" """
FeedForward. Base FeedForward in flash implementation.
Args: Args:
in_features (int): size of each input sample in_features (int): size of each input sample
@ -177,13 +225,13 @@ class FeedForward(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
block_idx: int = 0, colum_cls = None,
row_cls = None,
): ):
super().__init__() super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinearTorch( self.w1 = colum_cls(
in_features, in_features,
hidden_features, hidden_features,
process_group, process_group,
@ -192,7 +240,7 @@ class FeedForward(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w2 = ColumnParallelLinearTorch( self.w2 = colum_cls(
in_features, in_features,
hidden_features, hidden_features,
process_group, process_group,
@ -201,7 +249,7 @@ class FeedForward(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w3 = RowParallelLinearTorch( self.w3 = row_cls(
hidden_features, hidden_features,
out_features, out_features,
process_group, process_group,
@ -217,21 +265,9 @@ class FeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o)) out = self.w3(Silu(w1_o, w2_o))
return out return out
class FeedForward(BaseFeedForward):
class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
block_index = gpc.config.fstp_handler.module_to_index[self]
name_index = gpc.config.fstp_handler.module_name_index[self]
name = gpc.config.fstp_handler.module_name[name_index]
return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group,
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
)
class FSTPFeedForward(nn.Module):
""" """
FeedForward. FeedForward in flash implementation.
Args: Args:
in_features (int): size of each input sample in_features (int): size of each input sample
@ -255,169 +291,106 @@ class FSTPFeedForward(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
block_idx: int = 0,
): ):
super().__init__() super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch)
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSTPLinear( class MegatronFeedForward(BaseFeedForward):
in_features, """
hidden_features, FeedForward in megatron implementation.
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int = None,
process_group: Optional[torch.distributed.ProcessGroup] = None,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch)
class FSTPLinear(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
w1_o = self.w1(x) block_index = gpc.config.fstp_handler.module_to_index[self]
w2_o = self.w2(x) name_index = gpc.config.fstp_handler.module_name_index[self]
out = self.w3(F.silu(w1_o) * w2_o) name = gpc.config.fstp_handler.module_name[name_index]
return out return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group,
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
)
class FSTPFeedForward(BaseFeedForward):
class FSTPAllGatherSyncHandler:
""" """
All-gather handler for overlapping the all-gather in adjcent FSTP linear. FeedForward in FSTP.
Args:
in_features (int): size of each input sample
hidden_features (int): size of hidden state of FFN
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
""" """
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: def __init__(
# import pdb; pdb.set_trace() self,
self.process_group = process_group in_features: int,
self.FSTP_modules = [] hidden_features: int,
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] out_features: int = None,
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward process_group: Optional[torch.distributed.ProcessGroup] = None,
self.module_handler = dict() # key: FSTP module; value: all-gather handler bias: bool = True,
self.module_block = dict() # key: FSTP module; value: transformer block index device: Optional[torch.device] = None,
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} dtype: Optional[torch.dtype] = None,
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name multiple_of: int = 256,
):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
dtype, multiple_of, FSTPLinear, FSTPLinear)
self.reduce_scatter_handlers = {} def get_mlp_cls(sp_mode: str):
self.all_reduce_handlers = {} if sp_mode in ["none", "flash-attn"]:
mlp_cls = FeedForward
# just want to share same for loop for ModuleList and Module elif sp_mode == "megatron":
if not isinstance(model, nn.ModuleList): mlp_cls = MegatronFeedForward
model = [model] else:
mlp_cls = FSTPFeedForward
for _chunk in model: return mlp_cls
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if isinstance(child, FSTPLinear):
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
self.FSTP_modules.append(child)
self.module_block[child] = idx
self.block_module[idx][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue
def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
"""
def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
def _post_forward_hook(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.module_handler:
del self.module_handler[module]
def _pre_backward_hook(module: nn.Module, grad_output):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.module_handler[next_module] = weights_handler
def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook)
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_hook(_post_backward_hook)
def get_linear_cls(sp_mode: str, parallel_mode: str):
if parallel_mode == "column":
if sp_mode in ["none", "flash-attn"]:
cls = ColumnParallelLinearTorch
elif sp_mode == "megatron":
cls = MegatronColumnParallelLinearTorch
else:
cls = FSTPLinear
elif parallel_mode == 'row':
if sp_mode in ["none", "flash-attn"]:
cls = RowParallelLinearTorch
elif sp_mode == "megatron":
cls = MegatronRowParallelLinearTorch
else:
cls = FSTPLinear
return cls
class CoarseGrainedFSTPAllGatherSyncHandler: class CoarseGrainedFSTPAllGatherSyncHandler:
""" """
@ -468,7 +441,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
sub_modules = list(sub.children()) sub_modules = list(sub.children())
if len(sub_modules) > 0: if len(sub_modules) > 0:
for name, child in sub.named_children(): for name, child in sub.named_children():
# print(f"name: {name}", flush=True)
if name == "out_proj": if name == "out_proj":
self.FSTP_outs.append(child) self.FSTP_outs.append(child)
self.module_to_index[child] = idx self.module_to_index[child] = idx

View File

@ -15,9 +15,12 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
FeedForward, FeedForward,
MegatronFeedForward,
FSTPFeedForward, FSTPFeedForward,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
MegatronScaleColumnParallelLinear,
get_mlp_cls,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
from internlm.model.utils import ( from internlm.model.utils import (
@ -77,8 +80,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True, use_scaled_init: bool = True,
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
tp_mode: str = "origin_tp", sp_mode: str = "none",
block_idx: int = 0,
): ):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
@ -103,8 +105,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
device=device, device=device,
dtype=dtype, dtype=dtype,
tp_mode=tp_mode, sp_mode=sp_mode,
block_idx=block_idx,
) )
self.dropout1 = nn.Dropout(drop_rate) self.dropout1 = nn.Dropout(drop_rate)
@ -116,7 +117,7 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu: if use_swiglu:
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward mlp_cls = get_mlp_cls(sp_mode)
self.mlp = mlp_cls( self.mlp = mlp_cls(
hidden_size, hidden_size,
int(hidden_size * mlp_ratio), int(hidden_size * mlp_ratio),
@ -299,12 +300,16 @@ class PackedFlashInternLm1D(nn.Module):
super().__init__() super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint) checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel["tensor"]["mode"] self.sp_mode = gpc.config.parallel["tensor"]["sp"]
if self.sp_mode == "none":
gpc.config.parallel.sequence_parallel = False
else:
gpc.config.parallel.sequence_parallel = True
if is_reward: if is_reward:
head_cls = RewardModelLinear head_cls = RewardModelLinear
else: else:
head_cls = ScaleColumnParallelLinear head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
if first: if first:
if embed_split_hidden: if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
@ -345,8 +350,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init, use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode, sp_mode=self.sp_mode,
block_idx=lid,
) )
for lid in range(num_layers) for lid in range(num_layers)
] ]
@ -393,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
# The indexes are used to indicate the actual position IDs of each token in the packed input. # The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0] indexes = indexes[0]
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension. # if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp": if gpc.config.parallel.sequence_parallel and self.sp_mode == "intern":
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

View File

@ -42,6 +42,9 @@ from internlm.model.linear import (
ColumnParallelLinearTorch, ColumnParallelLinearTorch,
FSTPLinear, FSTPLinear,
RowParallelLinearTorch, RowParallelLinearTorch,
MegatronColumnParallelLinearTorch,
MegatronRowParallelLinearTorch,
get_linear_cls,
) )
@ -175,8 +178,7 @@ class MHA(nn.Module):
use_flash_attn: bool = True, use_flash_attn: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp", sp_mode: str = "none",
block_idx: int = 0,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -204,7 +206,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear Wqkv_cls = get_linear_cls(sp_mode, "column")
self.Wqkv = Wqkv_cls( self.Wqkv = Wqkv_cls(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
@ -220,12 +222,12 @@ class MHA(nn.Module):
self.inner_cross_attn = inner_cross_attn_cls( self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
) )
if tp_mode == "fstp": if sp_mode == "intern":
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group) self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear out_proj_cls = get_linear_cls(sp_mode, 'row')
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,

View File

@ -164,7 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
"tp fused dense function" "FusedDenseFunc for tensor parallel in flash-attn implementation."
@staticmethod @staticmethod
@custom_fwd @custom_fwd
@ -255,9 +255,96 @@ class FusedDenseFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class MegatronFusedDenseFunc(torch.autograd.Function):
'''
FusedDenseFunc for tensor parallel in megatron implementation.
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
so that the all-gather in backward is ommited.
'''
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim)
else:
total_x = x
if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(total_x, weight)
else:
ctx.save_for_backward(weight)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
total_x, weight = ctx.saved_tensors
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc): class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc.""" '''FusedDenseFunc in flash implementation for supporting torch.float32'''
@staticmethod @staticmethod
@custom_bwd @custom_bwd
@ -307,17 +394,61 @@ class FusedDenseFuncTorch(FusedDenseFunc):
handle_grad_input.wait() handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class MegatronFusedDenseFuncTorch(FusedDenseFunc):
'''FusedDenseFunc in megatron implementation for supporting torch.float32'''
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
gather_dim = ctx.gather_dim
if ctx.compute_weight_gradient:
total_x, weight = ctx.saved_tensors
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class FSTPFusedDenseFunc(torch.autograd.Function): class FSTPFusedDenseFunc(torch.autograd.Function):
"FSTP fused dense function" "FusedDenseFunc for FSTP, which is optimized based on flash implementation."
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None, block_index=None, module_name=None): def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, overlap_handler=None, block_index=None, module_name=None):
ctx.compute_weight_gradient = weight.requires_grad ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual ctx.return_residual = return_residual
ctx.process_group = process_group ctx.process_group = process_group
ctx.all_gather_handler = all_gather_handler ctx.overlap_handler = overlap_handler
ctx.module = module ctx.module = module
ctx.block_index = block_index ctx.block_index = block_index
ctx.module_name = module_name ctx.module_name = module_name
@ -329,13 +460,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
# do all_gather for weight and bias before actual computation # do all_gather for weight and bias before actual computation
if all_gather_handler is not None:# and module in all_gather_handler.FSTP_global_weights: if overlap_handler is not None:
# total_weight = all_gather_handler.FSTP_global_weights[module] total_weight = gpc.config.block_memory[block_index % 2][module_name]
total_weight = gpc.config.block_memory[block_index % 2][module_name]
else: else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
# TODO memory pool for bias
if bias is not None: if bias is not None:
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
handle_bias.wait() handle_bias.wait()
@ -356,6 +486,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
if min(batch_dim, n, *total_weight.shape) > 65535 * 32: if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError("fused_dense only supports matrix dims <= 2M") raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, total_weight, total_bias) output = F.linear(total_x, total_weight, total_bias)
# release memory
del total_weight del total_weight
del total_bias del total_bias
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
@ -372,8 +503,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
(grad_input,) = args (grad_input,) = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler overlap_handler = ctx.overlap_handler
module = ctx.module
block_index = ctx.block_index block_index = ctx.block_index
module_name = ctx.module_name module_name = ctx.module_name
@ -389,51 +519,35 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
total_weight = gpc.config.block_memory[block_index % 2][module_name] if overlap_handler is not None:
# # do all-gather for weight before backward total_weight = gpc.config.block_memory[block_index % 2][module_name]
# if module in all_gather_handler.FSTP_global_weights: else:
# total_weight = all_gather_handler.FSTP_global_weights[module] total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# else: handle_weight.wait()
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
# handle_weight.wait()
else: else:
total_weight = weight total_weight = weight
# compute weight grad # compute weight grad
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient assert ctx.compute_weight_gradient
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
) )
if world_size > 1: if world_size > 1:
if gpc.config.fstp_handler is not None: if overlap_handler is not None:
# grad_weight_async, handle_grad_weight = all_reduce_raw(grad_weight, process_group, async_op=True)
# assert hasattr(weight, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[weight._fstp_all_reduce_str] = (handle_grad_weight, grad_weight_async)
# grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
# if grad_bias is not None:
# grad_bias_async, handle_grad_bias = all_reduce_raw(grad_bias, process_group, async_op=True)
# assert hasattr(bias, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str") assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None: if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True) grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str") assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
else: else:
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None: if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
# grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
# if grad_bias is not None:
# grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
else: else:
grad_weight = None grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None grad_bias = grad_output if ctx.needs_input_grad[2] else None
@ -449,7 +563,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
del total_weight del total_weight
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if world_size > 1 and gpc.config.fstp_handler is None: if world_size > 1 and overlap_handler is None:
handle_grad_weight.wait() handle_grad_weight.wait()
if grad_bias is not None: if grad_bias is not None:
handle_grad_bias.wait() handle_grad_bias.wait()
@ -473,6 +587,22 @@ def fused_dense_func_torch(
else: else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def megatron_fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
gather_dim: int = 0,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
else:
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def fstp_fused_dense_func( def fstp_fused_dense_func(
x: Tensor, x: Tensor,

View File

@ -40,11 +40,6 @@ from .utils import compute_norm
inf = math.inf inf = math.inf
logger = get_logger(__file__) logger = get_logger(__file__)
def print_memory(msg):
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
print("===========================================")
class HybridZeroOptimizer(BaseOptimizer): class HybridZeroOptimizer(BaseOptimizer):
""" """
Hybrid Zero Optimizer. Hybrid Zero Optimizer.
@ -70,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis = grad_scal_cfg.hysteresis hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
self._fstp_handler = gpc.config.fstp_handler self._fstp_handler = gpc.config.fstp_handler
# Zero related args # Zero related args
@ -358,20 +353,7 @@ class HybridZeroOptimizer(BaseOptimizer):
del self._fstp_handler.reduce_scatter_handlers[key] del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
bucket.reset_by_rank(rank) bucket.reset_by_rank(rank)
@ -402,21 +384,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._fstp_handler.reduce_scatter_handlers[key] = None self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
current_bucket.reset_by_rank(reduce_rank) current_bucket.reset_by_rank(reduce_rank)
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
@ -634,7 +601,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# if not overlapping communication (no reduction hook is attached) # if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients # we need to manually reduce these gradients
print_memory("No 1")
if not self._overlap_sync_grad: if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)): for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]: for param in self._fp16_param_groups[group_id]:
@ -659,7 +625,6 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket.empty() bucket.empty()
self._bucket_in_progress = [] self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
print_memory("No 2")
# compute norm for gradients in the last bucket # compute norm for gradients in the last bucket
total_norms = {} total_norms = {}
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
@ -681,19 +646,11 @@ class HybridZeroOptimizer(BaseOptimizer):
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg) dist.all_reduce(scaled_norm_tensor, group=pg)
total_norms[group_name] = scaled_norm_tensor.item() total_norms[group_name] = scaled_norm_tensor.item()
print_memory("No 3")
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
print_memory("No 4") res = self._step(closure=closure, norms=total_norms)
try:
res = self._step(closure=closure, norms=total_norms)
except torch.cuda.OutOfMemoryError as e:
print(e, flush=True)
print(torch.cuda.memory_summary(), flush=True)
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
return res return res
@ -740,7 +697,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()
self.zero_grad() self.zero_grad()
return False, norms return False, norms
print_memory("No 5")
# copy the grad of fp16 param to fp32 param # copy the grad of fp16 param to fp32 param
single_grad_partition_groups = [] single_grad_partition_groups = []
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
@ -781,7 +737,6 @@ class HybridZeroOptimizer(BaseOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads) single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
print_memory("No 6")
# unscale and clip grads # unscale and clip grads
# get the global norm # get the global norm
global_norm_groups = {} global_norm_groups = {}
@ -804,12 +759,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# For those ranks that are not assigned parameters, we just wait for other ranks # For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters. # to send them updated their own parameters.
if self.has_params: if self.has_params:
print_memory("No 7")
self.optim.step() self.optim.step()
print_memory("No 8")
# release the fp32 grad # release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
print_memory("No 9")
# update fp16 partition updated by the current rank # update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)): for group_id in range(len(self._fp16_param_groups)):
if self.param_group_has_params[group_id]: if self.param_group_has_params[group_id]:
@ -818,7 +770,6 @@ class HybridZeroOptimizer(BaseOptimizer):
) )
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
print_memory("No 10")
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream): with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params() self.broadcast_params()
@ -829,7 +780,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# so synchronization is maintained # so synchronization is maintained
for group_name, global_norm in global_norm_groups.items(): for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale global_norm_groups[group_name] = global_norm / loss_scale
print_memory("No 11")
return True, global_norm_groups return True, global_norm_groups
def broadcast_params(self): def broadcast_params(self):

View File

@ -38,7 +38,6 @@ from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
CoarseGrainedFSTPAllGatherSyncHandler, CoarseGrainedFSTPAllGatherSyncHandler,
FeedForward, FeedForward,
FSTPAllGatherSyncHandler,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
) )
@ -111,7 +110,7 @@ def initialize_model():
gpc.config.fstp_handler = None gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] == True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook() handler._register_sync_parameters_hook()

View File

@ -195,7 +195,7 @@ def main(args):
# start iterating the train data and begin training # start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps): for batch_count in range(train_state.batch_count, total_steps):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
torch.cuda.memory._record_memory_history() # torch.cuda.memory._record_memory_history()
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
@ -300,7 +300,7 @@ def main(args):
if gpc.config.fstp_handler is not None: if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory = {} gpc.config.fstp_handler.reduce_scatter_memory = {}
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
ckpt_manager.wait_async_upload_finish() ckpt_manager.wait_async_upload_finish()