mirror of https://github.com/InternLM/InternLM
support fine-grained overlap
parent
792b066f15
commit
5fd5a8a32b
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=1, fsdp=False),
|
zero1=dict(size=1, fsdp=False),
|
||||||
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,8 @@ from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw
|
||||||
|
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class ScaleColumnParallelLinear(nn.Linear):
|
||||||
|
@ -211,8 +212,7 @@ class FeedForward(nn.Module):
|
||||||
|
|
||||||
class FSTPLinear(ColumnParallelLinear):
|
class FSTPLinear(ColumnParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
import pdb; pdb.set_trace()
|
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler)
|
||||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
|
||||||
|
|
||||||
|
|
||||||
class FSTPFeedForward(nn.Module):
|
class FSTPFeedForward(nn.Module):
|
||||||
|
@ -287,6 +287,7 @@ class FSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||||
|
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.FSTP_modules = []
|
self.FSTP_modules = []
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
|
@ -306,19 +307,21 @@ class FSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
for _, children in _chunk.named_children():
|
for _, children in _chunk.named_children():
|
||||||
if isinstance(children, nn.ModuleList):
|
if isinstance(children, nn.ModuleList):
|
||||||
for _, block in enumerate(children):
|
for idx, block in enumerate(children):
|
||||||
index = 0
|
index = 0
|
||||||
sub_modules = list(block.children())
|
self.block_module[idx] = {}
|
||||||
if len(sub_modules) > 0:
|
for _, sub in block.named_children():
|
||||||
for name, child in block.named_children():
|
sub_modules = list(sub.children())
|
||||||
if isinstance(child, FSTPLinear):
|
if len(sub_modules) > 0:
|
||||||
self.FSTP_modules.append(child)
|
for name, child in sub.named_children():
|
||||||
self.module_block[child] = _
|
if isinstance(child, FSTPLinear):
|
||||||
self.block_module[_][index] = child
|
self.FSTP_modules.append(child)
|
||||||
self.module_name_index[child] = index
|
self.module_block[child] = idx
|
||||||
index = index + 1
|
self.block_module[idx][index] = child
|
||||||
else:
|
self.module_name_index[child] = index
|
||||||
continue
|
index = index + 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
|
@ -326,27 +329,58 @@ class FSTPAllGatherSyncHandler:
|
||||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _hook(module: nn.Module):
|
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||||
block_index = self.module_block[module]
|
block_index = self.module_block[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if name_index == 0:
|
if name_index == 0:
|
||||||
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
weight_handler.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
|
# start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
else:
|
else:
|
||||||
handler = self.module_handler[module]
|
handler = self.module_handler[module]
|
||||||
handler.wait()
|
handler.wait()
|
||||||
if name_index != 4:
|
if name_index != 4:
|
||||||
next_module = self.block_module[block_index][name_index + 1]
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
def _post_forward_hook(module: nn.Module, input, output):
|
||||||
_hook(module)
|
del self.FSTP_global_weights[module]
|
||||||
|
del self.module_handler[module]
|
||||||
|
|
||||||
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
||||||
_hook(module)
|
block_index = self.module_block[module]
|
||||||
|
name_index = self.module_name_index[module]
|
||||||
|
if name_index == 4:
|
||||||
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
weight_handler.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
|
# start the all-gather for next module
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
|
self.module_handler[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.module_handler[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 0:
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
|
def _post_backward_hook(module, grad_input, grad_output):
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
for module in self.FSTP_modules:
|
for module in self.FSTP_modules:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
module.register_forward_pre_hook(_pre_forward_hook)
|
module.register_forward_pre_hook(_pre_forward_hook)
|
||||||
module.register_backward_pre_hook(_pre_backward_hook)
|
module.register_forward_hook(_post_forward_hook)
|
||||||
|
# module.register_backward_pre_hook(_pre_backward_hook)
|
||||||
|
# module.register_backward_hook(_post_backward_hook)
|
||||||
|
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
||||||
|
|
|
@ -210,7 +210,7 @@ class MHA(nn.Module):
|
||||||
embed_dim,
|
embed_dim,
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
bias=True,
|
bias=False,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
) # according to https://spaces.ac.cn/archives/9577
|
) # according to https://spaces.ac.cn/archives/9577
|
||||||
|
@ -231,6 +231,7 @@ class MHA(nn.Module):
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
|
bias=False,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -283,11 +283,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
|
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None):
|
||||||
|
|
||||||
ctx.compute_weight_gradient = weight.requires_grad
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
ctx.all_gather_handler = all_gather_handler
|
||||||
|
ctx.module = module
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
|
@ -295,14 +297,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# do all_gather for weight and bias before actual computation
|
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
total_bias = bias
|
||||||
if bias is not None:
|
# # do all_gather for weight and bias before actual computation
|
||||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_bias.wait()
|
# if bias is not None:
|
||||||
else:
|
# total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||||
total_bias = bias
|
# handle_bias.wait()
|
||||||
handle_weight.wait()
|
# else:
|
||||||
|
# total_bias = bias
|
||||||
|
# handle_weight.wait()
|
||||||
else:
|
else:
|
||||||
total_weight = weight
|
total_weight = weight
|
||||||
total_bias = bias
|
total_bias = bias
|
||||||
|
@ -332,6 +336,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
(grad_input,) = args
|
(grad_input,) = args
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
|
all_gather_handler = ctx.all_gather_handler
|
||||||
|
module = ctx.module
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
x, weight = ctx.saved_tensors
|
x, weight = ctx.saved_tensors
|
||||||
total_x = x
|
total_x = x
|
||||||
|
@ -345,8 +351,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
# do all-gather for weight before backward
|
# do all-gather for weight before backward
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
# total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_weight.wait()
|
# handle_weight.wait()
|
||||||
|
total_weight = all_gather_handler.FSTP_global_weights[module]
|
||||||
else:
|
else:
|
||||||
total_weight = weight
|
total_weight = weight
|
||||||
|
|
||||||
|
@ -379,7 +386,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
handle_grad_weight.wait()
|
handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def fused_dense_func_torch(
|
def fused_dense_func_torch(
|
||||||
|
@ -401,13 +408,13 @@ def fused_dense_func_torch(
|
||||||
|
|
||||||
|
|
||||||
def fstp_fused_dense_func(
|
def fstp_fused_dense_func(
|
||||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
|
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None
|
||||||
):
|
):
|
||||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||||
)
|
)
|
||||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
|
||||||
else:
|
else:
|
||||||
assert process_group is None
|
assert process_group is None
|
||||||
out = F.linear(x, weight, bias)
|
out = F.linear(x, weight, bias)
|
||||||
|
|
|
@ -39,6 +39,7 @@ from internlm.model.linear import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
RewardModelLinear,
|
RewardModelLinear,
|
||||||
ScaleColumnParallelLinear,
|
ScaleColumnParallelLinear,
|
||||||
|
FSTPAllGatherSyncHandler,
|
||||||
)
|
)
|
||||||
from internlm.model.multi_head_attention import MHA
|
from internlm.model.multi_head_attention import MHA
|
||||||
from internlm.model.utils import try_import_RMSNorm
|
from internlm.model.utils import try_import_RMSNorm
|
||||||
|
@ -106,10 +107,13 @@ def initialize_model():
|
||||||
|
|
||||||
# if fsdp enabled, wrap the model
|
# if fsdp enabled, wrap the model
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
|
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||||
|
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
handler._register_sync_parameters_hook()
|
||||||
|
gpc.config.fstp_handler = handler
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
if gpc.config.parallel.zero1.fsdp:
|
if gpc.config.parallel.zero1.fsdp:
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
|
|
Loading…
Reference in New Issue