mirror of https://github.com/hpcaitech/ColossalAI
fix format parallel_2p5d (#357)
parent
7eb87f516d
commit
4a0f8c2c50
|
@ -26,20 +26,20 @@ class _Classifier2p5D(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(
|
def forward(
|
||||||
ctx: Any,
|
ctx: Any,
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
bias,
|
bias,
|
||||||
tesseract_dim: int,
|
tesseract_dim: int,
|
||||||
out_shape: Tuple[int, ...],
|
out_shape: Tuple[int, ...],
|
||||||
row_rank: int,
|
row_rank: int,
|
||||||
col_rank: int,
|
col_rank: int,
|
||||||
row_parallel_mode: ParallelMode,
|
row_parallel_mode: ParallelMode,
|
||||||
col_parallel_mode: ParallelMode,
|
col_parallel_mode: ParallelMode,
|
||||||
data_parallel_rank: int,
|
data_parallel_rank: int,
|
||||||
pipeline_parallel_rank: int,
|
pipeline_parallel_rank: int,
|
||||||
pipeline_parallel_size: int,
|
pipeline_parallel_size: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
A_shape = A.shape
|
A_shape = A.shape
|
||||||
|
@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||||
:param tensor_parallel_size: tensor parallel size
|
:param tensor_parallel_size: tensor parallel size
|
||||||
:type tensor_parallel_size: int
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||||
|
@ -197,10 +198,14 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||||
row_group = gpc.get_group(row_parallel_mode)
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_a = \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_b = \
|
||||||
|
col_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opa = [None] * 2
|
opa = [None] * 2
|
||||||
opb = [None] * 2
|
opb = [None] * 2
|
||||||
|
@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||||
:param tensor_parallel_size: tensor parallel size
|
:param tensor_parallel_size: tensor parallel size
|
||||||
:type tensor_parallel_size: int
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||||
|
@ -323,10 +329,14 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||||
row_group = gpc.get_group(row_parallel_mode)
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_b = \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
col_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_c = \
|
||||||
|
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opb = [None] * 2
|
opb = [None] * 2
|
||||||
opr = [None] * 2
|
opr = [None] * 2
|
||||||
|
@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||||
:param tensor_parallel_size: tensor parallel size
|
:param tensor_parallel_size: tensor parallel size
|
||||||
:type tensor_parallel_size: int
|
:type tensor_parallel_size: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||||
|
@ -457,10 +468,14 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||||
row_group = gpc.get_group(row_parallel_mode)
|
row_group = gpc.get_group(row_parallel_mode)
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_a = \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
src_c = \
|
||||||
|
col_rank + tesseract_dim ** 2 * dep_rank + \
|
||||||
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opa = [None] * 2
|
opa = [None] * 2
|
||||||
opr = [None] * 2
|
opr = [None] * 2
|
||||||
|
@ -540,8 +555,10 @@ class _Add_Bias_2p5D(torch.autograd.Function):
|
||||||
bias_temp = bias.clone()
|
bias_temp = bias.clone()
|
||||||
else:
|
else:
|
||||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||||
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_rank = \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
col_rank + dep_rank * tesseract_dim ** 2 + \
|
||||||
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
||||||
|
|
||||||
ctx.row_rank = row_rank
|
ctx.row_rank = row_rank
|
||||||
|
@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function):
|
||||||
tensor_parallel_size = ctx.tensor_parallel_size
|
tensor_parallel_size = ctx.tensor_parallel_size
|
||||||
|
|
||||||
if ctx.bias:
|
if ctx.bias:
|
||||||
dst_rank = col_rank + dep_rank * (
|
dst_rank = \
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
col_rank + dep_rank * (tesseract_dim ** 2) + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
||||||
if row_rank == 0:
|
if row_rank == 0:
|
||||||
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return \
|
||||||
|
None, output_grad, None, None, None, None, None, None, \
|
||||||
|
None, None, None, None, None, None, None, None
|
||||||
else:
|
else:
|
||||||
grad_tmp = torch.zeros_like(output_grad)
|
grad_tmp = torch.zeros_like(output_grad)
|
||||||
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return \
|
||||||
|
None, grad_tmp, None, None, None, None, None, None, \
|
||||||
|
None, None, None, None, None, None, None, None
|
||||||
else:
|
else:
|
||||||
reduce_dim = tuple(range(output_grad.ndim - 1))
|
reduce_dim = tuple(range(output_grad.ndim - 1))
|
||||||
reduce = torch.sum(output_grad, dim=reduce_dim)
|
reduce = torch.sum(output_grad, dim=reduce_dim)
|
||||||
dst_rank = col_rank + dep_rank * (
|
dst_rank = \
|
||||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
col_rank + dep_rank * (tesseract_dim ** 2) + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
|
||||||
if row_rank == 0:
|
if row_rank == 0:
|
||||||
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return \
|
||||||
|
output_grad, reduce, None, None, None, None, None, None, None, \
|
||||||
|
None, None, None, None, None, None, None, None
|
||||||
else:
|
else:
|
||||||
reduce_tmp = torch.zeros_like(reduce)
|
reduce_tmp = torch.zeros_like(reduce)
|
||||||
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return \
|
||||||
|
output_grad, reduce_tmp, None, None, None, None, None, None, \
|
||||||
|
None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
|
def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
|
||||||
|
@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
|
||||||
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
:param col_parallel_mode: column parallel mode
|
:param col_parallel_mode: column parallel mode
|
||||||
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
|
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
which is preserved for kernel fusion
|
||||||
:type skip_bias_add: bool
|
:type skip_bias_add: bool
|
||||||
:param data_parallel_rank: data parallel rank
|
:param data_parallel_rank: data parallel rank
|
||||||
:type data_parallel_rank: int
|
:type data_parallel_rank: int
|
||||||
|
@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function):
|
||||||
:param row_parallel_mode: row parallel mode
|
:param row_parallel_mode: row parallel mode
|
||||||
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
||||||
|
@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function):
|
||||||
:param col_parallel_mode: column parallel mode
|
:param col_parallel_mode: column parallel mode
|
||||||
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float16)
|
@custom_fwd(cast_inputs=torch.float16)
|
||||||
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||||
|
@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
|
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||||
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
|
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
|
||||||
output_grad.contiguous(),
|
output_grad.contiguous(),
|
||||||
|
@ -775,10 +805,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||||
|
|
||||||
:param input_: Input tensor
|
:param input_: Input tensor
|
||||||
:param dim: Specified dimension in which to split
|
:param dim: Specified dimension in which to split
|
||||||
|
|
||||||
:type input_: torch.Tensor
|
:type input_: torch.Tensor
|
||||||
:type dim: int, optional
|
:type dim: int, optional
|
||||||
|
|
||||||
:return output: Splitted tensor
|
:return output: Splitted tensor
|
||||||
:rtype output: torch.Tensor
|
:rtype output: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
@ -801,7 +831,7 @@ class _ReduceTensor2p5D(torch.autograd.Function):
|
||||||
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
All-reduce the input.
|
All-reduce the input.
|
||||||
|
|
||||||
:param input_: input tensor
|
:param input_: input tensor
|
||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
"""
|
"""
|
||||||
|
@ -823,7 +853,7 @@ class _ReduceScatterTensor2p5D(torch.autograd.Function):
|
||||||
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Reduce-scatter the input.
|
Reduce-scatter the input.
|
||||||
|
|
||||||
:param input_: input tensor
|
:param input_: input tensor
|
||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
"""
|
"""
|
||||||
|
@ -868,4 +898,4 @@ def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor:
|
||||||
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
|
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
|
||||||
:type reduce_mean: bool, optional
|
:type reduce_mean: bool, optional
|
||||||
"""
|
"""
|
||||||
return _RreduceByBatch2p5D.apply(input_, reduce_mean)
|
return _RreduceByBatch2p5D.apply(input_, reduce_mean)
|
||||||
|
|
|
@ -21,4 +21,5 @@ def assert_tesseract_initialization():
|
||||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
|
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
|
||||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
|
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
|
||||||
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
|
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
|
||||||
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'
|
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \
|
||||||
|
'must be initialized by the process group initializer'
|
||||||
|
|
|
@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer):
|
||||||
r"""
|
r"""
|
||||||
Layer Normalization for 2.5D parallelism
|
Layer Normalization for 2.5D parallelism
|
||||||
|
|
||||||
:param normalized_shape: input shape from an expected input
|
:param normalized_shape: input shape from an expected input of size.
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
:type normalized_shape: int
|
:type normalized_shape: int
|
||||||
|
@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None and \
|
if self.padding_idx is not None and \
|
||||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
self.vocab_start_index <= self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue