fix format parallel_2p5d (#357)

pull/394/head
Yuer867 2022-03-09 21:42:30 +08:00 committed by Frank Lee
parent 7eb87f516d
commit 4a0f8c2c50
3 changed files with 81 additions and 49 deletions

View File

@ -26,20 +26,20 @@ class _Classifier2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
A_shape = A.shape
@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -197,10 +198,14 @@ class Matmul_AB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_a = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -323,10 +329,14 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -457,10 +468,14 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_a = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
@ -540,8 +555,10 @@ class _Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_rank = \
col_rank + dep_rank * tesseract_dim ** 2 + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
ctx.row_rank = row_rank
@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function):
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dst_rank = \
col_rank + dep_rank * (tesseract_dim ** 2) + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return \
None, output_grad, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else:
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return \
None, grad_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dst_rank = \
col_rank + dep_rank * (tesseract_dim ** 2) + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return \
output_grad, reduce, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else:
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return \
output_grad, reduce_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, None
def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function):
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function):
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
@ -775,10 +805,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
:param input_: Input tensor
:param dim: Specified dimension in which to split
:type input_: torch.Tensor
:type dim: int, optional
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
@ -801,7 +831,7 @@ class _ReduceTensor2p5D(torch.autograd.Function):
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
@ -823,7 +853,7 @@ class _ReduceScatterTensor2p5D(torch.autograd.Function):
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
@ -868,4 +898,4 @@ def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor:
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional
"""
return _RreduceByBatch2p5D.apply(input_, reduce_mean)
return _RreduceByBatch2p5D.apply(input_, reduce_mean)

View File

@ -21,4 +21,5 @@ def assert_tesseract_initialization():
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \
'must be initialized by the process group initializer'

View File

@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer):
r"""
Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:param normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
self.vocab_start_index <= self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)