mirror of https://github.com/InternLM/InternLM
Merge pull request #5 from yingtongxiong/fstp/refactor-hook-handle
feat(model/overlap_handler.py): refactor overlap hook handlepull/456/head
commit
b48687a7ff
|
@ -163,7 +163,7 @@ pipeline parallel (dict):
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, sp="megatron", intern_overlap=True),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
)
|
||||
|
||||
|
|
|
@ -1,22 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.utils import (
|
||||
Silu,
|
||||
all_gather_raw,
|
||||
all_gather_raw_memory_pool,
|
||||
fstp_fused_dense_func,
|
||||
fused_dense_func_torch,
|
||||
megatron_fused_dense_func_torch,
|
||||
|
@ -58,10 +53,12 @@ class BaseScaleColumnParallelLinear(nn.Linear):
|
|||
self.process_group = process_group
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||
"""
|
||||
ScaleColumnParallelLinear in flash implementation.
|
||||
"""
|
||||
|
||||
def forward(self, input, gather_dim=0): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
|
@ -79,6 +76,7 @@ class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
|
||||
class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
||||
"""
|
||||
ScaleColumnParallelLinear in megatron implementation.
|
||||
|
@ -101,6 +99,7 @@ class MegatronScaleColumnParallelLinear(BaseScaleColumnParallelLinear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
|
||||
class RewardModelLinear(ScaleColumnParallelLinear):
|
||||
"""
|
||||
RewardModelLinear.
|
||||
|
@ -164,6 +163,7 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
|
||||
class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
|
||||
def forward(self, x, gather_dim=0):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
|
@ -178,6 +178,7 @@ class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearTorch(RowParallelLinear):
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -188,6 +189,7 @@ class RowParallelLinearTorch(RowParallelLinear):
|
|||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class MegatronRowParallelLinearTorch(RowParallelLinear):
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -225,8 +227,8 @@ class BaseFeedForward(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
colum_cls = None,
|
||||
row_cls = None,
|
||||
colum_cls=None,
|
||||
row_cls=None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
@ -265,6 +267,7 @@ class BaseFeedForward(nn.Module):
|
|||
out = self.w3(Silu(w1_o, w2_o))
|
||||
return out
|
||||
|
||||
|
||||
class FeedForward(BaseFeedForward):
|
||||
"""
|
||||
FeedForward in flash implementation.
|
||||
|
@ -292,8 +295,18 @@ class FeedForward(BaseFeedForward):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, ColumnParallelLinearTorch, RowParallelLinearTorch)
|
||||
super().__init__(
|
||||
in_features,
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias,
|
||||
device,
|
||||
dtype,
|
||||
multiple_of,
|
||||
ColumnParallelLinearTorch,
|
||||
RowParallelLinearTorch,
|
||||
)
|
||||
|
||||
|
||||
class MegatronFeedForward(BaseFeedForward):
|
||||
|
@ -323,19 +336,32 @@ class MegatronFeedForward(BaseFeedForward):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, MegatronColumnParallelLinearTorch, MegatronRowParallelLinearTorch)
|
||||
super().__init__(
|
||||
in_features,
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias,
|
||||
device,
|
||||
dtype,
|
||||
multiple_of,
|
||||
MegatronColumnParallelLinearTorch,
|
||||
MegatronRowParallelLinearTorch,
|
||||
)
|
||||
|
||||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
block_index = gpc.config.fstp_handler.module_to_index[self]
|
||||
name_index = gpc.config.fstp_handler.module_name_index[self]
|
||||
name = gpc.config.fstp_handler.module_name[name_index]
|
||||
return fstp_fused_dense_func(
|
||||
x, self.weight, self.bias, process_group=self.process_group,
|
||||
module=self, handler=gpc.config.fstp_handler, block_index=block_index, module_name=name
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
module=self,
|
||||
handler=gpc.fstp_handler,
|
||||
)
|
||||
|
||||
|
||||
class FSTPFeedForward(BaseFeedForward):
|
||||
"""
|
||||
FeedForward in FSTP.
|
||||
|
@ -363,8 +389,19 @@ class FSTPFeedForward(BaseFeedForward):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
super().__init__(in_features, hidden_features, out_features, process_group, bias, device,
|
||||
dtype, multiple_of, FSTPLinear, FSTPLinear)
|
||||
super().__init__(
|
||||
in_features,
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias,
|
||||
device,
|
||||
dtype,
|
||||
multiple_of,
|
||||
FSTPLinear,
|
||||
FSTPLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_mlp_cls(sp_mode: str):
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
|
@ -375,6 +412,7 @@ def get_mlp_cls(sp_mode: str):
|
|||
mlp_cls = FSTPFeedForward
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def get_linear_cls(sp_mode: str, parallel_mode: str):
|
||||
if parallel_mode == "column":
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
|
@ -383,7 +421,7 @@ def get_linear_cls(sp_mode: str, parallel_mode: str):
|
|||
cls = MegatronColumnParallelLinearTorch
|
||||
else:
|
||||
cls = FSTPLinear
|
||||
elif parallel_mode == 'row':
|
||||
elif parallel_mode == "row":
|
||||
if sp_mode in ["none", "flash-attn"]:
|
||||
cls = RowParallelLinearTorch
|
||||
elif sp_mode == "megatron":
|
||||
|
@ -391,192 +429,3 @@ def get_linear_cls(sp_mode: str, parallel_mode: str):
|
|||
else:
|
||||
cls = FSTPLinear
|
||||
return cls
|
||||
|
||||
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||
"""
|
||||
All-gather handler for overlapping the all-gather in adjcent FSTP block.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
# import pdb; pdb.set_trace()
|
||||
self.process_group = process_group
|
||||
self.FSTP_blocks = []
|
||||
self.FSTP_outs = []
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.block_handles = dict() # key: transformer block; value: all-gather handles
|
||||
self.module_to_index = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_to_index = dict() # key: transformer block; value: transformer block index
|
||||
self.index_to_block = dict() # key: transformer block index; value: transformer block
|
||||
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
self.all_reduce_handlers = {}
|
||||
self.zero_const_pool = {}
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _chunk_name, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
index = 0
|
||||
self.block_module[idx] = {}
|
||||
self.FSTP_blocks.append(block)
|
||||
self.block_to_index[block] = idx
|
||||
self.index_to_block[idx] = block
|
||||
self.index_to_fsdp_modules[idx] = []
|
||||
for _sub_name, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
if name == "out_proj":
|
||||
self.FSTP_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.block_module[idx][index] = child
|
||||
self.FSTP_modules.append(child)
|
||||
self.index_to_fsdp_modules[idx].append(child)
|
||||
self.module_name_index[child] = index
|
||||
index = index + 1
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
else:
|
||||
continue
|
||||
elif isinstance(children, ScaleColumnParallelLinear):
|
||||
self.head.append(children)
|
||||
elif isinstance(children, Embedding1D):
|
||||
self.embedding.append(children)
|
||||
|
||||
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
|
||||
if size not in self.zero_const_pool:
|
||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
return self.zero_const_pool[size]
|
||||
|
||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
for module in fsdp_modules:
|
||||
module_index = self.module_name_index[module]
|
||||
name = self.module_name[module_index]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||
)
|
||||
self.FSTP_global_handle[module] = weight_handle
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register pre_forward_hook and pre_backward_hook for FSTP block.
|
||||
|
||||
Notice that next block's all_gather op should be after current block's all_to_all op, so we
|
||||
1. register pre_forward_hook @out_proj module to prefetch for next block
|
||||
2. register pre_forward_hook @block module to wait handles for next block
|
||||
3. register pre_backward_hook @wqkv module to prefetch for next block
|
||||
4. register pre_backward_hook @block module to wait handles for next block
|
||||
"""
|
||||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||
self._all_gather_block_weight_memory_pool(0)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
handle = self.FSTP_global_handle[module]
|
||||
handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
|
||||
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_handle[first_module] = weight_handler
|
||||
self.FSTP_global_weights[first_module] = total_weight
|
||||
|
||||
def _pre_backward_hook_for_module_memory_pool(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
name_index = self.module_name_index[module]
|
||||
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
next_name = self.module_name[name_index - 1]
|
||||
weights_handler = all_gather_raw_memory_pool(
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index,
|
||||
module_name=next_name,
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
elif name_index == 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
if block_index - 1 >= 0:
|
||||
next_module = self.block_module[block_index - 1][4]
|
||||
name = self.module_name[4]
|
||||
weights_handler = all_gather_raw_memory_pool(
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
block_index=block_index - 1,
|
||||
module_name=name,
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
name = self.module_name[name_index - 1]
|
||||
weights_handler = all_gather_raw_memory_pool(
|
||||
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
for out_proj in self.FSTP_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||
module.register_forward_hook(_post_forward_hook_for_module)
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
|
||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||
|
|
|
@ -14,12 +14,9 @@ from internlm.core.context.parallel_context import global_context as gpc
|
|||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
MegatronFeedForward,
|
||||
FSTPFeedForward,
|
||||
MegatronScaleColumnParallelLinear,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
MegatronScaleColumnParallelLinear,
|
||||
get_mlp_cls,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
|
@ -309,7 +306,11 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
else:
|
||||
head_cls = ScaleColumnParallelLinear if self.sp_mode in ["flash-attn", "none", "intern"] else MegatronScaleColumnParallelLinear
|
||||
head_cls = (
|
||||
ScaleColumnParallelLinear
|
||||
if self.sp_mode in ["flash-attn", "none", "intern"]
|
||||
else MegatronScaleColumnParallelLinear
|
||||
)
|
||||
if first:
|
||||
if embed_split_hidden:
|
||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||
|
|
|
@ -53,7 +53,6 @@ class MoE(torch.nn.Module):
|
|||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
|
|
|
@ -38,14 +38,7 @@ from torch.nn import Module
|
|||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
|
||||
from internlm.model.linear import (
|
||||
ColumnParallelLinearTorch,
|
||||
FSTPLinear,
|
||||
RowParallelLinearTorch,
|
||||
MegatronColumnParallelLinearTorch,
|
||||
MegatronRowParallelLinearTorch,
|
||||
get_linear_cls,
|
||||
)
|
||||
from internlm.model.linear import get_linear_cls
|
||||
|
||||
|
||||
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||||
|
@ -227,7 +220,7 @@ class MHA(nn.Module):
|
|||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = get_linear_cls(sp_mode, 'row')
|
||||
out_proj_cls = get_linear_cls(sp_mode, "row")
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
|
||||
from internlm.model.utils import (
|
||||
all_gather_raw_bias_memory_pool,
|
||||
all_gather_raw_memory_pool,
|
||||
)
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
|
||||
class FSTPOverlapHandler:
|
||||
"""
|
||||
FSTP overlap handler for managing the all-gather and reduce_scatter overlapping.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
self.process_group = process_group
|
||||
self.fstp_outs = []
|
||||
self.fstp_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
|
||||
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
|
||||
self.module_to_index = dict() # key: fstp module; value: transformer block index
|
||||
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
self.zero_const_pool = {}
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _chunk_name, children in _chunk.named_children():
|
||||
if isinstance(children, ScaleColumnParallelLinear):
|
||||
self.head.append(children)
|
||||
elif isinstance(children, Embedding1D):
|
||||
self.embedding.append(children)
|
||||
elif isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
self.index_to_fstp_modules[idx] = []
|
||||
for _sub_name, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
if name == "out_proj":
|
||||
self.fstp_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.fstp_modules.append(child)
|
||||
self.index_to_fstp_modules[idx].append(child)
|
||||
|
||||
setattr(child, "_fstp_name", name)
|
||||
|
||||
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
|
||||
self._initialize_memory_pool()
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
|
||||
if size not in self.zero_const_pool:
|
||||
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()
|
||||
|
||||
return self.zero_const_pool[size]
|
||||
|
||||
def _initialize_module_shape(self):
|
||||
hidden_size = gpc.config.HIDDEN_SIZE
|
||||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||
|
||||
self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
|
||||
self.module_shape["out_proj"] = (hidden_size, hidden_size)
|
||||
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
|
||||
self.module_shape["w2"] = (mlp_hidden_size, hidden_size)
|
||||
self.module_shape["w3"] = (hidden_size, mlp_hidden_size)
|
||||
|
||||
def _initialize_memory_pool(self) -> None:
|
||||
# allocate memory pool
|
||||
self.all_gather_memory_pool = []
|
||||
self.all_gather_bias_memory_pool = []
|
||||
self.reduce_scatter_memory_pool = {}
|
||||
self.module_shape = {}
|
||||
|
||||
self._initialize_module_shape()
|
||||
dtype = gpc.config.model.get("dtype", torch.half)
|
||||
device = get_current_device()
|
||||
|
||||
for _ in range(2):
|
||||
weight = {}
|
||||
for name in self.module_name:
|
||||
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
|
||||
self.all_gather_memory_pool.append(weight) # containing two groups of block weight
|
||||
|
||||
def clear_memory_pool(self) -> None:
|
||||
self.zero_const_pool = {}
|
||||
self.reduce_scatter_memory_pool = {}
|
||||
|
||||
def get_all_gather_memory(self, module):
|
||||
block_index = self.module_to_index[module]
|
||||
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_bias_memory(self, module: nn.Module):
|
||||
block_index = self.module_to_index[module]
|
||||
# if the bias memory pool is empty or module has been not allocated memory
|
||||
# import pdb; pdb.set_trace()
|
||||
if len(self.all_gather_bias_memory_pool) == 0:
|
||||
for _ in range(2):
|
||||
weight = {}
|
||||
weight[module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
self.all_gather_bias_memory_pool.append(weight)
|
||||
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
|
||||
for i in range(2):
|
||||
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
|
||||
self.module_shape[module._fstp_name][0],
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
|
||||
return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]
|
||||
|
||||
def get_reduce_scatter_memory(self, key):
|
||||
return_idx = 0
|
||||
|
||||
# if key not in dict
|
||||
if key not in self.reduce_scatter_memory_pool:
|
||||
self.reduce_scatter_memory_pool[key] = []
|
||||
|
||||
# if the data is empty
|
||||
if len(self.reduce_scatter_memory_pool[key]) == 0:
|
||||
self.reduce_scatter_memory_pool[key].append(
|
||||
torch.zeros(
|
||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||
).contiguous()
|
||||
)
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
else: # if not empty
|
||||
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
|
||||
if mem_item.idle is True:
|
||||
self.reduce_scatter_memory_pool[key][index].idle = False
|
||||
return_idx = index
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
# if the memory pool is all used
|
||||
cur_len = len(self.reduce_scatter_memory_pool[key])
|
||||
self.reduce_scatter_memory_pool[key].append(
|
||||
torch.zeros(
|
||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||
).contiguous()
|
||||
)
|
||||
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
|
||||
return_idx = cur_len
|
||||
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||
|
||||
def release_reduce_scatter_memory(self, key, index):
|
||||
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||
|
||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||
for module in fstp_modules:
|
||||
if module.bias is not None:
|
||||
bias_handle = all_gather_raw_bias_memory_pool(
|
||||
module.bias,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.bias_global_handle[module] = bias_handle
|
||||
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register forward hooks and backward hooks for fstp modules.
|
||||
"""
|
||||
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
self._all_gather_block_weight_memory_pool(0)
|
||||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||
first_backward_module = self.fstp_modules[-1]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
first_backward_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=first_backward_module,
|
||||
)
|
||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
# wait handle for current module
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
weight_handle.wait()
|
||||
|
||||
# start the all-gather for next module
|
||||
module_index = self.fstp_modules.index(module)
|
||||
if module_index - 1 >= 0:
|
||||
next_module = self.fstp_modules[module_index - 1]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
next_module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=next_module,
|
||||
)
|
||||
self.fstp_global_handle[next_module] = weight_handle
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
|
||||
# register forward hooks
|
||||
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
||||
# 2. register pre_forward_hook @out_proj module to prefetch for next block,
|
||||
# notice that next block's all_gather op should be after current block's all_to_all op
|
||||
# 3. register pre_forward_hook @fstp_module to wait handle for current module
|
||||
# 4. register post_forward_hook @fstp_module to release resource
|
||||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
for out_proj in self.fstp_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
for module in self.fstp_modules:
|
||||
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||
module.register_forward_hook(_post_forward_hook_for_module)
|
||||
|
||||
# register backward hooks
|
||||
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
||||
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
||||
# 3. register post_backward_hook @fstp_module to release resource
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
for module in self.fstp_modules:
|
||||
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||
module.register_full_backward_hook(_post_backward_hook_for_module)
|
|
@ -7,13 +7,12 @@ import fused_dense_lib as fused_dense_cuda
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.utils.distributed import all_reduce_raw
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -130,12 +129,25 @@ def all_gather_raw_memory_pool(
|
|||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
block_index: int = None,
|
||||
module_name: str = None,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.config.block_memory[block_index % 2][module_name],
|
||||
gpc.fstp_handler.get_all_gather_memory(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
)
|
||||
return handle
|
||||
|
||||
|
||||
def all_gather_raw_bias_memory_pool(
|
||||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
gpc.fstp_handler.get_bias_memory(module=module),
|
||||
input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op,
|
||||
|
@ -166,9 +178,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
|||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
index = check_reduce_scatter_memory_pool(size)
|
||||
output = gpc.config.reduce_scatter_memory[size]["data"][index]
|
||||
setattr(output, "index", index)
|
||||
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
|
@ -269,11 +279,11 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
|
||||
class MegatronFusedDenseFunc(torch.autograd.Function):
|
||||
'''
|
||||
"""
|
||||
FusedDenseFunc for tensor parallel in megatron implementation.
|
||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
|
||||
so that the all-gather in backward is ommited.
|
||||
'''
|
||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be
|
||||
saved for backward in megatron, so that the all-gather in backward is ommited.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
|
@ -355,9 +365,10 @@ class MegatronFusedDenseFunc(torch.autograd.Function):
|
|||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFuncTorch(FusedDenseFunc):
|
||||
'''FusedDenseFunc in flash implementation for supporting torch.float32'''
|
||||
"""FusedDenseFunc in flash implementation for supporting torch.float32"""
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
|
@ -407,8 +418,9 @@ class FusedDenseFuncTorch(FusedDenseFunc):
|
|||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
||||
'''FusedDenseFunc in megatron implementation for supporting torch.float32'''
|
||||
"""FusedDenseFunc in megatron implementation for supporting torch.float32"""
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
|
@ -419,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
|||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
gather_dim = ctx.gather_dim
|
||||
if ctx.compute_weight_gradient:
|
||||
total_x, weight = ctx.saved_tensors
|
||||
else:
|
||||
|
@ -452,6 +463,7 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
|||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||
"FusedDenseFunc for FSTP, which is optimized based on flash implementation."
|
||||
|
||||
|
@ -466,16 +478,12 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
process_group=None,
|
||||
module=None,
|
||||
overlap_handler=None,
|
||||
block_index=None,
|
||||
module_name=None,
|
||||
):
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.overlap_handler = overlap_handler
|
||||
ctx.module = module
|
||||
ctx.block_index = block_index
|
||||
ctx.module_name = module_name
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
|
@ -485,12 +493,15 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
if world_size > 1:
|
||||
# do all_gather for weight and bias before actual computation
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
# TODO memory pool for bias
|
||||
if bias is not None:
|
||||
if overlap_handler is not None:
|
||||
total_bias = gpc.fstp_handler.get_bias_memory(module=module)
|
||||
else:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
else:
|
||||
|
@ -528,8 +539,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
overlap_handler = ctx.overlap_handler
|
||||
block_index = ctx.block_index
|
||||
module_name = ctx.module_name
|
||||
module = ctx.module
|
||||
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight, bias = ctx.saved_tensors
|
||||
|
@ -544,7 +554,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
total_weight = gpc.config.block_memory[block_index % 2][module_name]
|
||||
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
@ -559,17 +569,39 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
)
|
||||
if world_size > 1:
|
||||
if overlap_handler is not None:
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(
|
||||
grad_weight, process_group, async_op=True
|
||||
)
|
||||
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||
grad_weight = overlap_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
overlap_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (
|
||||
handle_grad_weight,
|
||||
grad_weight_async,
|
||||
)
|
||||
grad_weight = overlap_handler.get_zero_by_shape(
|
||||
(
|
||||
grad_weight.shape[0] // torch.distributed.get_world_size(process_group),
|
||||
*grad_weight.shape[1:],
|
||||
),
|
||||
dtype=grad_weight.dtype,
|
||||
device=grad_weight.device,
|
||||
)
|
||||
if grad_bias is not None:
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(
|
||||
grad_bias, process_group, async_op=True
|
||||
)
|
||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||
grad_bias = overlap_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
overlap_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (
|
||||
handle_grad_bias,
|
||||
grad_bias_async,
|
||||
)
|
||||
grad_bias = overlap_handler.get_zero_by_shape(
|
||||
(
|
||||
grad_bias.shape[0] // torch.distributed.get_world_size(process_group),
|
||||
*grad_bias.shape[1:],
|
||||
),
|
||||
dtype=grad_bias.dtype,
|
||||
device=grad_bias.device,
|
||||
)
|
||||
else:
|
||||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||
if grad_bias is not None:
|
||||
|
@ -613,6 +645,7 @@ def fused_dense_func_torch(
|
|||
else:
|
||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
|
||||
|
||||
def megatron_fused_dense_func_torch(
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
|
@ -626,9 +659,14 @@ def megatron_fused_dense_func_torch(
|
|||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return MegatronFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
return MegatronFusedDenseFunc.apply(
|
||||
x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim
|
||||
)
|
||||
else:
|
||||
return MegatronFusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
return MegatronFusedDenseFuncTorch.apply(
|
||||
x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim
|
||||
)
|
||||
|
||||
|
||||
def fstp_fused_dense_func(
|
||||
x: Tensor,
|
||||
|
@ -638,16 +676,12 @@ def fstp_fused_dense_func(
|
|||
process_group=None,
|
||||
module=None,
|
||||
handler=None,
|
||||
block_index=None,
|
||||
module_name=None,
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FSTPFusedDenseFunc.apply(
|
||||
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
|
||||
)
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
|
@ -693,38 +727,3 @@ def Silu(w1_o, w2_o):
|
|||
|
||||
|
||||
Silu = torch.jit.script(Silu)
|
||||
|
||||
|
||||
def check_reduce_scatter_memory_pool(key):
|
||||
return_idx = 0
|
||||
|
||||
# if key not in dict
|
||||
if key not in gpc.config.reduce_scatter_memory:
|
||||
gpc.config.reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||
|
||||
# if the data is empty
|
||||
if len(gpc.config.reduce_scatter_memory[key]["data"]) == 0:
|
||||
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||
)
|
||||
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||
return_idx = 0
|
||||
return return_idx
|
||||
else: # if not empty
|
||||
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]["used"]):
|
||||
if used is False:
|
||||
gpc.config.reduce_scatter_memory[key]["used"][index] = True
|
||||
return_idx = index
|
||||
return return_idx
|
||||
# if the memory pool is all used
|
||||
length = len(gpc.config.reduce_scatter_memory[key]["data"])
|
||||
gpc.config.reduce_scatter_memory[key]["data"].append(
|
||||
torch.zeros(key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()).contiguous()
|
||||
)
|
||||
gpc.config.reduce_scatter_memory[key]["used"].append(True)
|
||||
return_idx = length
|
||||
return return_idx
|
||||
|
||||
|
||||
def release_reduce_scatter_memory_pool(size, index):
|
||||
gpc.config.reduce_scatter_memory[size]["used"][index] = False
|
||||
|
|
|
@ -11,7 +11,6 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import release_reduce_scatter_memory_pool
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
@ -41,6 +40,7 @@ from .utils import compute_norm
|
|||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class HybridZeroOptimizer(BaseOptimizer):
|
||||
"""
|
||||
Hybrid Zero Optimizer.
|
||||
|
@ -68,7 +68,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
self._fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
self._fstp_handler = gpc.config.fstp_handler
|
||||
self._fstp_handler = gpc.fstp_handler
|
||||
|
||||
# Zero related args
|
||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||
|
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
_param.grad.add_(_grad)
|
||||
|
||||
# release cuda memory.
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
self._fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
|
||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||
|
||||
bucket.reset_by_rank(reduce_rank)
|
||||
|
|
|
@ -36,12 +36,12 @@ from internlm.data.packed_dataset import (
|
|||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
CoarseGrainedFSTPAllGatherSyncHandler,
|
||||
FeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.overlap_handler import FSTPOverlapHandler
|
||||
from internlm.model.utils import try_import_RMSNorm
|
||||
from internlm.monitor import send_heartbeat, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
|
@ -50,7 +50,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
|||
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import DummyProfile, get_current_device
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||
|
@ -108,61 +108,9 @@ def initialize_model():
|
|||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
gpc.config.fstp_handler = None
|
||||
|
||||
gpc.fstp_handler = None
|
||||
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
|
||||
# allocate memory pool
|
||||
block_memory = {} # containing two groups of block weight
|
||||
hidden_size = gpc.config.HIDDEN_SIZE
|
||||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
size_key = [
|
||||
(3 * hidden_size // world_size, hidden_size),
|
||||
(mlp_hidden_size // world_size, hidden_size),
|
||||
(hidden_size // world_size, mlp_hidden_size),
|
||||
(hidden_size // world_size, hidden_size),
|
||||
]
|
||||
module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
for i in range(2):
|
||||
weight = {}
|
||||
for name in module_name:
|
||||
if name == "Wqkv":
|
||||
weight[name] = torch.zeros(
|
||||
(3 * hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "out_proj":
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
elif name == "w1" or name == "w2":
|
||||
weight[name] = torch.zeros(
|
||||
(mlp_hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
else:
|
||||
weight[name] = torch.zeros(
|
||||
(hidden_size, mlp_hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device(),
|
||||
).contiguous()
|
||||
block_memory[i] = weight
|
||||
reduce_scatter_memory = {}
|
||||
for key in size_key:
|
||||
reduce_scatter_memory[key] = {"data": [], "used": []}
|
||||
|
||||
gpc.config.block_memory = block_memory
|
||||
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
||||
gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
|
||||
return model
|
||||
|
||||
|
|
5
train.py
5
train.py
|
@ -297,9 +297,8 @@ def main(args):
|
|||
|
||||
prof.step()
|
||||
|
||||
if gpc.config.fstp_handler is not None:
|
||||
gpc.config.fstp_handler.zero_const_pool = {}
|
||||
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
||||
if gpc.fstp_handler is not None:
|
||||
gpc.fstp_handler.clear_memory_pool()
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
|
Loading…
Reference in New Issue