feat(model/overlap_handler.py): refactor overlap hook handle

pull/456/head
huangting4201 2023-10-20 21:50:32 +08:00
parent 1804d01bb3
commit 85ad917ae4
9 changed files with 392 additions and 349 deletions

View File

@ -163,7 +163,7 @@ pipeline parallel (dict):
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, sp="megatron", intern_overlap=True), tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
) )

View File

@ -1,22 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Any, Optional, Union from typing import Optional
import torch import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn from torch import nn
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.utils import ( from internlm.model.utils import (
Silu, Silu,
all_gather_raw,
all_gather_raw_memory_pool,
fstp_fused_dense_func, fstp_fused_dense_func,
fused_dense_func_torch, fused_dense_func_torch,
megatron_fused_dense_func_torch, megatron_fused_dense_func_torch,
@ -25,20 +20,20 @@ from internlm.model.utils import (
class BaseScaleColumnParallelLinear(nn.Linear): class BaseScaleColumnParallelLinear(nn.Linear):
""" """
Base class for ScaleColumnParallelLinear. Base class for ScaleColumnParallelLinear.
Args: Args:
in_features (int): size of each input sample in_features (int): size of each input sample
out_features (int): size of each output sample out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config. in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul. we do an all_gather of x before doing the matmul.
If not, then the input is already gathered. If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used. device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data. dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default. weight_scale (int): For training stability. 1 by default.
""" """
def __init__( def __init__(
@ -58,10 +53,12 @@ class BaseScaleColumnParallelLinear(nn.Linear):
self.process_group = process_group self.process_group = process_group
self.weight_scale = weight_scale self.weight_scale = weight_scale
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
""" """
ScaleColumnParallelLinear in flash implementation. ScaleColumnParallelLinear in flash implementation.
""" """
def forward(self, input, gather_dim=0): # pylint: disable=W0622 def forward(self, input, gather_dim=0): # pylint: disable=W0622
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul. # we do an all_gather of x before doing the matmul.
@ -79,6 +76,7 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear): class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
""" """
ScaleColumnParallelLinear in megatron implementation. ScaleColumnParallelLinear in megatron implementation.
@ -101,6 +99,7 @@ class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class RewardModelLinear(ScaleColumnParallelLinear): class RewardModelLinear(ScaleColumnParallelLinear):
""" """
RewardModelLinear. RewardModelLinear.
@ -164,6 +163,7 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class MegatronColumnParallelLinearTorch(ColumnParallelLinear): class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x, gather_dim=0): def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
@ -178,6 +178,7 @@ class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
gather_dim=gather_dim, gather_dim=gather_dim,
) )
class RowParallelLinearTorch(RowParallelLinear): class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x): def forward(self, x):
""" """
@ -188,6 +189,7 @@ class RowParallelLinearTorch(RowParallelLinear):
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group) return reduce_fn(out, self.process_group)
class MegatronRowParallelLinearTorch(RowParallelLinear): class MegatronRowParallelLinearTorch(RowParallelLinear):
def forward(self, x): def forward(self, x):
""" """
@ -225,8 +227,8 @@ class BaseFeedForward(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
colum_cls = None, colum_cls=None,
row_cls = None, row_cls=None,
): ):
super().__init__() super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
@ -265,6 +267,7 @@ class BaseFeedForward(nn.Module):
out = self.w3(Silu(w1_o, w2_o)) out = self.w3(Silu(w1_o, w2_o))
return out return out
class FeedForward(BaseFeedForward): class FeedForward(BaseFeedForward):
""" """
FeedForward in flash implementation. FeedForward in flash implementation.
@ -292,9 +295,19 @@ class FeedForward(BaseFeedForward):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
): ):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device, super().__init__(
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch) in_features,
hidden_features,
out_features,
process_group,
bias,
device,
dtype,
multiple_of,
ColumnParallelLinearTorch,
RowParallelLinearTorch,
)
class MegatronFeedForward(BaseFeedForward): class MegatronFeedForward(BaseFeedForward):
""" """
@ -323,19 +336,35 @@ class MegatronFeedForward(BaseFeedForward):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
): ):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device, super().__init__(
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch) in_features,
hidden_features,
out_features,
process_group,
bias,
device,
dtype,
multiple_of,
MegatronColumnParallelLinearTorch,
MegatronRowParallelLinearTorch,
)
class FSTPLinear(ColumnParallelLinear): class FSTPLinear(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
block_index = gpc.config.fstp_handler.module_to_index[self] block_index = gpc.config.fstp_handler.module_to_index[self]
name_index = gpc.config.fstp_handler.module_name_index[self]
name = gpc.config.fstp_handler.module_name[name_index]
return fstp_fused_dense_func( return fstp_fused_dense_func(
x, self.weight, self.bias, process_group=self.process_group, x,
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name self.weight,
self.bias,
process_group=self.process_group,
module=self,
handler=gpc.config.fstp_handler,
block_index=block_index,
module_name=self._fstp_name,
) )
class FSTPFeedForward(BaseFeedForward): class FSTPFeedForward(BaseFeedForward):
""" """
FeedForward in FSTP. FeedForward in FSTP.
@ -363,8 +392,19 @@ class FSTPFeedForward(BaseFeedForward):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
): ):
super().__init__(in_features, hidden_features, out_features, process_group, bias, device, super().__init__(
dtype, multiple_of, FSTPLinear, FSTPLinear) in_features,
hidden_features,
out_features,
process_group,
bias,
device,
dtype,
multiple_of,
FSTPLinear,
FSTPLinear,
)
def get_mlp_cls(sp_mode: str): def get_mlp_cls(sp_mode: str):
if sp_mode in ["none", "flash-attn"]: if sp_mode in ["none", "flash-attn"]:
@ -375,6 +415,7 @@ def get_mlp_cls(sp_mode: str):
mlp_cls = FSTPFeedForward mlp_cls = FSTPFeedForward
return mlp_cls return mlp_cls
def get_linear_cls(sp_mode: str, parallel_mode: str): def get_linear_cls(sp_mode: str, parallel_mode: str):
if parallel_mode == "column": if parallel_mode == "column":
if sp_mode in ["none", "flash-attn"]: if sp_mode in ["none", "flash-attn"]:
@ -383,7 +424,7 @@ def get_linear_cls(sp_mode: str, parallel_mode: str):
cls = MegatronColumnParallelLinearTorch cls = MegatronColumnParallelLinearTorch
else: else:
cls = FSTPLinear cls = FSTPLinear
elif parallel_mode == 'row': elif parallel_mode == "row":
if sp_mode in ["none", "flash-attn"]: if sp_mode in ["none", "flash-attn"]:
cls = RowParallelLinearTorch cls = RowParallelLinearTorch
elif sp_mode == "megatron": elif sp_mode == "megatron":
@ -391,192 +432,3 @@ def get_linear_cls(sp_mode: str, parallel_mode: str):
else: else:
cls = FSTPLinear cls = FSTPLinear
return cls return cls
class CoarseGrainedFSTPAllGatherSyncHandler:
"""
All-gather handler for overlapping the all-gather in adjcent FSTP block.
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
# import pdb; pdb.set_trace()
self.process_group = process_group
self.FSTP_blocks = []
self.FSTP_outs = []
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.block_handles = dict() # key: transformer block; value: all-gather handles
self.module_to_index = dict() # key: FSTP module; value: transformer block index
self.block_to_index = dict() # key: transformer block; value: transformer block index
self.index_to_block = dict() # key: transformer block index; value: transformer block
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = []
self.embedding = []
self.reduce_scatter_handlers = {}
self.all_reduce_handlers = {}
self.zero_const_pool = {}
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
self.FSTP_blocks.append(block)
self.block_to_index[block] = idx
self.index_to_block[idx] = block
self.index_to_fsdp_modules[idx] = []
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if name == "out_proj":
self.FSTP_outs.append(child)
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.block_module[idx][index] = child
self.FSTP_modules.append(child)
self.index_to_fsdp_modules[idx].append(child)
self.module_name_index[child] = index
index = index + 1
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
else:
continue
elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
if size not in self.zero_const_pool:
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
return self.zero_const_pool[size]
def _all_gather_block_weight_memory_pool(self, block_index: int):
fsdp_modules = self.index_to_fsdp_modules[block_index]
for module in fsdp_modules:
module_index = self.module_name_index[module]
name = self.module_name[module_index]
weight_handle = all_gather_raw_memory_pool(
module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
)
self.FSTP_global_handle[module] = weight_handle
def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTP block.
Notice that next block's all_gather op should be after current block's all_to_all op, so we
1. register pre_forward_hook @out_proj module to prefetch for next block
2. register pre_forward_hook @block module to wait handles for next block
3. register pre_backward_hook @wqkv module to prefetch for next block
4. register pre_backward_hook @block module to wait handles for next block
"""
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
self._all_gather_block_weight_memory_pool(0)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
handle = self.FSTP_global_handle[module]
handle.wait()
def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
self.FSTP_global_handle[first_module] = weight_handler
self.FSTP_global_weights[first_module] = total_weight
def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
next_name = self.module_name[name_index - 1]
weights_handler = all_gather_raw_memory_pool(
next_module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=next_name,
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()
if block_index - 1 >= 0:
next_module = self.block_module[block_index - 1][4]
name = self.module_name[4]
weights_handler = all_gather_raw_memory_pool(
next_module.weight,
self.process_group,
async_op=True,
block_index=block_index - 1,
module_name=name,
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
name = self.module_name[name_index - 1]
weights_handler = all_gather_raw_memory_pool(
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding)
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
for module in self.FSTP_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
module.register_full_backward_hook(_post_backward_hook_for_module)

View File

@ -14,12 +14,9 @@ from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
FeedForward, MegatronScaleColumnParallelLinear,
MegatronFeedForward,
FSTPFeedForward,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
MegatronScaleColumnParallelLinear,
get_mlp_cls, get_mlp_cls,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
@ -309,7 +306,11 @@ class PackedFlashInternLm1D(nn.Module):
if is_reward: if is_reward:
head_cls = RewardModelLinear head_cls = RewardModelLinear
else: else:
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear head_cls = (
ScaleColumnParallelLinear
if self.sp_mode in ["flash-attn", "none", "intern"]
else MegatronScaleColumnParallelLinear
)
if first: if first:
if embed_split_hidden: if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)

View File

@ -38,14 +38,7 @@ from torch.nn import Module
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ( from internlm.model.linear import get_linear_cls
ColumnParallelLinearTorch,
FSTPLinear,
RowParallelLinearTorch,
MegatronColumnParallelLinearTorch,
MegatronRowParallelLinearTorch,
get_linear_cls,
)
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py # adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
@ -227,7 +220,7 @@ class MHA(nn.Module):
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group) self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = get_linear_cls(sp_mode, 'row') out_proj_cls = get_linear_cls(sp_mode, "row")
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,

View File

@ -0,0 +1,253 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Union
import torch
from torch import nn
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import all_gather_raw_memory_pool
from internlm.utils.common import get_current_device
class FSTPOverlapHandler:
"""
FSTP overlap handler for managing the all-gather and reduce_scatter overlapping.
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
self.process_group = process_group
self.fstp_outs = []
self.fstp_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
self.module_to_index = dict() # key: fstp module; value: transformer block index
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = []
self.embedding = []
self.reduce_scatter_handlers = {}
self.zero_const_pool = {}
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _chunk_name, children in _chunk.named_children():
if isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)
elif isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
self.index_to_fstp_modules[idx] = []
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if name == "out_proj":
self.fstp_outs.append(child)
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.fstp_modules.append(child)
self.index_to_fstp_modules[idx].append(child)
setattr(child, "_fstp_name", name)
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
self._initialize_memory_pool()
self._register_sync_parameters_hook()
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
if size not in self.zero_const_pool:
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
return self.zero_const_pool[size]
def _initialize_memory_pool(self) -> None:
# allocate memory pool
hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
self.all_gather_memory_pool = []
self.reduce_scatter_memory_pool = {}
for _ in range(2):
weight = {}
for name in self.module_name:
if name == "Wqkv":
weight[name] = torch.zeros(
(3 * hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
elif name == "out_proj":
weight[name] = torch.zeros(
(hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
elif name == "w1" or name == "w2":
weight[name] = torch.zeros(
(mlp_hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
else:
weight[name] = torch.zeros(
(hidden_size, mlp_hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
def get_all_gather_memory(self, index, module_name):
return self.all_gather_memory_pool[index % 2][module_name]
def get_reduce_scatter_memory(self, key):
return_idx = 0
# if key not in dict
if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = {"data": [], "used": []}
# if the data is empty
if len(self.reduce_scatter_memory_pool[key]["data"]) == 0:
self.reduce_scatter_memory_pool[key]["data"].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = 0
return return_idx
else: # if not empty
for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]):
if used is False:
self.reduce_scatter_memory_pool[key]["used"][index] = True
return_idx = index
return return_idx
# if the memory pool is all used
length = len(self.reduce_scatter_memory_pool[key]["data"])
self.reduce_scatter_memory_pool[key]["data"].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = length
return return_idx
def release_reduce_scatter_memory(self, size, index):
self.reduce_scatter_memory_pool[size]["used"][index] = False
def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index]
for module in fstp_modules:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(module, "_fstp_name"),
)
self.fstp_global_handle[module] = weight_handle
def _register_sync_parameters_hook(self) -> None:
"""
register forward hooks and backward hooks for fstp modules.
"""
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any):
self._all_gather_block_weight_memory_pool(0)
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
handle = self.fstp_global_handle[module]
handle.wait()
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
first_backward_module = self.fstp_modules[-1]
block_index = self.module_to_index[first_backward_module]
weight_handle = all_gather_raw_memory_pool(
first_backward_module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(first_backward_module, "_fstp_name"),
)
self.fstp_global_handle[first_backward_module] = weight_handle
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
# wait handle for current module
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()
# start the all-gather for next module
module_index = self.fstp_modules.index(module)
if module_index - 1 >= 0:
next_module = self.fstp_modules[module_index - 1]
block_index = self.module_to_index[next_module]
weight_handle = all_gather_raw_memory_pool(
next_module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(next_module, "_fstp_name"),
)
self.fstp_global_handle[next_module] = weight_handle
def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]
# register forward hooks
# 1. register post_forward_hook @embedding module to prefetch for block 0
# 2. register pre_forward_hook @out_proj module to prefetch for next block,
# notice that next block's all_gather op should be after current block's all_to_all op
# 3. register pre_forward_hook @fstp_module to wait handle for current module
# 4. register post_forward_hook @fstp_module to release resource
for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding)
for out_proj in self.fstp_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
for module in self.fstp_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)
# register backward hooks
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
for module in self.fstp_modules:
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)

View File

@ -135,7 +135,7 @@ def all_gather_raw_memory_pool(
module_name: str = None, module_name: str = None,
): ):
handle = torch.distributed.all_gather_into_tensor( handle = torch.distributed.all_gather_into_tensor(
gpc.config.block_memory[block_index % 2][module_name], gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name),
input_.contiguous(), input_.contiguous(),
group=process_group, group=process_group,
async_op=async_op, async_op=async_op,
@ -166,8 +166,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0 assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:]) size = (input_.shape[0] // world_size, *input_.shape[1:])
index = check_reduce_scatter_memory_pool(size) index = gpc.config.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.config.reduce_scatter_memory[size]["data"][index] output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index) setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor( handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op output, input_.contiguous(), group=process_group, async_op=async_op
@ -269,11 +269,11 @@ class FusedDenseFunc(torch.autograd.Function):
class MegatronFusedDenseFunc(torch.autograd.Function): class MegatronFusedDenseFunc(torch.autograd.Function):
''' """
FusedDenseFunc for tensor parallel in megatron implementation. FusedDenseFunc for tensor parallel in megatron implementation.
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
so that the all-gather in backward is ommited. so that the all-gather in backward is ommited.
''' """
@staticmethod @staticmethod
@custom_fwd @custom_fwd
@ -355,9 +355,10 @@ class MegatronFusedDenseFunc(torch.autograd.Function):
handle_grad_input.wait() handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc): class FusedDenseFuncTorch(FusedDenseFunc):
'''FusedDenseFunc in flash implementation for supporting torch.float32''' """FusedDenseFunc in flash implementation for supporting torch.float32"""
@staticmethod @staticmethod
@custom_bwd @custom_bwd
@ -407,8 +408,9 @@ class FusedDenseFuncTorch(FusedDenseFunc):
handle_grad_input.wait() handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class MegatronFusedDenseFuncTorch(FusedDenseFunc): class MegatronFusedDenseFuncTorch(FusedDenseFunc):
'''FusedDenseFunc in megatron implementation for supporting torch.float32''' """FusedDenseFunc in megatron implementation for supporting torch.float32"""
@staticmethod @staticmethod
@custom_bwd @custom_bwd
@ -452,6 +454,7 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
handle_grad_input.wait() handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class FSTPFusedDenseFunc(torch.autograd.Function): class FSTPFusedDenseFunc(torch.autograd.Function):
"FusedDenseFunc for FSTP, which is optimized based on flash implementation." "FusedDenseFunc for FSTP, which is optimized based on flash implementation."
@ -485,7 +488,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
if world_size > 1: if world_size > 1:
# do all_gather for weight and bias before actual computation # do all_gather for weight and bias before actual computation
if overlap_handler is not None: if overlap_handler is not None:
total_weight = gpc.config.block_memory[block_index % 2][module_name] total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
else: else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
@ -544,7 +547,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
if overlap_handler is not None: if overlap_handler is not None:
total_weight = gpc.config.block_memory[block_index % 2][module_name] total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
else: else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
@ -559,17 +562,39 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
) )
if world_size > 1: if world_size > 1:
if overlap_handler is not None: if overlap_handler is not None:
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
grad_weight, process_group, async_op=True
)
assert hasattr(weight, "_fstp_reduce_scatter_str") assert hasattr(weight, "_fstp_reduce_scatter_str")
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) handle_grad_weight,
grad_weight_async,
)
grad_weight = overlap_handler.get_zero_by_shape(
(
grad_weight.shape[0] // torch.distributed.get_world_size(process_group),
*grad_weight.shape[1:],
),
dtype=grad_weight.dtype,
device=grad_weight.device,
)
if grad_bias is not None: if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool( grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(
grad_bias, process_group, async_op=True grad_bias, process_group, async_op=True
) )
assert hasattr(bias, "_fstp_reduce_scatter_str") assert hasattr(bias, "_fstp_reduce_scatter_str")
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) handle_grad_bias,
grad_bias_async,
)
grad_bias = overlap_handler.get_zero_by_shape(
(
grad_bias.shape[0] // torch.distributed.get_world_size(process_group),
*grad_bias.shape[1:],
),
dtype=grad_bias.dtype,
device=grad_bias.device,
)
else: else:
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
if grad_bias is not None: if grad_bias is not None:
@ -613,6 +638,7 @@ def fused_dense_func_torch(
else: else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
def megatron_fused_dense_func_torch( def megatron_fused_dense_func_torch(
x: Tensor, x: Tensor,
weight: Tensor, weight: Tensor,
@ -626,9 +652,14 @@ def megatron_fused_dense_func_torch(
x.dtype == torch.float32 and torch.is_autocast_enabled() x.dtype == torch.float32 and torch.is_autocast_enabled()
) )
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) return MegatronFusedDenseFunc.apply(
x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim
)
else: else:
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim) return MegatronFusedDenseFuncTorch.apply(
x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim
)
def fstp_fused_dense_func( def fstp_fused_dense_func(
x: Tensor, x: Tensor,
@ -693,38 +724,3 @@ def Silu(w1_o, w2_o):
Silu = torch.jit.script(Silu) Silu = torch.jit.script(Silu)
def check_reduce_scatter_memory_pool(key):
return_idx = 0
# if key not in dict
if key not in gpc.config.reduce_scatter_memory:
gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []}
# if the data is empty
if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0:
gpc.config.reduce_scatter_memory[key]["data"].append(
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
)
gpc.config.reduce_scatter_memory[key]["used"].append(True)
return_idx = 0
return return_idx
else: # if not empty
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]):
if used is False:
gpc.config.reduce_scatter_memory[key]["used"][index] = True
return_idx = index
return return_idx
# if the memory pool is all used
length = len(gpc.config.reduce_scatter_memory[key]["data"])
gpc.config.reduce_scatter_memory[key]["data"].append(
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
)
gpc.config.reduce_scatter_memory[key]["used"].append(True)
return_idx = length
return return_idx
def release_reduce_scatter_memory_pool(size, index):
gpc.config.reduce_scatter_memory[size]["used"][index] = False

View File

@ -11,7 +11,6 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import ( from internlm.solver.optimizer.store import (
BucketStore, BucketStore,
@ -41,6 +40,7 @@ from .utils import compute_norm
inf = math.inf inf = math.inf
logger = get_logger(__file__) logger = get_logger(__file__)
class HybridZeroOptimizer(BaseOptimizer): class HybridZeroOptimizer(BaseOptimizer):
""" """
Hybrid Zero Optimizer. Hybrid Zero Optimizer.
@ -65,7 +65,7 @@ class HybridZeroOptimizer(BaseOptimizer):
backoff_factor = grad_scal_cfg.backoff_factor backoff_factor = grad_scal_cfg.backoff_factor
hysteresis = grad_scal_cfg.hysteresis hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale max_scale = grad_scal_cfg.max_scale
self._fstp_handler = None self._fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler self._fstp_handler = gpc.config.fstp_handler
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad) _param.grad.add_(_grad)
# release cuda memory. # release cuda memory.
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index) gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None self._fstp_handler.reduce_scatter_handlers[_key] = None
bucket.reset_by_rank(reduce_rank) bucket.reset_by_rank(reduce_rank)
@ -635,9 +635,9 @@ class HybridZeroOptimizer(BaseOptimizer):
timer("sync_grad").start() timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
res = self._step(closure=closure, norms=total_norms) res = self._step(closure=closure, norms=total_norms)
return res return res
def _step(self, closure=None, norms=None): def _step(self, closure=None, norms=None):

View File

@ -36,12 +36,12 @@ from internlm.data.packed_dataset import (
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
CoarseGrainedFSTPAllGatherSyncHandler,
FeedForward, FeedForward,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
from internlm.model.overlap_handler import FSTPOverlapHandler
from internlm.model.utils import try_import_RMSNorm from internlm.model.utils import try_import_RMSNorm
from internlm.monitor import send_heartbeat, set_env_var from internlm.monitor import send_heartbeat, set_env_var
from internlm.monitor.monitor import monitor_manager as mm from internlm.monitor.monitor import monitor_manager as mm
@ -109,60 +109,8 @@ def initialize_model():
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
gpc.config.fstp_handler = None gpc.config.fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True: if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler
# allocate memory pool
block_memory = {} # containing two groups of block weight
hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
world_size = gpc.get_world_size(ParallelMode.TENSOR)
size_key = [
(3 * hidden_size // world_size, hidden_size),
(mlp_hidden_size // world_size, hidden_size),
(hidden_size // world_size, mlp_hidden_size),
(hidden_size // world_size, hidden_size),
]
module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
for i in range(2):
weight = {}
for name in module_name:
if name == "Wqkv":
weight[name] = torch.zeros(
(3 * hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
elif name == "out_proj":
weight[name] = torch.zeros(
(hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
elif name == "w1" or name == "w2":
weight[name] = torch.zeros(
(mlp_hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
else:
weight[name] = torch.zeros(
(hidden_size, mlp_hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
block_memory[i] = weight
reduce_scatter_memory = {}
for key in size_key:
reduce_scatter_memory[key] = {"data": [], "used": []}
gpc.config.block_memory = block_memory
gpc.config.reduce_scatter_memory = reduce_scatter_memory
return model return model

View File

@ -299,7 +299,7 @@ def main(args):
if gpc.config.fstp_handler is not None: if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory = {} gpc.config.fstp_handler.reduce_scatter_memory_pool = {}
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()