mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): support block allgather overlap
parent
5fd5a8a32b
commit
d0b1346993
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Union, Any
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -12,7 +12,12 @@ from torch import nn
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
|
||||
from internlm.model.utils import (
|
||||
Silu,
|
||||
all_gather_raw,
|
||||
fstp_fused_dense_func,
|
||||
fused_dense_func_torch,
|
||||
)
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
|
@ -212,7 +217,9 @@ class FeedForward(nn.Module):
|
|||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
|
||||
return fstp_fused_dense_func(
|
||||
x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler
|
||||
)
|
||||
|
||||
|
||||
class FSTPFeedForward(nn.Module):
|
||||
|
@ -280,31 +287,31 @@ class FSTPFeedForward(nn.Module):
|
|||
out = self.w3(F.silu(w1_o) * w2_o)
|
||||
return out
|
||||
|
||||
|
||||
class FSTPAllGatherSyncHandler:
|
||||
"""
|
||||
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
self.process_group = process_group
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
|
@ -322,13 +329,12 @@ class FSTPAllGatherSyncHandler:
|
|||
index = index + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||
"""
|
||||
|
||||
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
|
@ -336,19 +342,23 @@ class FSTPAllGatherSyncHandler:
|
|||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 4:
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
|
||||
def _post_forward_hook(module: nn.Module, input, output):
|
||||
del self.FSTP_global_weights[module]
|
||||
del self.module_handler[module]
|
||||
|
@ -360,22 +370,26 @@ class FSTPAllGatherSyncHandler:
|
|||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
|
||||
def _post_backward_hook(module, grad_input, grad_output):
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
# import pdb; pdb.set_trace()
|
||||
module.register_forward_pre_hook(_pre_forward_hook)
|
||||
|
@ -383,4 +397,145 @@ class FSTPAllGatherSyncHandler:
|
|||
# module.register_backward_pre_hook(_pre_backward_hook)
|
||||
# module.register_backward_hook(_post_backward_hook)
|
||||
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
||||
|
||||
|
||||
|
||||
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||
"""
|
||||
All-gather handler for overlapping the all-gather in adjcent FSTP block.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
# import pdb; pdb.set_trace()
|
||||
self.process_group = process_group
|
||||
self.FSTP_blocks = []
|
||||
self.FSTP_outs = []
|
||||
self.FSTP_wqkvs = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.block_handles = dict() # key: transformer block; value: all-gather handles
|
||||
self.module_to_index = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_to_index = dict() # key: transformer block; value: transformer block index
|
||||
self.index_to_block = dict() # key: transformer block index; value: transformer block
|
||||
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
self.FSTP_blocks.append(block)
|
||||
self.block_to_index[block] = idx
|
||||
self.index_to_block[idx] = block
|
||||
self.index_to_fsdp_modules[idx] = []
|
||||
for _, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in sub.named_children():
|
||||
# print(f"name: {name}", flush=True)
|
||||
if name == "out_proj":
|
||||
self.FSTP_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if name == "Wqkv":
|
||||
self.FSTP_wqkvs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.index_to_fsdp_modules[idx].append(child)
|
||||
else:
|
||||
continue
|
||||
|
||||
def _all_gather_block_weight(self, block_index: int):
|
||||
block = self.index_to_block[block_index]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
self.block_handles[block] = []
|
||||
for module in fsdp_modules:
|
||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
self.block_handles[block].append(weight_handle)
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register pre_forward_hook and pre_backward_hook for FSTP block.
|
||||
|
||||
Notice that next block's all_gather op should be after current block's all_to_all op, so we
|
||||
1. register pre_forward_hook @out_proj module to prefetch for next block
|
||||
2. register pre_forward_hook @block module to wait handles for next block
|
||||
3. register pre_backward_hook @wqkv module to prefetch for next block
|
||||
4. register pre_backward_hook @block module to wait handles for next block
|
||||
"""
|
||||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||
block_index = self.block_to_index[block]
|
||||
if block_index == 0:
|
||||
# all gather weight for block 0
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
for module in fsdp_modules:
|
||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handle.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
else:
|
||||
# wait handle for current block
|
||||
handles = self.block_handles[block]
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
||||
block_index = self.block_to_index[block]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
if block in self.block_handles:
|
||||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||
block_index = self.block_to_index[block]
|
||||
if block_index == gpc.config.NUM_LAYER - 1:
|
||||
# all gather weight for the last block
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
for module in fsdp_modules:
|
||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handle.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
else:
|
||||
# wait handle for current block
|
||||
handles = self.block_handles[block]
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||
block_index = self.block_to_index[block]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
if block in self.block_handles:
|
||||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
for block in self.FSTP_blocks:
|
||||
block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
block.register_forward_hook(_post_forward_hook_for_block)
|
||||
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||
block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||
|
||||
for out_proj in self.FSTP_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
for wqkv in self.FSTP_wqkvs:
|
||||
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||
|
|
|
@ -284,7 +284,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
|
||||
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
|
@ -297,16 +296,18 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
total_bias = bias
|
||||
# # do all_gather for weight and bias before actual computation
|
||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
# if bias is not None:
|
||||
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
# handle_bias.wait()
|
||||
# else:
|
||||
# total_bias = bias
|
||||
# handle_weight.wait()
|
||||
# do all_gather for weight and bias before actual computation
|
||||
if module in all_gather_handler.FSTP_global_weights:
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
|
||||
if bias is not None:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
else:
|
||||
total_bias = bias
|
||||
else:
|
||||
total_weight = weight
|
||||
total_bias = bias
|
||||
|
@ -351,12 +352,14 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
# do all-gather for weight before backward
|
||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
# handle_weight.wait()
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
if module in all_gather_handler.FSTP_global_weights:
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
else:
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
else:
|
||||
total_weight = weight
|
||||
|
||||
|
||||
# compute weight grad
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
|
@ -380,7 +383,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
else:
|
||||
grad_input = None
|
||||
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
if world_size > 1:
|
||||
handle_grad_weight.wait()
|
||||
|
@ -408,7 +411,13 @@ def fused_dense_func_torch(
|
|||
|
||||
|
||||
def fstp_fused_dense_func(
|
||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False,
|
||||
process_group=None,
|
||||
module=None,
|
||||
handler=None,
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
|
@ -460,5 +469,3 @@ def Silu(w1_o, w2_o):
|
|||
|
||||
|
||||
Silu = torch.jit.script(Silu)
|
||||
|
||||
|
||||
|
|
|
@ -36,10 +36,11 @@ from internlm.data.packed_dataset import (
|
|||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
CoarseGrainedFSTPAllGatherSyncHandler,
|
||||
FeedForward,
|
||||
FSTPAllGatherSyncHandler,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
FSTPAllGatherSyncHandler,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import try_import_RMSNorm
|
||||
|
@ -107,13 +108,14 @@ def initialize_model():
|
|||
|
||||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
return model
|
||||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
# set wrap_policy for fsdp wrap
|
||||
|
|
Loading…
Reference in New Issue