fix format parallel_2p5d (#357)

pull/394/head
Yuer867 2022-03-09 21:42:30 +08:00 committed by Frank Lee
parent 7eb87f516d
commit 4a0f8c2c50
3 changed files with 81 additions and 49 deletions

View File

@ -26,20 +26,20 @@ class _Classifier2p5D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
ctx: Any, ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
bias, bias,
tesseract_dim: int, tesseract_dim: int,
out_shape: Tuple[int, ...], out_shape: Tuple[int, ...],
row_rank: int, row_rank: int,
col_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_rank: int,
pipeline_parallel_size: int, pipeline_parallel_size: int,
tensor_parallel_size: int, tensor_parallel_size: int,
) -> Tensor: ) -> Tensor:
A_shape = A.shape A_shape = A.shape
@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size :param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int :type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -197,10 +198,14 @@ class Matmul_AB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode) row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = \
pipeline_parallel_rank * tensor_parallel_size tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_b = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opb = [None] * 2 opb = [None] * 2
@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size :param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int :type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -323,10 +329,14 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode) row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_b = \
pipeline_parallel_rank * tensor_parallel_size col_rank + tesseract_dim ** 2 * dep_rank + \
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2 opb = [None] * 2
opr = [None] * 2 opr = [None] * 2
@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size :param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int :type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
@ -457,10 +468,14 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode) row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = \
pipeline_parallel_rank * tensor_parallel_size tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opr = [None] * 2 opr = [None] * 2
@ -540,8 +555,10 @@ class _Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone() bias_temp = bias.clone()
else: else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_rank = \
pipeline_parallel_rank * tensor_parallel_size col_rank + dep_rank * tesseract_dim ** 2 + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
ctx.row_rank = row_rank ctx.row_rank = row_rank
@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function):
tensor_parallel_size = ctx.tensor_parallel_size tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias: if ctx.bias:
dst_rank = col_rank + dep_rank * ( dst_rank = \
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ col_rank + dep_rank * (tesseract_dim ** 2) + \
pipeline_parallel_rank * tensor_parallel_size data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0: if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None return \
None, output_grad, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else: else:
grad_tmp = torch.zeros_like(output_grad) grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None return \
None, grad_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else: else:
reduce_dim = tuple(range(output_grad.ndim - 1)) reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim) reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + dep_rank * ( dst_rank = \
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ col_rank + dep_rank * (tesseract_dim ** 2) + \
pipeline_parallel_rank * tensor_parallel_size data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0: if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return \
output_grad, reduce, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
else: else:
reduce_tmp = torch.zeros_like(reduce) reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return \
output_grad, reduce_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, None
def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion
:type skip_bias_add: bool :type skip_bias_add: bool
:param data_parallel_rank: data parallel rank :param data_parallel_rank: data parallel rank
:type data_parallel_rank: int :type data_parallel_rank: int
@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function):
:param row_parallel_mode: row parallel mode :param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function):
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:] grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(), output_grad.contiguous(),
@ -775,10 +805,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
@ -801,7 +831,7 @@ class _ReduceTensor2p5D(torch.autograd.Function):
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input. All-reduce the input.
:param input_: input tensor :param input_: input tensor
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
""" """
@ -823,7 +853,7 @@ class _ReduceScatterTensor2p5D(torch.autograd.Function):
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
Reduce-scatter the input. Reduce-scatter the input.
:param input_: input tensor :param input_: input tensor
:param parallel_mode: parallel mode :param parallel_mode: parallel mode
""" """
@ -868,4 +898,4 @@ def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor:
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional :type reduce_mean: bool, optional
""" """
return _RreduceByBatch2p5D.apply(input_, reduce_mean) return _RreduceByBatch2p5D.apply(input_, reduce_mean)

View File

@ -21,4 +21,5 @@ def assert_tesseract_initialization():
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer' 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \
'must be initialized by the process group initializer'

View File

@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer):
r""" r"""
Layer Normalization for 2.5D parallelism Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input :param normalized_shape: input shape from an expected input of size.
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int :type normalized_shape: int
@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \ if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: self.vocab_start_index <= self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)