updated tp layers

pull/1790/head
kurisusnowdeng 2022-10-26 20:54:39 +08:00 committed by アマデウス
parent cb5a587e9a
commit 0b8161fab8
13 changed files with 645 additions and 293 deletions

View File

@ -23,6 +23,8 @@ INITIALIZER_MAPPING = {
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'

View File

@ -39,6 +39,8 @@ class ParallelMode(Enum):
PARALLEL_3D_INPUT = '3d_input'
PARALLEL_3D_WEIGHT = '3d_weight'
PARALLEL_3D_OUTPUT = '3d_output'
PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight"
PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight"
# 2.5D parallel
PARALLEL_2P5D_ROW = '2p5d_row'

View File

@ -176,6 +176,112 @@ class Initializer_3D_Output(ProcessGroupInitializer):
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_InputxWeight(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT
env.input_x_weight_group_3d = mode
for h in range(self.num_group):
for k in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)
for i in range(self.depth)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
self.depth = depth
def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT
env.output_x_weight_group_3d = mode
for h in range(self.num_group):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)
for i in range(self.depth)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_3D(ProcessGroupInitializer):
"""Serve as the single entry point to 3D parallel initialization.
@ -200,6 +306,8 @@ class Initializer_3D(ProcessGroupInitializer):
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args)
self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args)
def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
@ -211,6 +319,8 @@ class Initializer_3D(ProcessGroupInitializer):
parallel_setting = [
self.input_initializer.init_dist_group(),
self.weight_initializer.init_dist_group(),
self.output_initializer.init_dist_group()
self.output_initializer.init_dist_group(),
self.input_x_weight_initializer.init_dist_group(),
self.output_x_weight_initializer.init_dist_group()
]
return parallel_setting

View File

@ -22,7 +22,9 @@ class TensorParallelEnv(object):
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
output_group_3d=None,
input_x_weight_group_3d=None,
output_x_weight_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
@ -33,6 +35,8 @@ class TensorParallelEnv(object):
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
self.input_x_weight_group_3d = input_x_weight_group_3d
self.output_x_weight_group_3d = output_x_weight_group_3d
def save(self):
return dict(mode=self.mode,
@ -44,7 +48,9 @@ class TensorParallelEnv(object):
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
output_group_3d=self.output_group_3d,
input_x_weight_group_3d=self.input_x_weight_group_3d,
output_x_weight_group_3d=self.output_x_weight_group_3d)
tensor_parallel_env = TensorParallelEnv()

View File

@ -1,4 +1,6 @@
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)

View File

@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
from ._operation import linear_with_async_comm
@LAYERS.register_module
@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
from apex.normalization import FusedLayerNorm
fast_ln_installed = False
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
fast_ln_installed = True
except ImportError:
pass
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
else:
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer):
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer):
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:

View File

@ -9,7 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from ._utils import get_parallel_mode_from_env, push_async_grad
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
@ -17,34 +17,27 @@ class _Linear3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode)
weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
if bias is not None:
output += bias
def forward(
ctx,
input_: Tensor,
weight: Tensor,
weight_id: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, -1, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, 0, output_parallel_mode)
return output
@staticmethod
@ -52,73 +45,70 @@ class _Linear3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
async_ops = list()
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
async_ops.append(op)
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
bias_grad = None
input_op.wait()
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
return input_grad, weight_grad, None, None, None, None
def linear_3d(input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
def linear_3d(
input_: Tensor,
weight: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""Linear layer for 3D parallelism.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_dim (int, optional): dimension of input, defaults to 0.
weight_dim (int, optional): dimension of weight, defaults to -1.
output_dim (int, optional): dimension of output, defaults to 0.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
input_dim, weight_dim, output_dim)
return _Linear3D.apply(
input_,
weight,
id(weight),
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight)
@ -126,6 +116,7 @@ class _Classifier3D(torch.autograd.Function):
output = all_reduce(output, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.src_rank = src_rank
@ -139,14 +130,12 @@ class _Classifier3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
async_ops = list()
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
else:
weight_grad = None
@ -154,21 +143,23 @@ class _Classifier3D(torch.autograd.Function):
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight)
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
def classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel classifier.
Args:
@ -183,16 +174,134 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
return _Classifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _VocabParallelClassifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, 0, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def vocab_parallel_classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D vocab parallel classifier.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _VocabParallelClassifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Layernorm3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Tensor,
weight_id: int,
bias_id: int,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.bias_id = bias_id
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
@ -201,15 +310,13 @@ class _Layernorm3D(torch.autograd.Function):
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z
if bias is not None:
output = output + bias
output = weight * z + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
return output
@ -218,17 +325,14 @@ class _Layernorm3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
weight_grad = output_grad * mu / sigma
if ctx.use_bias:
bias_grad = output_grad
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
else:
bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
@ -236,15 +340,22 @@ class _Layernorm3D(torch.autograd.Function):
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / \
ctx.normalized_shape + dmean / ctx.normalized_shape
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
def layernorm_3d(
input_: Tensor,
weight: Tensor,
bias: Tensor,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel Layernorm.
Args:
@ -265,8 +376,19 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
return _Layernorm3D.apply(
input_,
weight,
bias,
id(weight),
id(bias),
normalized_shape,
eps,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
input_x_weight_parallel_mode,
)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
@ -315,17 +437,12 @@ def split_batch_3d(input_: Tensor,
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
dim_size = input_.size(dim)
if input_.size(dim) <= 1:
return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_world_size = gpc.get_world_size(weight_parallel_mode)
input_world_size = gpc.get_world_size(input_parallel_mode)
assert dim_size % (input_world_size*weight_world_size) == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
@ -464,47 +581,3 @@ def reduce_by_batch_3d(tensor: Tensor,
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
r"""broadcast weight from diagonal.
Args:
input_ (:class:`torch.tensor`): input matrix.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
output = broadcast(input_, src_rank, input_parallel_mode)
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
else:
input_grad = None
return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)

View File

@ -1,8 +1,13 @@
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from collections import OrderedDict
from functools import partial
import torch
from torch import Tensor
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor
def get_depth_from_env() -> int:
@ -17,30 +22,17 @@ def get_depth_from_env() -> int:
def get_parallel_mode_from_env(group):
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
env.output_x_weight_group_3d,
env.input_x_weight_group_3d,
)
def dbg_check_shape(tensor: Tensor, shape: tuple):
@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)
class AsyncGradientBucket(object):
def __init__(self):
self.bucket = OrderedDict()
def __len__(self):
return len(self.bucket)
def push(self, async_op, grad_tensor, param_id):
self.bucket[param_id] = tuple((async_op, grad_tensor))
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
def pop(self, param_id):
grad = None
if param_id in self.bucket:
op, grad = self.bucket.pop(param_id)
if op is not None:
op.wait()
return grad
def synchronize(self, params):
for p in params:
i = id(p)
if i in self.bucket:
op, grad = self.bucket.pop(i)
if op is not None:
op.wait()
p.grad.add_(grad)
_async_grad_bucket = AsyncGradientBucket()
def push_async_grad(op, grad, param_id):
return _async_grad_bucket.push(op, grad, param_id)
def pop_async_grad(param_id):
return _async_grad_bucket.pop(param_id)
def _async_grad_hook(grad, param_id):
grad.add_(pop_async_grad(param_id))
return grad
def register_async_grad_hook(param):
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
def synchronize(params=list()):
_async_grad_bucket.synchronize(params)
torch.cuda.default_stream().synchronize()
if len(_async_grad_bucket) > 0:
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
@ -20,9 +20,9 @@ from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
@LAYERS.register_module
@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer):
else:
self.bias = None
self.variance_epsilon = eps
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer):
def reset_parameters(self) -> None:
init.ones_()(self.weight)
register_async_grad_hook(self.weight)
if self.bias is not None:
init.zeros_()(self.bias)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
return layernorm_3d(
input_,
self.weight,
self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
self.input_x_weight_parallel_mode,
)
@LAYERS.register_module
@ -161,6 +174,7 @@ class Linear3D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@ -168,8 +182,10 @@ class Linear3D(ParallelLayer):
self.out_features = out_features
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.bias_features_per_partition = divide(out_features, self.depth)
@ -194,18 +210,23 @@ class Linear3D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -324,8 +345,20 @@ class Linear3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = linear_3d(
input_,
self.weight,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
@LAYERS.register_module
@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer):
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
return classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth**2)
@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
return vocab_parallel_classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.patch_size = to_2tuple(patch_size)
grid_size = to_2tuple(img_size // patch_size)
num_patches = grid_size[0] * grid_size[1]
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_size = embed_size
embed_size_per_partition = divide(embed_size, self.depth)
embed_size_per_partition = embed_size // self.depth
self.flatten = flatten
self.weight = nn.Parameter(
@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode)
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook)
@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight,
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)

View File

@ -6,12 +6,12 @@ RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
# install apex
RUN git clone https://github.com/NVIDIA/apex && \
cd apex && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
# install colossalai
RUN git clone https://github.com/hpcaitech/ColossalAI.git \
&& cd ./ColossalAI \
&& pip install -v --no-cache-dir .
&& cd ./ColossalAI \
&& pip install -v --no-cache-dir .
# install titans
RUN pip install --no-cache-dir titans

View File

@ -20,7 +20,6 @@ def check_linear():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@ -32,12 +31,12 @@ def check_linear():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data.transpose(0, 1)
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
@ -49,7 +48,7 @@ def check_linear():
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@ -72,7 +71,7 @@ def check_linear():
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@ -108,7 +107,6 @@ def check_layernorm():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -119,7 +117,7 @@ def check_layernorm():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
@ -134,7 +132,7 @@ def check_layernorm():
norm.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@ -159,7 +157,7 @@ def check_layernorm():
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@ -193,7 +191,6 @@ def check_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -204,10 +201,10 @@ def check_classifier_no_given_weight():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
@ -219,7 +216,7 @@ def check_classifier_no_given_weight():
layer.bias.data.copy_(bias_master)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@ -242,7 +239,7 @@ def check_classifier_no_given_weight():
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight():
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight():
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight():
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
k = global_context.get_local_rank(output_parallel_mode)
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed = embed.to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
embed_master = embed_master.to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@ -546,12 +541,12 @@ def check_patch_embed():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
@ -566,7 +561,7 @@ def check_patch_embed():
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
@ -586,7 +581,7 @@ def check_patch_embed():
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@ -639,9 +634,9 @@ def check_embed():
k = global_context.get_local_rank(output_parallel_mode)
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@ -669,7 +664,7 @@ def check_embed():
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@ -686,10 +681,7 @@ def check_embed():
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
@ -709,9 +701,9 @@ def check_vocab_parallel_embed():
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@ -741,7 +733,7 @@ def check_vocab_parallel_embed():
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@ -771,7 +763,6 @@ def check_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@ -783,8 +774,8 @@ def check_loss():
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@ -836,8 +827,8 @@ def check_vocab_parallel_loss():
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]

View File

@ -12,8 +12,8 @@ NUM_BLOCKS = 2
IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq
return eq
assert eq, f"\nA = {A}\nB = {B}"
return eq

View File

@ -10,9 +10,8 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear,
check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
@ -30,7 +29,6 @@ def check_layer():
check_layernorm()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_embed()
check_patch_embed()