mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Add overlap support for gpt2 (#4535)
* add overlap support for gpt2 * remove unused code * remove unused codepull/4544/head
parent
0387a47e63
commit
e241b74f24
|
@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||||
ctx.save_for_backward(input_, weight)
|
ctx.save_for_backward(input_, weight)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
|
ctx.overlap = overlap
|
||||||
|
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
@ -312,37 +313,70 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
|
overlap = ctx.overlap
|
||||||
|
|
||||||
# TODO: overlap SP input with gradient computation
|
if not overlap:
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
total_input = input_parallel
|
total_input = input_parallel
|
||||||
grad_input = grad_output.matmul(weight.T)
|
grad_input = grad_output.matmul(weight.T)
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
if len(grad_output.shape) > 2:
|
if len(grad_output.shape) > 2:
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
# TODO: overlap SP input with gradient computation
|
if ctx.async_grad_reduce_scatter:
|
||||||
if ctx.async_grad_reduce_scatter:
|
# Asynchronous reduce-scatter
|
||||||
# Asynchronous reduce-scatter
|
input_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||||
|
device=input_parallel.device).contiguous()
|
||||||
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
|
# reduce-scatter scheduled first and have GPU resources allocated
|
||||||
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
|
||||||
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
else:
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
|
||||||
|
# do all gather in is async way
|
||||||
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||||
|
# calculate gradient and prepare data asynchronously with all-gather
|
||||||
|
# calculate
|
||||||
|
grad_input = grad_output.matmul(weight.T)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
# prepare data
|
||||||
input_list = [
|
input_list = [
|
||||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
]
|
]
|
||||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
# wait until all-gather finished
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
gather_handle.wait()
|
||||||
# reduce-scatter scheduled first and have GPU resources allocated
|
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
# do reduce-scatter in async way
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
# calculate gradient
|
||||||
|
if len(input_parallel.shape) > 2:
|
||||||
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||||
|
grad_weight = input_parallel.t().matmul(grad_output)
|
||||||
|
# wait until reduce-scatter finished
|
||||||
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
if ctx.async_grad_reduce_scatter:
|
return output, grad_weight, grad_bias, None, None, None, None
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||||
|
|
||||||
|
|
||||||
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||||
|
overlap):
|
||||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||||
async_grad_reduce_scatter, dim)
|
async_grad_reduce_scatter, dim, overlap)
|
||||||
|
|
||||||
|
|
||||||
def gather_forward_split_backward(input_, dim, process_group):
|
def gather_forward_split_backward(input_, dim, process_group):
|
||||||
|
|
|
@ -177,6 +177,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
async_communication: bool = False,
|
async_communication: bool = False,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel: bool = False,
|
seq_parallel: bool = False,
|
||||||
|
overlap: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
n_fused: int = 3,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
|
@ -190,6 +191,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
self.seq_parallel = seq_parallel
|
self.seq_parallel = seq_parallel
|
||||||
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.n_fused = n_fused
|
||||||
|
@ -308,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
if self.seq_parallel:
|
if self.seq_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||||
self.process_group, True, 1)
|
self.process_group, True, 1, self.overlap)
|
||||||
else:
|
else:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
input_parallel = reduce_backward(input_, self.process_group)
|
input_parallel = reduce_backward(input_, self.process_group)
|
||||||
|
|
|
@ -226,22 +226,3 @@ class Policy(ABC):
|
||||||
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
||||||
|
|
||||||
return [start_idx, end_idx]
|
return [start_idx, end_idx]
|
||||||
|
|
||||||
def append_seq_parallel_to_policy(
|
|
||||||
self,
|
|
||||||
suffix_list: List[str],
|
|
||||||
module_policy_description: ModulePolicyDescription,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Append the sequence parallel policy to the policy for the given key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
suffix_list (List[str]): the suffix list of the module to be parallelized
|
|
||||||
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
|
||||||
"""
|
|
||||||
|
|
||||||
for sub_description in module_policy_description.sub_module_replacement:
|
|
||||||
if (sub_description.suffix in suffix_list):
|
|
||||||
if sub_description.kwargs is None:
|
|
||||||
sub_description.kwargs = {}
|
|
||||||
sub_description.kwargs["seq_parallel"] = True
|
|
||||||
|
|
|
@ -37,7 +37,8 @@ class GPT2Policy(Policy):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
|
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -50,47 +51,54 @@ class GPT2Policy(Policy):
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|
||||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
policy[GPT2Block] = ModulePolicyDescription(
|
||||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
attribute_replacement={
|
||||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
},
|
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
sub_module_replacement=[
|
},
|
||||||
SubModuleReplacementDescription(
|
sub_module_replacement=[
|
||||||
suffix="attn.c_attn",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
suffix="attn.c_attn",
|
||||||
kwargs={
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
"n_fused": 3,
|
kwargs={
|
||||||
},
|
"n_fused": 3,
|
||||||
),
|
"seq_parallel": use_sequence_parallel,
|
||||||
SubModuleReplacementDescription(
|
"overlap": overlap
|
||||||
suffix="attn.c_proj",
|
},
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
),
|
||||||
),
|
SubModuleReplacementDescription(suffix="attn.c_proj",
|
||||||
SubModuleReplacementDescription(
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
suffix="mlp.c_fc",
|
kwargs={
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
"seq_parallel": use_sequence_parallel,
|
||||||
kwargs={
|
}),
|
||||||
"n_fused": 1,
|
SubModuleReplacementDescription(
|
||||||
},
|
suffix="mlp.c_fc",
|
||||||
),
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
SubModuleReplacementDescription(
|
kwargs={
|
||||||
suffix="mlp.c_proj",
|
"n_fused": 1,
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
"seq_parallel": use_sequence_parallel,
|
||||||
),
|
"overlap": overlap
|
||||||
SubModuleReplacementDescription(
|
},
|
||||||
suffix="attn.attn_dropout",
|
),
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
SubModuleReplacementDescription(suffix="mlp.c_proj",
|
||||||
),
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
SubModuleReplacementDescription(
|
kwargs={
|
||||||
suffix="attn.resid_dropout",
|
"seq_parallel": use_sequence_parallel,
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
}),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="attn.attn_dropout",
|
||||||
suffix="mlp.dropout",
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
])
|
suffix="attn.resid_dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.dropout",
|
||||||
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
@ -126,8 +134,6 @@ class GPT2Policy(Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||||
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
|
||||||
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||||
return rearanged_tensor
|
return rearanged_tensor
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
|
@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
seq_parallel=seq_parallel,
|
seq_parallel=seq_parallel,
|
||||||
n_fused=3)
|
n_fused=3,
|
||||||
|
overlap=overlap)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
assert linear.bias.shape == torch.Size([192])
|
assert linear.bias.shape == torch.Size([192])
|
||||||
|
@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
@parameterize('lazy_init', [False, True])
|
||||||
@parameterize('seq_parallel', [False, True])
|
@parameterize('seq_parallel', [False, True])
|
||||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
|
@parameterize('overlap', [True])
|
||||||
check_linear_conv_1d_col(lazy_init, seq_parallel)
|
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
|
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
|
||||||
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue