support fine-grained overlap

pull/407/head
yingtongxiong 2023-10-11 17:36:41 +08:00
parent 792b066f15
commit 5fd5a8a32b
5 changed files with 86 additions and 40 deletions

View File

@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=1, fsdp=False), zero1=dict(size=1, fsdp=False),
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True, sequence_parallel=True,
) )

View File

@ -11,7 +11,8 @@ from torch import nn
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
class ScaleColumnParallelLinear(nn.Linear): class ScaleColumnParallelLinear(nn.Linear):
@ -211,8 +212,7 @@ class FeedForward(nn.Module):
class FSTPLinear(ColumnParallelLinear): class FSTPLinear(ColumnParallelLinear):
def forward(self, x): def forward(self, x):
import pdb; pdb.set_trace() return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
class FSTPFeedForward(nn.Module): class FSTPFeedForward(nn.Module):
@ -287,6 +287,7 @@ class FSTPAllGatherSyncHandler:
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
# import pdb; pdb.set_trace()
self.process_group = process_group self.process_group = process_group
self.FSTP_modules = [] self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
@ -306,19 +307,21 @@ class FSTPAllGatherSyncHandler:
for _, children in _chunk.named_children(): for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList): if isinstance(children, nn.ModuleList):
for _, block in enumerate(children): for idx, block in enumerate(children):
index = 0 index = 0
sub_modules = list(block.children()) self.block_module[idx] = {}
if len(sub_modules) > 0: for _, sub in block.named_children():
for name, child in block.named_children(): sub_modules = list(sub.children())
if isinstance(child, FSTPLinear): if len(sub_modules) > 0:
self.FSTP_modules.append(child) for name, child in sub.named_children():
self.module_block[child] = _ if isinstance(child, FSTPLinear):
self.block_module[_][index] = child self.FSTP_modules.append(child)
self.module_name_index[child] = index self.module_block[child] = idx
index = index + 1 self.block_module[idx][index] = child
else: self.module_name_index[child] = index
continue index = index + 1
else:
continue
def _register_sync_parameters_hook(self) -> None: def _register_sync_parameters_hook(self) -> None:
@ -326,27 +329,58 @@ class FSTPAllGatherSyncHandler:
register pre_forward_hook and pre_backward_hook for FSTPLinear. register pre_forward_hook and pre_backward_hook for FSTPLinear.
""" """
def _hook(module: nn.Module): def _pre_forward_hook(module: nn.Module, inputs: Any):
block_index = self.module_block[module] block_index = self.module_block[module]
name_index = self.module_name_index[module] name_index = self.module_name_index[module]
if name_index == 0: if name_index == 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index + 1] next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler self.module_handler[next_module] = weights_handler
else: else:
handler = self.module_handler[module] handler = self.module_handler[module]
handler.wait() handler.wait()
if name_index != 4: if name_index != 4:
next_module = self.block_module[block_index][name_index + 1] next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler self.module_handler[next_module] = weights_handler
def _pre_forward_hook(module: nn.Module, inputs: Any): def _post_forward_hook(module: nn.Module, input, output):
_hook(module) del self.FSTP_global_weights[module]
del self.module_handler[module]
def _pre_backward_hook(module: nn.Module, grad_input, grad_output): def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
_hook(module) block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
def _post_backward_hook(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
for module in self.FSTP_modules: for module in self.FSTP_modules:
# import pdb; pdb.set_trace()
module.register_forward_pre_hook(_pre_forward_hook) module.register_forward_pre_hook(_pre_forward_hook)
module.register_backward_pre_hook(_pre_backward_hook) module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_module_full_backward_pre_hook(_pre_backward_hook)

View File

@ -210,7 +210,7 @@ class MHA(nn.Module):
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
process_group, process_group,
bias=True, bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel, sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs, **factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577 ) # according to https://spaces.ac.cn/archives/9577
@ -231,6 +231,7 @@ class MHA(nn.Module):
embed_dim, embed_dim,
embed_dim, embed_dim,
process_group, process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel, sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs, **factory_kwargs,
) )

View File

@ -283,11 +283,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None): def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
ctx.compute_weight_gradient = weight.requires_grad ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual ctx.return_residual = return_residual
ctx.process_group = process_group ctx.process_group = process_group
ctx.all_gather_handler = all_gather_handler
ctx.module = module
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.to(dtype=torch.get_autocast_gpu_dtype())
@ -295,14 +297,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
# do all_gather for weight and bias before actual computation total_weight = all_gather_handler.FSTP_global_weights[module]
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_bias = bias
if bias is not None: # # do all_gather for weight and bias before actual computation
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_bias.wait() # if bias is not None:
else: # total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
total_bias = bias # handle_bias.wait()
handle_weight.wait() # else:
# total_bias = bias
# handle_weight.wait()
else: else:
total_weight = weight total_weight = weight
total_bias = bias total_bias = bias
@ -332,6 +336,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
(grad_input,) = args (grad_input,) = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors x, weight = ctx.saved_tensors
total_x = x total_x = x
@ -345,8 +351,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
world_size = gpc.get_world_size(ParallelMode.TENSOR) world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1: if world_size > 1:
# do all-gather for weight before backward # do all-gather for weight before backward
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() # handle_weight.wait()
total_weight = all_gather_handler.FSTP_global_weights[module]
else: else:
total_weight = weight total_weight = weight
@ -379,7 +386,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
handle_grad_weight.wait() handle_grad_weight.wait()
if grad_bias is not None: if grad_bias is not None:
handle_grad_bias.wait() handle_grad_bias.wait()
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
def fused_dense_func_torch( def fused_dense_func_torch(
@ -401,13 +408,13 @@ def fused_dense_func_torch(
def fstp_fused_dense_func( def fstp_fused_dense_func(
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
): ):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled() x.dtype == torch.float32 and torch.is_autocast_enabled()
) )
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
else: else:
assert process_group is None assert process_group is None
out = F.linear(x, weight, bias) out = F.linear(x, weight, bias)

View File

@ -39,6 +39,7 @@ from internlm.model.linear import (
FeedForward, FeedForward,
RewardModelLinear, RewardModelLinear,
ScaleColumnParallelLinear, ScaleColumnParallelLinear,
FSTPAllGatherSyncHandler,
) )
from internlm.model.multi_head_attention import MHA from internlm.model.multi_head_attention import MHA
from internlm.model.utils import try_import_RMSNorm from internlm.model.utils import try_import_RMSNorm
@ -106,10 +107,13 @@ def initialize_model():
# if fsdp enabled, wrap the model # if fsdp enabled, wrap the model
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
if gpc.config.parallel["tensor"]["mode"] == "fstp":
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler
return model return model
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
if gpc.config.parallel.zero1.fsdp: if gpc.config.parallel.zero1.fsdp:
# set wrap_policy for fsdp wrap # set wrap_policy for fsdp wrap