mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): support block allgather overlap
parent
5fd5a8a32b
commit
d0b1346993
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional, Union, Any
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -12,7 +12,12 @@ from torch import nn
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
|
from internlm.model.utils import (
|
||||||
|
Silu,
|
||||||
|
all_gather_raw,
|
||||||
|
fstp_fused_dense_func,
|
||||||
|
fused_dense_func_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class ScaleColumnParallelLinear(nn.Linear):
|
||||||
|
@ -212,7 +217,9 @@ class FeedForward(nn.Module):
|
||||||
|
|
||||||
class FSTPLinear(ColumnParallelLinear):
|
class FSTPLinear(ColumnParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
|
return fstp_fused_dense_func(
|
||||||
|
x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FSTPFeedForward(nn.Module):
|
class FSTPFeedForward(nn.Module):
|
||||||
|
@ -280,22 +287,22 @@ class FSTPFeedForward(nn.Module):
|
||||||
out = self.w3(F.silu(w1_o) * w2_o)
|
out = self.w3(F.silu(w1_o) * w2_o)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FSTPAllGatherSyncHandler:
|
class FSTPAllGatherSyncHandler:
|
||||||
"""
|
"""
|
||||||
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.FSTP_modules = []
|
self.FSTP_modules = []
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
@ -323,7 +330,6 @@ class FSTPAllGatherSyncHandler:
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
"""
|
"""
|
||||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||||
|
@ -339,14 +345,18 @@ class FSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
# start the all-gather for next module
|
# start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
else:
|
else:
|
||||||
handler = self.module_handler[module]
|
handler = self.module_handler[module]
|
||||||
handler.wait()
|
handler.wait()
|
||||||
if name_index != 4:
|
if name_index != 4:
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
def _post_forward_hook(module: nn.Module, input, output):
|
def _post_forward_hook(module: nn.Module, input, output):
|
||||||
|
@ -363,14 +373,18 @@ class FSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
# start the all-gather for next module
|
# start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
else:
|
else:
|
||||||
handler = self.module_handler[module]
|
handler = self.module_handler[module]
|
||||||
handler.wait()
|
handler.wait()
|
||||||
if name_index != 0:
|
if name_index != 0:
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
def _post_backward_hook(module, grad_input, grad_output):
|
def _post_backward_hook(module, grad_input, grad_output):
|
||||||
|
@ -384,3 +398,144 @@ class FSTPAllGatherSyncHandler:
|
||||||
# module.register_backward_hook(_post_backward_hook)
|
# module.register_backward_hook(_post_backward_hook)
|
||||||
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
||||||
|
|
||||||
|
|
||||||
|
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
|
"""
|
||||||
|
All-gather handler for overlapping the all-gather in adjcent FSTP block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
self.process_group = process_group
|
||||||
|
self.FSTP_blocks = []
|
||||||
|
self.FSTP_outs = []
|
||||||
|
self.FSTP_wqkvs = []
|
||||||
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
|
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||||
|
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||||
|
self.block_handles = dict() # key: transformer block; value: all-gather handles
|
||||||
|
self.module_to_index = dict() # key: FSTP module; value: transformer block index
|
||||||
|
self.block_to_index = dict() # key: transformer block; value: transformer block index
|
||||||
|
self.index_to_block = dict() # key: transformer block index; value: transformer block
|
||||||
|
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||||
|
|
||||||
|
# just want to share same for loop for ModuleList and Module
|
||||||
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
model = [model]
|
||||||
|
|
||||||
|
for _chunk in model:
|
||||||
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
|
_chunk = _chunk.model
|
||||||
|
|
||||||
|
for _, children in _chunk.named_children():
|
||||||
|
if isinstance(children, nn.ModuleList):
|
||||||
|
for idx, block in enumerate(children):
|
||||||
|
self.FSTP_blocks.append(block)
|
||||||
|
self.block_to_index[block] = idx
|
||||||
|
self.index_to_block[idx] = block
|
||||||
|
self.index_to_fsdp_modules[idx] = []
|
||||||
|
for _, sub in block.named_children():
|
||||||
|
sub_modules = list(sub.children())
|
||||||
|
if len(sub_modules) > 0:
|
||||||
|
for name, child in sub.named_children():
|
||||||
|
# print(f"name: {name}", flush=True)
|
||||||
|
if name == "out_proj":
|
||||||
|
self.FSTP_outs.append(child)
|
||||||
|
self.module_to_index[child] = idx
|
||||||
|
if name == "Wqkv":
|
||||||
|
self.FSTP_wqkvs.append(child)
|
||||||
|
self.module_to_index[child] = idx
|
||||||
|
if isinstance(child, FSTPLinear):
|
||||||
|
self.index_to_fsdp_modules[idx].append(child)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _all_gather_block_weight(self, block_index: int):
|
||||||
|
block = self.index_to_block[block_index]
|
||||||
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
self.block_handles[block] = []
|
||||||
|
for module in fsdp_modules:
|
||||||
|
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
self.block_handles[block].append(weight_handle)
|
||||||
|
|
||||||
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
|
"""
|
||||||
|
register pre_forward_hook and pre_backward_hook for FSTP block.
|
||||||
|
|
||||||
|
Notice that next block's all_gather op should be after current block's all_to_all op, so we
|
||||||
|
1. register pre_forward_hook @out_proj module to prefetch for next block
|
||||||
|
2. register pre_forward_hook @block module to wait handles for next block
|
||||||
|
3. register pre_backward_hook @wqkv module to prefetch for next block
|
||||||
|
4. register pre_backward_hook @block module to wait handles for next block
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
|
||||||
|
block_index = self.module_to_index[module]
|
||||||
|
# start the all-gather for next block
|
||||||
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
|
self._all_gather_block_weight(block_index + 1)
|
||||||
|
|
||||||
|
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||||
|
block_index = self.block_to_index[block]
|
||||||
|
if block_index == 0:
|
||||||
|
# all gather weight for block 0
|
||||||
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
for module in fsdp_modules:
|
||||||
|
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
weight_handle.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
else:
|
||||||
|
# wait handle for current block
|
||||||
|
handles = self.block_handles[block]
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
||||||
|
block_index = self.block_to_index[block]
|
||||||
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
if block in self.block_handles:
|
||||||
|
del self.block_handles[block]
|
||||||
|
for module in fsdp_modules:
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
||||||
|
block_index = self.module_to_index[module]
|
||||||
|
# start the all-gather for next block
|
||||||
|
if block_index - 1 >= 0:
|
||||||
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
|
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||||
|
block_index = self.block_to_index[block]
|
||||||
|
if block_index == gpc.config.NUM_LAYER - 1:
|
||||||
|
# all gather weight for the last block
|
||||||
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
for module in fsdp_modules:
|
||||||
|
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
weight_handle.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
else:
|
||||||
|
# wait handle for current block
|
||||||
|
handles = self.block_handles[block]
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||||
|
block_index = self.block_to_index[block]
|
||||||
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
if block in self.block_handles:
|
||||||
|
del self.block_handles[block]
|
||||||
|
for module in fsdp_modules:
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
for block in self.FSTP_blocks:
|
||||||
|
block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||||
|
block.register_forward_hook(_post_forward_hook_for_block)
|
||||||
|
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||||
|
block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||||
|
|
||||||
|
for out_proj in self.FSTP_outs:
|
||||||
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
||||||
|
for wqkv in self.FSTP_wqkvs:
|
||||||
|
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||||
|
|
|
@ -284,7 +284,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
|
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
|
||||||
|
|
||||||
ctx.compute_weight_gradient = weight.requires_grad
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
@ -297,16 +296,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
# do all_gather for weight and bias before actual computation
|
||||||
total_bias = bias
|
if module in all_gather_handler.FSTP_global_weights:
|
||||||
# # do all_gather for weight and bias before actual computation
|
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
else:
|
||||||
# if bias is not None:
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
handle_weight.wait()
|
||||||
# handle_bias.wait()
|
|
||||||
# else:
|
if bias is not None:
|
||||||
# total_bias = bias
|
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||||
# handle_weight.wait()
|
handle_bias.wait()
|
||||||
|
else:
|
||||||
|
total_bias = bias
|
||||||
else:
|
else:
|
||||||
total_weight = weight
|
total_weight = weight
|
||||||
total_bias = bias
|
total_bias = bias
|
||||||
|
@ -351,9 +352,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# do all-gather for weight before backward
|
# do all-gather for weight before backward
|
||||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
if module in all_gather_handler.FSTP_global_weights:
|
||||||
# handle_weight.wait()
|
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
else:
|
||||||
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
|
handle_weight.wait()
|
||||||
else:
|
else:
|
||||||
total_weight = weight
|
total_weight = weight
|
||||||
|
|
||||||
|
@ -408,7 +411,13 @@ def fused_dense_func_torch(
|
||||||
|
|
||||||
|
|
||||||
def fstp_fused_dense_func(
|
def fstp_fused_dense_func(
|
||||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
|
x: Tensor,
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Optional[Tensor] = None,
|
||||||
|
return_residual: bool = False,
|
||||||
|
process_group=None,
|
||||||
|
module=None,
|
||||||
|
handler=None,
|
||||||
):
|
):
|
||||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||||
|
@ -460,5 +469,3 @@ def Silu(w1_o, w2_o):
|
||||||
|
|
||||||
|
|
||||||
Silu = torch.jit.script(Silu)
|
Silu = torch.jit.script(Silu)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,10 +36,11 @@ from internlm.data.packed_dataset import (
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
|
CoarseGrainedFSTPAllGatherSyncHandler,
|
||||||
FeedForward,
|
FeedForward,
|
||||||
|
FSTPAllGatherSyncHandler,
|
||||||
RewardModelLinear,
|
RewardModelLinear,
|
||||||
ScaleColumnParallelLinear,
|
ScaleColumnParallelLinear,
|
||||||
FSTPAllGatherSyncHandler,
|
|
||||||
)
|
)
|
||||||
from internlm.model.multi_head_attention import MHA
|
from internlm.model.multi_head_attention import MHA
|
||||||
from internlm.model.utils import try_import_RMSNorm
|
from internlm.model.utils import try_import_RMSNorm
|
||||||
|
@ -109,11 +110,12 @@ def initialize_model():
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
if gpc.config.parallel.zero1.fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
|
|
Loading…
Reference in New Issue