mirror of https://github.com/hpcaitech/ColossalAI
fixed padding index issue for vocab parallel embedding layers; updated 3D linear to be compatible with examples in the tutorial
parent
24f8583cc4
commit
3dba070580
|
@ -18,8 +18,8 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad,
|
||||
reduce_input, set_parallel_input, split_forward_gather_backward)
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -551,9 +551,10 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Build the mask.
|
||||
|
|
|
@ -415,9 +415,10 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
|||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
|
|
|
@ -430,9 +430,10 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
|||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Build the mask.
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
||||
from colossalai.context import parallel_mode
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
|
@ -13,8 +12,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
from ._utils import get_parallel_mode_from_env
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
|
||||
|
||||
class _Linear3D(torch.autograd.Function):
|
||||
|
||||
|
@ -33,6 +30,7 @@ class _Linear3D(torch.autograd.Function):
|
|||
ctx.use_bias = bias is not None
|
||||
|
||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||
weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
@ -64,7 +62,7 @@ class _Linear3D(torch.autograd.Function):
|
|||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
|
||||
if ctx.use_bias:
|
||||
|
@ -343,27 +341,29 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|||
return _ReduceTensor3D.apply(tensor, parallel_mode)
|
||||
|
||||
|
||||
class _ReduceGrad3D(torch.autograd.Function):
|
||||
class _AllGatherTensor3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
def forward(ctx, input_, dim, parallel_mode):
|
||||
ctx.dim = dim
|
||||
ctx.parallel_mode = parallel_mode
|
||||
return input_
|
||||
output = all_gather(input_, dim, parallel_mode)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
input_grad = all_reduce(output_grad, ctx.parallel_mode)
|
||||
return input_grad, None
|
||||
input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
|
||||
return input_grad, None, None
|
||||
|
||||
|
||||
def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
All-reduce the gradient in backward pass.
|
||||
|
||||
:param tensor: Input tensor
|
||||
:param parallel_mode: Parallel mode
|
||||
"""
|
||||
return _ReduceGrad3D.apply(tensor, parallel_mode)
|
||||
return _AllGatherTensor3D.apply(tensor, dim, parallel_mode)
|
||||
|
||||
|
||||
class _ReduceScatterTensor3D(torch.autograd.Function):
|
||||
|
|
|
@ -100,7 +100,8 @@ class Linear3D(ParallelLayer):
|
|||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.in_features_per_partition,
|
||||
|
@ -108,8 +109,8 @@ class Linear3D(ParallelLayer):
|
|||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -118,21 +119,20 @@ class Linear3D(ParallelLayer):
|
|||
swap_in_out_group()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
|
@ -257,7 +257,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(num_classes, self.depth)
|
||||
self.out_features_per_partition = divide(num_classes, self.depth**2)
|
||||
self.bias_features_per_partition = divide(num_classes, self.depth)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
|
@ -270,8 +271,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
dtype=dtype))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -289,15 +290,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
|
@ -523,11 +523,11 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth)
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
|
||||
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
|
||||
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
|
@ -546,13 +546,12 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
|
@ -561,7 +560,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
weight = reduce_grad_3d(self.weight, self.weight_parallel_mode)
|
||||
weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode)
|
||||
|
||||
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ def check_linear():
|
|||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
|
@ -93,6 +94,7 @@ def check_linear():
|
|||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
|
@ -301,6 +303,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
|
@ -358,6 +361,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
|||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
@ -470,6 +474,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
|
@ -518,6 +523,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
|||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
|
@ -710,6 +716,7 @@ def check_vocab_parallel_embed():
|
|||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
|
||||
|
@ -751,6 +758,7 @@ def check_vocab_parallel_embed():
|
|||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
|
|
Loading…
Reference in New Issue