fixed padding index issue for vocab parallel embedding layers; updated 3D linear to be compatible with examples in the tutorial

pull/394/head
zbian 2022-02-17 22:03:39 +08:00 committed by Frank Lee
parent 24f8583cc4
commit 3dba070580
6 changed files with 50 additions and 40 deletions

View File

@ -18,8 +18,8 @@ from torch.nn.parameter import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
reduce_input, set_parallel_input, split_forward_gather_backward) split_forward_gather_backward)
@LAYERS.register_module @LAYERS.register_module
@ -551,9 +551,10 @@ class VocabParallelEmbedding1D(torch.nn.Module):
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Build the mask. # Build the mask.

View File

@ -415,9 +415,10 @@ class VocabParallelEmbedding2D(torch.nn.Module):
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)

View File

@ -430,9 +430,10 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Build the mask. # Build the mask.

View File

@ -5,7 +5,6 @@ from typing import Optional, Tuple
import torch import torch
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter) from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context import parallel_mode
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
@ -13,8 +12,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env from ._utils import get_parallel_mode_from_env
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.base_layer import ParallelLayer
class _Linear3D(torch.autograd.Function): class _Linear3D(torch.autograd.Function):
@ -33,6 +30,7 @@ class _Linear3D(torch.autograd.Function):
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode) input_ = all_gather(input_, input_dim, input_parallel_mode)
weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight) output = torch.matmul(input_, weight)
@ -64,7 +62,7 @@ class _Linear3D(torch.autograd.Function):
weight_grad = torch.matmul( weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) async_ops.append(op)
if ctx.use_bias: if ctx.use_bias:
@ -343,27 +341,29 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
return _ReduceTensor3D.apply(tensor, parallel_mode) return _ReduceTensor3D.apply(tensor, parallel_mode)
class _ReduceGrad3D(torch.autograd.Function): class _AllGatherTensor3D(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, parallel_mode): def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
return input_ output = all_gather(input_, dim, parallel_mode)
return output
@staticmethod @staticmethod
def backward(ctx, output_grad): def backward(ctx, output_grad):
input_grad = all_reduce(output_grad, ctx.parallel_mode) input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
return input_grad, None return input_grad, None, None
def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the gradient in backward pass. All-reduce the gradient in backward pass.
:param tensor: Input tensor :param tensor: Input tensor
:param parallel_mode: Parallel mode :param parallel_mode: Parallel mode
""" """
return _ReduceGrad3D.apply(tensor, parallel_mode) return _AllGatherTensor3D.apply(tensor, dim, parallel_mode)
class _ReduceScatterTensor3D(torch.autograd.Function): class _ReduceScatterTensor3D(torch.autograd.Function):

View File

@ -100,7 +100,8 @@ class Linear3D(ParallelLayer):
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth) self.out_features_per_partition = divide(out_features, self.depth**2)
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.in_features_per_partition, torch.empty(self.in_features_per_partition,
@ -108,8 +109,8 @@ class Linear3D(ParallelLayer):
device=get_current_device(), device=get_current_device(),
dtype=dtype)) dtype=dtype))
if bias: if bias:
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(), self.bias = Parameter(
dtype=dtype)) torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
else: else:
self.bias = None self.bias = None
@ -118,21 +119,20 @@ class Linear3D(ParallelLayer):
swap_in_out_group() swap_in_out_group()
def _set_tensor_parallel_attributes(self) -> None: def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth) set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
@ -257,7 +257,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth) self.out_features_per_partition = divide(num_classes, self.depth**2)
self.bias_features_per_partition = divide(num_classes, self.depth)
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
@ -270,8 +271,8 @@ class VocabParallelClassifier3D(ParallelLayer):
dtype=dtype)) dtype=dtype))
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(), self.bias = Parameter(
dtype=dtype)) torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
else: else:
self.bias = None self.bias = None
@ -289,15 +290,14 @@ class VocabParallelClassifier3D(ParallelLayer):
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
if self.has_weight: if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
@ -523,11 +523,11 @@ class VocabParallelEmbedding3D(torch.nn.Module):
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth) self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth) self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
self.weight = Parameter( self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
@ -546,13 +546,12 @@ class VocabParallelEmbedding3D(torch.nn.Module):
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
@ -561,7 +560,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
masked_input = input_.clone() - self.vocab_start_index masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0 masked_input[input_mask] = 0
weight = reduce_grad_3d(self.weight, self.weight_parallel_mode) weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode)
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@ -41,6 +41,7 @@ def check_linear():
torch.distributed.broadcast(weight_master, src=0) torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j] weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
layer.weight.data.copy_(weight) layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0) torch.distributed.broadcast(bias_master, src=0)
@ -93,6 +94,7 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad bias_grad = layer_master.bias.grad
@ -301,6 +303,7 @@ def check_vocab_parallel_classifier_no_given_weight():
weight_master = layer_master.weight.data weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0) torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j] weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight) layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data bias_master = layer_master.bias.data
@ -358,6 +361,7 @@ def check_vocab_parallel_classifier_no_given_weight():
B_grad = layer_master.weight.grad B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad))) rank, check_equal(B_grad, layer.weight.grad)))
@ -470,6 +474,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
weight_master = embed_master.weight.data weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0) torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j] weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight) embed.weight.data.copy_(weight)
@ -518,6 +523,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
B_grad = embed_master.weight.grad B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad, check_equal(B_grad,
@ -710,6 +716,7 @@ def check_vocab_parallel_embed():
weight_master = layer_master.weight.data weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0) torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j] weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight) layer.weight.data.copy_(weight)
@ -751,6 +758,7 @@ def check_vocab_parallel_embed():
B_grad = layer_master.weight.grad B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad, check_equal(B_grad,