mirror of https://github.com/InternLM/InternLM
support fine-grained overlap
parent
792b066f15
commit
5fd5a8a32b
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=1, fsdp=False),
|
||||
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
|
|
@ -11,7 +11,8 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
|
@ -211,8 +212,7 @@ class FeedForward(nn.Module):
|
|||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
import pdb; pdb.set_trace()
|
||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
|
||||
|
||||
|
||||
class FSTPFeedForward(nn.Module):
|
||||
|
@ -287,6 +287,7 @@ class FSTPAllGatherSyncHandler:
|
|||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
self.process_group = process_group
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
|
@ -306,15 +307,17 @@ class FSTPAllGatherSyncHandler:
|
|||
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for _, block in enumerate(children):
|
||||
for idx, block in enumerate(children):
|
||||
index = 0
|
||||
sub_modules = list(block.children())
|
||||
self.block_module[idx] = {}
|
||||
for _, sub in block.named_children():
|
||||
sub_modules = list(sub.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in block.named_children():
|
||||
for name, child in sub.named_children():
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.FSTP_modules.append(child)
|
||||
self.module_block[child] = _
|
||||
self.block_module[_][index] = child
|
||||
self.module_block[child] = idx
|
||||
self.block_module[idx][index] = child
|
||||
self.module_name_index[child] = index
|
||||
index = index + 1
|
||||
else:
|
||||
|
@ -326,27 +329,58 @@ class FSTPAllGatherSyncHandler:
|
|||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||
"""
|
||||
|
||||
def _hook(module: nn.Module):
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 0:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 4:
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||
_hook(module)
|
||||
def _post_forward_hook(module: nn.Module, input, output):
|
||||
del self.FSTP_global_weights[module]
|
||||
del self.module_handler[module]
|
||||
|
||||
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
||||
_hook(module)
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 4:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
def _post_backward_hook(module, grad_input, grad_output):
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
# import pdb; pdb.set_trace()
|
||||
module.register_forward_pre_hook(_pre_forward_hook)
|
||||
module.register_backward_pre_hook(_pre_backward_hook)
|
||||
module.register_forward_hook(_post_forward_hook)
|
||||
# module.register_backward_pre_hook(_pre_backward_hook)
|
||||
# module.register_backward_hook(_post_backward_hook)
|
||||
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
||||
|
|
@ -210,7 +210,7 @@ class MHA(nn.Module):
|
|||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
@ -231,6 +231,7 @@ class MHA(nn.Module):
|
|||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
|
|
@ -283,11 +283,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
|
||||
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.all_gather_handler = all_gather_handler
|
||||
ctx.module = module
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
|
@ -295,14 +297,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
# do all_gather for weight and bias before actual computation
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
if bias is not None:
|
||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
handle_bias.wait()
|
||||
else:
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
total_bias = bias
|
||||
handle_weight.wait()
|
||||
# # do all_gather for weight and bias before actual computation
|
||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
# if bias is not None:
|
||||
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||
# handle_bias.wait()
|
||||
# else:
|
||||
# total_bias = bias
|
||||
# handle_weight.wait()
|
||||
else:
|
||||
total_weight = weight
|
||||
total_bias = bias
|
||||
|
@ -332,6 +336,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
all_gather_handler = ctx.all_gather_handler
|
||||
module = ctx.module
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
total_x = x
|
||||
|
@ -345,8 +351,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if world_size > 1:
|
||||
# do all-gather for weight before backward
|
||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
handle_weight.wait()
|
||||
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||
# handle_weight.wait()
|
||||
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||
else:
|
||||
total_weight = weight
|
||||
|
||||
|
@ -379,7 +386,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
handle_grad_weight.wait()
|
||||
if grad_bias is not None:
|
||||
handle_grad_bias.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def fused_dense_func_torch(
|
||||
|
@ -401,13 +408,13 @@ def fused_dense_func_torch(
|
|||
|
||||
|
||||
def fstp_fused_dense_func(
|
||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
|
||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
|
|
|
@ -39,6 +39,7 @@ from internlm.model.linear import (
|
|||
FeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
FSTPAllGatherSyncHandler,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import try_import_RMSNorm
|
||||
|
@ -107,9 +108,12 @@ def initialize_model():
|
|||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
return model
|
||||
|
||||
|
||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
if gpc.config.parallel.zero1.fsdp:
|
||||
# set wrap_policy for fsdp wrap
|
||||
|
|
Loading…
Reference in New Issue