mirror of https://github.com/hpcaitech/ColossalAI
updated tp layers
parent
cb5a587e9a
commit
0b8161fab8
|
@ -23,6 +23,8 @@ INITIALIZER_MAPPING = {
|
|||
INPUT_GROUP_3D = 'input_group_3d'
|
||||
WEIGHT_GROUP_3D = 'weight_group_3d'
|
||||
OUTPUT_GROUP_3D = 'output_group_3d'
|
||||
INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
|
||||
OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
|
||||
|
||||
# Attributes of tensor parallel parameters
|
||||
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
|
||||
|
|
|
@ -39,6 +39,8 @@ class ParallelMode(Enum):
|
|||
PARALLEL_3D_INPUT = '3d_input'
|
||||
PARALLEL_3D_WEIGHT = '3d_weight'
|
||||
PARALLEL_3D_OUTPUT = '3d_output'
|
||||
PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight"
|
||||
PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight"
|
||||
|
||||
# 2.5D parallel
|
||||
PARALLEL_2P5D_ROW = '2p5d_row'
|
||||
|
|
|
@ -176,6 +176,112 @@ class Initializer_3D_Output(ProcessGroupInitializer):
|
|||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_InputxWeight(ProcessGroupInitializer):
|
||||
"""3D tensor parallel initialization among input.
|
||||
|
||||
Args:
|
||||
num_group (int): The number of all tensor groups.
|
||||
depth (int): Depth of 3D parallelism.
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
config (Config): Running configuration.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
self.depth = depth
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
3D tensor parallelism's information among input in a tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT
|
||||
env.input_x_weight_group_3d = mode
|
||||
|
||||
for h in range(self.num_group):
|
||||
for k in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)
|
||||
for i in range(self.depth)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
|
||||
"""3D tensor parallel initialization among input.
|
||||
|
||||
Args:
|
||||
num_group (int): The number of all tensor groups.
|
||||
depth (int): Depth of 3D parallelism.
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
config (Config): Running configuration.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, num_group: int, depth: int, *args):
|
||||
super().__init__(*args)
|
||||
self.num_group = num_group
|
||||
self.depth = depth
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
3D tensor parallelism's information among input in a tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT
|
||||
env.output_x_weight_group_3d = mode
|
||||
|
||||
for h in range(self.num_group):
|
||||
for j in range(self.depth):
|
||||
ranks = [
|
||||
h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)
|
||||
for i in range(self.depth)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_3D(ProcessGroupInitializer):
|
||||
"""Serve as the single entry point to 3D parallel initialization.
|
||||
|
@ -200,6 +306,8 @@ class Initializer_3D(ProcessGroupInitializer):
|
|||
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
|
||||
self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
|
||||
self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
|
||||
self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args)
|
||||
self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args)
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
@ -211,6 +319,8 @@ class Initializer_3D(ProcessGroupInitializer):
|
|||
parallel_setting = [
|
||||
self.input_initializer.init_dist_group(),
|
||||
self.weight_initializer.init_dist_group(),
|
||||
self.output_initializer.init_dist_group()
|
||||
self.output_initializer.init_dist_group(),
|
||||
self.input_x_weight_initializer.init_dist_group(),
|
||||
self.output_x_weight_initializer.init_dist_group()
|
||||
]
|
||||
return parallel_setting
|
||||
|
|
|
@ -22,7 +22,9 @@ class TensorParallelEnv(object):
|
|||
depth_3d: int = None,
|
||||
input_group_3d=None,
|
||||
weight_group_3d=None,
|
||||
output_group_3d=None):
|
||||
output_group_3d=None,
|
||||
input_x_weight_group_3d=None,
|
||||
output_x_weight_group_3d=None):
|
||||
self.mode = mode
|
||||
self.vocab_parallel = vocab_parallel
|
||||
self.parallel_input_1d = parallel_input_1d
|
||||
|
@ -33,6 +35,8 @@ class TensorParallelEnv(object):
|
|||
self.input_group_3d = input_group_3d
|
||||
self.weight_group_3d = weight_group_3d
|
||||
self.output_group_3d = output_group_3d
|
||||
self.input_x_weight_group_3d = input_x_weight_group_3d
|
||||
self.output_x_weight_group_3d = output_x_weight_group_3d
|
||||
|
||||
def save(self):
|
||||
return dict(mode=self.mode,
|
||||
|
@ -44,7 +48,9 @@ class TensorParallelEnv(object):
|
|||
depth_3d=self.depth_3d,
|
||||
input_group_3d=self.input_group_3d,
|
||||
weight_group_3d=self.weight_group_3d,
|
||||
output_group_3d=self.output_group_3d)
|
||||
output_group_3d=self.output_group_3d,
|
||||
input_x_weight_group_3d=self.input_x_weight_group_3d,
|
||||
output_x_weight_group_3d=self.output_x_weight_group_3d)
|
||||
|
||||
|
||||
tensor_parallel_env = TensorParallelEnv()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
|
@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
|||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
||||
|
|
|
@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device
|
|||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
from ._operation import linear_with_async_comm
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule):
|
|||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
_fast_ln_supported_sizes = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
||||
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
||||
]
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
fast_ln_installed = False
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
fast_ln_installed = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
|
||||
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
else:
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
|
@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer):
|
|||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
# output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from ._utils import get_parallel_mode_from_env
|
||||
from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
|
||||
|
||||
|
@ -17,34 +17,27 @@ class _Linear3D(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||
weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
weight_id: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.weight_id = weight_id
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_dim = input_dim
|
||||
ctx.weight_dim = weight_dim
|
||||
ctx.output_dim = output_dim
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight, -1, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, 0, output_parallel_mode)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -52,73 +45,70 @@ class _Linear3D(torch.autograd.Function):
|
|||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
|
||||
|
||||
async_ops = list()
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
else:
|
||||
bias_grad = None
|
||||
input_op.wait()
|
||||
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
return input_grad, weight_grad, None, None, None, None
|
||||
|
||||
|
||||
def linear_3d(input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
def linear_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""Linear layer for 3D parallelism.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
weight (:class:`torch.tensor`): matrix of weight.
|
||||
bias (:class:`torch.tensor`): matrix of bias.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
input_dim (int, optional): dimension of input, defaults to 0.
|
||||
weight_dim (int, optional): dimension of weight, defaults to -1.
|
||||
output_dim (int, optional): dimension of output, defaults to 0.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
|
||||
input_dim, weight_dim, output_dim)
|
||||
return _Linear3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
id(weight),
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _Classifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
weight_id: int,
|
||||
bias_id: Optional[int],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.weight_id = weight_id
|
||||
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
|
||||
weight = broadcast(weight, src_rank, input_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
|
@ -126,6 +116,7 @@ class _Classifier3D(torch.autograd.Function):
|
|||
output = all_reduce(output, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
ctx.bias_id = bias_id
|
||||
output += bias
|
||||
|
||||
ctx.src_rank = src_rank
|
||||
|
@ -139,14 +130,12 @@ class _Classifier3D(torch.autograd.Function):
|
|||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
async_ops = list()
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
else:
|
||||
weight_grad = None
|
||||
|
||||
|
@ -154,21 +143,23 @@ class _Classifier3D(torch.autograd.Function):
|
|||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight)
|
||||
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def classifier_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D parallel classifier.
|
||||
|
||||
Args:
|
||||
|
@ -183,16 +174,134 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
|
||||
return _Classifier3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias) if bias is not None else None,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
weight_id: int,
|
||||
bias_id: Optional[int],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.weight_id = weight_id
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, 0, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
ctx.bias_id = bias_id
|
||||
output += bias
|
||||
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def vocab_parallel_classifier_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D vocab parallel classifier.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
weight (:class:`torch.tensor`): matrix of weight.
|
||||
bias (:class:`torch.tensor`): matrix of bias.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _VocabParallelClassifier3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias) if bias is not None else None,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _Layernorm3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
weight_id: int,
|
||||
bias_id: int,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.weight_id = weight_id
|
||||
ctx.bias_id = bias_id
|
||||
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
mu = input_ - mean
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
|
@ -201,15 +310,13 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
output = weight * z + bias
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
|
||||
|
||||
return output
|
||||
|
||||
|
@ -218,17 +325,14 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
weight_grad = output_grad * mu / sigma
|
||||
if ctx.use_bias:
|
||||
bias_grad = output_grad
|
||||
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||
else:
|
||||
bias_grad = None
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
|
||||
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
|
||||
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
|
||||
if ctx.use_bias:
|
||||
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
|
||||
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
|
@ -236,15 +340,22 @@ class _Layernorm3D(torch.autograd.Function):
|
|||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
|
||||
input_grad = dz / sigma + dvar * 2 * mu / \
|
||||
ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def layernorm_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D parallel Layernorm.
|
||||
|
||||
Args:
|
||||
|
@ -265,8 +376,19 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
|
||||
output_parallel_mode)
|
||||
return _Layernorm3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias),
|
||||
normalized_shape,
|
||||
eps,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
input_x_weight_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
|
@ -315,17 +437,12 @@ def split_batch_3d(input_: Tensor,
|
|||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
dim_size = input_.size(dim)
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_world_size = gpc.get_world_size(weight_parallel_mode)
|
||||
input_world_size = gpc.get_world_size(input_parallel_mode)
|
||||
|
||||
assert dim_size % (input_world_size*weight_world_size) == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
|
@ -464,47 +581,3 @@ def reduce_by_batch_3d(tensor: Tensor,
|
|||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
||||
|
||||
|
||||
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
|
||||
r"""broadcast weight from diagonal.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
output = broadcast(input_, src_rank, input_parallel_mode)
|
||||
ctx.src_rank = src_rank
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
|
||||
else:
|
||||
input_grad = None
|
||||
return input_grad, None, None, None
|
||||
|
||||
|
||||
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
|
||||
output_parallel_mode)
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_depth_from_env() -> int:
|
||||
|
@ -17,30 +22,17 @@ def get_depth_from_env() -> int:
|
|||
|
||||
|
||||
def get_parallel_mode_from_env(group):
|
||||
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
|
||||
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
|
||||
f'{group} is not valid for 3D tensor parallelism.'
|
||||
return getattr(env, group)
|
||||
|
||||
|
||||
def get_last_group(a, b):
|
||||
mapping = {
|
||||
ParallelMode.PARALLEL_3D_INPUT: 'A',
|
||||
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
|
||||
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
|
||||
}
|
||||
|
||||
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
|
||||
|
||||
if res == 'A':
|
||||
return ParallelMode.PARALLEL_3D_INPUT
|
||||
elif res == 'B':
|
||||
return ParallelMode.PARALLEL_3D_WEIGHT
|
||||
elif res == 'C':
|
||||
return ParallelMode.PARALLEL_3D_OUTPUT
|
||||
|
||||
|
||||
def swap_in_out_group():
|
||||
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
|
||||
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
|
||||
env.output_x_weight_group_3d,
|
||||
env.input_x_weight_group_3d,
|
||||
)
|
||||
|
||||
|
||||
def dbg_check_shape(tensor: Tensor, shape: tuple):
|
||||
|
@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
|
|||
print(tensor.shape)
|
||||
assert tensor.shape == shape, \
|
||||
'{} does not match {}'.format(tensor.shape, shape)
|
||||
|
||||
|
||||
class AsyncGradientBucket(object):
|
||||
|
||||
def __init__(self):
|
||||
self.bucket = OrderedDict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bucket)
|
||||
|
||||
def push(self, async_op, grad_tensor, param_id):
|
||||
self.bucket[param_id] = tuple((async_op, grad_tensor))
|
||||
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
|
||||
|
||||
def pop(self, param_id):
|
||||
grad = None
|
||||
if param_id in self.bucket:
|
||||
op, grad = self.bucket.pop(param_id)
|
||||
if op is not None:
|
||||
op.wait()
|
||||
return grad
|
||||
|
||||
def synchronize(self, params):
|
||||
for p in params:
|
||||
i = id(p)
|
||||
if i in self.bucket:
|
||||
op, grad = self.bucket.pop(i)
|
||||
if op is not None:
|
||||
op.wait()
|
||||
p.grad.add_(grad)
|
||||
|
||||
|
||||
_async_grad_bucket = AsyncGradientBucket()
|
||||
|
||||
|
||||
def push_async_grad(op, grad, param_id):
|
||||
return _async_grad_bucket.push(op, grad, param_id)
|
||||
|
||||
|
||||
def pop_async_grad(param_id):
|
||||
return _async_grad_bucket.pop(param_id)
|
||||
|
||||
|
||||
def _async_grad_hook(grad, param_id):
|
||||
grad.add_(pop_async_grad(param_id))
|
||||
return grad
|
||||
|
||||
|
||||
def register_async_grad_hook(param):
|
||||
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
|
||||
|
||||
|
||||
def synchronize(params=list()):
|
||||
_async_grad_bucket.synchronize(params)
|
||||
torch.cuda.default_stream().synchronize()
|
||||
if len(_async_grad_bucket) > 0:
|
||||
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import all_reduce, broadcast
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
@ -20,9 +20,9 @@ from torch import Tensor
|
|||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
|
||||
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
|
||||
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer):
|
|||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer):
|
|||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
self.reset_parameters()
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
|
@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer):
|
|||
|
||||
def reset_parameters(self) -> None:
|
||||
init.ones_()(self.weight)
|
||||
register_async_grad_hook(self.weight)
|
||||
if self.bias is not None:
|
||||
init.zeros_()(self.bias)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
||||
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
||||
return layernorm_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.normalized_shape,
|
||||
self.variance_epsilon,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
self.input_x_weight_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -161,6 +174,7 @@ class Linear3D(ParallelLayer):
|
|||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
@ -168,8 +182,10 @@ class Linear3D(ParallelLayer):
|
|||
self.out_features = out_features
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
@ -194,18 +210,23 @@ class Linear3D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -324,8 +345,20 @@ class Linear3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
output = linear_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer):
|
|||
self.num_classes = num_classes
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
|
||||
|
@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer):
|
|||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
|
||||
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
return classifier_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
self.num_classes = num_classes
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(num_classes, self.depth**2)
|
||||
|
@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
|
@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
return vocab_parallel_classifier_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
grid_size = to_2tuple(img_size // patch_size)
|
||||
num_patches = grid_size[0] * grid_size[1]
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_size = embed_size
|
||||
embed_size_per_partition = divide(embed_size, self.depth)
|
||||
embed_size_per_partition = embed_size // self.depth
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
|
@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.input_parallel_mode)
|
||||
grad = all_reduce(grad, self.weight_parallel_mode)
|
||||
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
|
||||
|
@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
|
||||
src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
|
||||
broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
|
||||
broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
|
||||
broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
|
||||
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
|
@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
|
@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer):
|
|||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
|
@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer):
|
|||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight,
|
||||
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
|
@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
|||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
|
||||
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
|
||||
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||
|
|
|
@ -6,12 +6,12 @@ RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
|
|||
# install apex
|
||||
RUN git clone https://github.com/NVIDIA/apex && \
|
||||
cd apex && \
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
|
||||
|
||||
# install colossalai
|
||||
RUN git clone https://github.com/hpcaitech/ColossalAI.git \
|
||||
&& cd ./ColossalAI \
|
||||
&& pip install -v --no-cache-dir .
|
||||
&& cd ./ColossalAI \
|
||||
&& pip install -v --no-cache-dir .
|
||||
|
||||
# install titans
|
||||
RUN pip install --no-cache-dir titans
|
||||
|
|
|
@ -20,7 +20,6 @@ def check_linear():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
|
@ -32,12 +31,12 @@ def check_linear():
|
|||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data.transpose(0, 1)
|
||||
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
|
@ -49,7 +48,7 @@ def check_linear():
|
|||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
|
@ -72,7 +71,7 @@ def check_linear():
|
|||
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
@ -108,7 +107,6 @@ def check_layernorm():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
|
@ -119,7 +117,7 @@ def check_layernorm():
|
|||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
|
||||
norm = norm.to(device)
|
||||
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
|
||||
norm_master = norm_master.to(device)
|
||||
|
@ -134,7 +132,7 @@ def check_layernorm():
|
|||
norm.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
|
@ -159,7 +157,7 @@ def check_layernorm():
|
|||
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
|
@ -193,7 +191,6 @@ def check_classifier_no_given_weight():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
|
@ -204,10 +201,10 @@ def check_classifier_no_given_weight():
|
|||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype)
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
|
@ -219,7 +216,7 @@ def check_classifier_no_given_weight():
|
|||
layer.bias.data.copy_(bias_master)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
|
@ -242,7 +239,7 @@ def check_classifier_no_given_weight():
|
|||
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
|
@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
|
@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
|
@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed = embed.to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
embed_master = embed_master.to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
|
@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
|
@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
@ -546,12 +541,12 @@ def check_patch_embed():
|
|||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
@ -566,7 +561,7 @@ def check_patch_embed():
|
|||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
|
@ -586,7 +581,7 @@ def check_patch_embed():
|
|||
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
|
@ -639,9 +634,9 @@ def check_embed():
|
|||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
|
@ -669,7 +664,7 @@ def check_embed():
|
|||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
|
@ -686,10 +681,7 @@ def check_embed():
|
|||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
@ -709,9 +701,9 @@ def check_vocab_parallel_embed():
|
|||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
|
@ -741,7 +733,7 @@ def check_vocab_parallel_embed():
|
|||
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
|
@ -771,7 +763,6 @@ def check_loss():
|
|||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
@ -783,8 +774,8 @@ def check_loss():
|
|||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
@ -836,8 +827,8 @@ def check_vocab_parallel_loss():
|
|||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
|
|
@ -12,8 +12,8 @@ NUM_BLOCKS = 2
|
|||
IMG_SIZE = 16
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
assert eq
|
||||
return eq
|
||||
|
||||
assert eq, f"\nA = {A}\nB = {B}"
|
||||
return eq
|
|
@ -10,9 +10,8 @@ from colossalai.initialize import launch
|
|||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
|
||||
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear,
|
||||
check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss)
|
||||
|
||||
|
@ -30,7 +29,6 @@ def check_layer():
|
|||
check_layernorm()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
|
|
Loading…
Reference in New Issue