diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index daf54c126..2c1314f19 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -18,8 +18,8 @@ from torch.nn.parameter import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, - reduce_input, set_parallel_input, split_forward_gather_backward) +from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, + split_forward_gather_backward) @LAYERS.register_module @@ -551,9 +551,10 @@ class VocabParallelEmbedding1D(torch.nn.Module): self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def forward(self, input_: Tensor) -> Tensor: # Build the mask. diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index b6adbcecd..4b7dfb529 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -415,9 +415,10 @@ class VocabParallelEmbedding2D(torch.nn.Module): self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def forward(self, input_: Tensor) -> Tensor: input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index 7dd17f21b..a803f331d 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -430,9 +430,10 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def forward(self, input_: Tensor) -> Tensor: # Build the mask. diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 26e30d8cf..6ad442788 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -5,7 +5,6 @@ from typing import Optional, Tuple import torch from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter) -from colossalai.context import parallel_mode from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from torch import Tensor @@ -13,8 +12,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd from ._utils import get_parallel_mode_from_env from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.base_layer import ParallelLayer - class _Linear3D(torch.autograd.Function): @@ -33,6 +30,7 @@ class _Linear3D(torch.autograd.Function): ctx.use_bias = bias is not None input_ = all_gather(input_, input_dim, input_parallel_mode) + weight = all_gather(weight, weight_dim, weight_parallel_mode) ctx.save_for_backward(input_, weight) output = torch.matmul(input_, weight) @@ -64,7 +62,7 @@ class _Linear3D(torch.autograd.Function): weight_grad = torch.matmul( input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) + weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True) async_ops.append(op) if ctx.use_bias: @@ -343,27 +341,29 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: return _ReduceTensor3D.apply(tensor, parallel_mode) -class _ReduceGrad3D(torch.autograd.Function): +class _AllGatherTensor3D(torch.autograd.Function): @staticmethod - def forward(ctx, input_, parallel_mode): + def forward(ctx, input_, dim, parallel_mode): + ctx.dim = dim ctx.parallel_mode = parallel_mode - return input_ + output = all_gather(input_, dim, parallel_mode) + return output @staticmethod def backward(ctx, output_grad): - input_grad = all_reduce(output_grad, ctx.parallel_mode) - return input_grad, None + input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode) + return input_grad, None, None -def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: +def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """ All-reduce the gradient in backward pass. :param tensor: Input tensor :param parallel_mode: Parallel mode """ - return _ReduceGrad3D.apply(tensor, parallel_mode) + return _AllGatherTensor3D.apply(tensor, dim, parallel_mode) class _ReduceScatterTensor3D(torch.autograd.Function): diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index da8a50995..5164bc69a 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -100,7 +100,8 @@ class Linear3D(ParallelLayer): self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(out_features, self.depth) + self.out_features_per_partition = divide(out_features, self.depth**2) + self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( torch.empty(self.in_features_per_partition, @@ -108,8 +109,8 @@ class Linear3D(ParallelLayer): device=get_current_device(), dtype=dtype)) if bias: - self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(), - dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None @@ -118,21 +119,20 @@ class Linear3D(ParallelLayer): swap_in_out_group() def _set_tensor_parallel_attributes(self) -> None: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.out_features - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) + weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] + output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) @@ -257,7 +257,8 @@ class VocabParallelClassifier3D(ParallelLayer): self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(num_classes, self.depth) + self.out_features_per_partition = divide(num_classes, self.depth**2) + self.bias_features_per_partition = divide(num_classes, self.depth) if weight is not None: self.weight = weight @@ -270,8 +271,8 @@ class VocabParallelClassifier3D(ParallelLayer): dtype=dtype)) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(), - dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None @@ -289,15 +290,14 @@ class VocabParallelClassifier3D(ParallelLayer): def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) + weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] + output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) @@ -523,11 +523,11 @@ class VocabParallelEmbedding3D(torch.nn.Module): self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) - self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth) + self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) self.embed_dim_per_partition = divide(self.embed_dim, self.depth) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) - self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth self.weight = Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), @@ -546,13 +546,12 @@ class VocabParallelEmbedding3D(torch.nn.Module): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) @@ -561,7 +560,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - weight = reduce_grad_3d(self.weight, self.weight_parallel_mode) + weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode) output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index 087bb0781..b3cfa60bd 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -41,6 +41,7 @@ def check_linear(): torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[j] + weight = torch.chunk(weight, DEPTH, dim=-1)[i] layer.weight.data.copy_(weight) bias_master = layer_master.bias.data torch.distributed.broadcast(bias_master, src=0) @@ -93,6 +94,7 @@ def check_linear(): B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad @@ -301,6 +303,7 @@ def check_vocab_parallel_classifier_no_given_weight(): weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] weight = torch.chunk(weight, DEPTH, dim=-1)[k] layer.weight.data.copy_(weight) bias_master = layer_master.bias.data @@ -358,6 +361,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( rank, check_equal(B_grad, layer.weight.grad))) @@ -470,6 +474,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): weight_master = embed_master.weight.data torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] weight = torch.chunk(weight, DEPTH, dim=-1)[k] embed.weight.data.copy_(weight) @@ -518,6 +523,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, @@ -710,6 +716,7 @@ def check_vocab_parallel_embed(): weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[j] + weight = torch.chunk(weight, DEPTH, dim=0)[i] weight = torch.chunk(weight, DEPTH, dim=-1)[k] layer.weight.data.copy_(weight) @@ -751,6 +758,7 @@ def check_vocab_parallel_embed(): B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, check_equal(B_grad,