mirror of https://github.com/hpcaitech/ColossalAI
fixed padding index issue for vocab parallel embedding layers; updated 3D linear to be compatible with examples in the tutorial
parent
24f8583cc4
commit
3dba070580
|
@ -18,8 +18,8 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad,
|
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||||
reduce_input, set_parallel_input, split_forward_gather_backward)
|
split_forward_gather_backward)
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
|
@ -551,9 +551,10 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None and \
|
||||||
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
|
|
|
@ -415,9 +415,10 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None and \
|
||||||
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||||
|
|
|
@ -430,9 +430,10 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None and \
|
||||||
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
|
|
|
@ -5,7 +5,6 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
||||||
from colossalai.context import parallel_mode
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -13,8 +12,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from ._utils import get_parallel_mode_from_env
|
from ._utils import get_parallel_mode_from_env
|
||||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||||
|
|
||||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
|
||||||
|
|
||||||
|
|
||||||
class _Linear3D(torch.autograd.Function):
|
class _Linear3D(torch.autograd.Function):
|
||||||
|
|
||||||
|
@ -33,6 +30,7 @@ class _Linear3D(torch.autograd.Function):
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
|
|
||||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||||
|
weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||||
ctx.save_for_backward(input_, weight)
|
ctx.save_for_backward(input_, weight)
|
||||||
|
|
||||||
output = torch.matmul(input_, weight)
|
output = torch.matmul(input_, weight)
|
||||||
|
@ -64,7 +62,7 @@ class _Linear3D(torch.autograd.Function):
|
||||||
|
|
||||||
weight_grad = torch.matmul(
|
weight_grad = torch.matmul(
|
||||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
|
||||||
async_ops.append(op)
|
async_ops.append(op)
|
||||||
|
|
||||||
if ctx.use_bias:
|
if ctx.use_bias:
|
||||||
|
@ -343,27 +341,29 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||||
return _ReduceTensor3D.apply(tensor, parallel_mode)
|
return _ReduceTensor3D.apply(tensor, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
class _ReduceGrad3D(torch.autograd.Function):
|
class _AllGatherTensor3D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, parallel_mode):
|
def forward(ctx, input_, dim, parallel_mode):
|
||||||
|
ctx.dim = dim
|
||||||
ctx.parallel_mode = parallel_mode
|
ctx.parallel_mode = parallel_mode
|
||||||
return input_
|
output = all_gather(input_, dim, parallel_mode)
|
||||||
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grad):
|
def backward(ctx, output_grad):
|
||||||
input_grad = all_reduce(output_grad, ctx.parallel_mode)
|
input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
|
||||||
return input_grad, None
|
return input_grad, None, None
|
||||||
|
|
||||||
|
|
||||||
def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
All-reduce the gradient in backward pass.
|
All-reduce the gradient in backward pass.
|
||||||
|
|
||||||
:param tensor: Input tensor
|
:param tensor: Input tensor
|
||||||
:param parallel_mode: Parallel mode
|
:param parallel_mode: Parallel mode
|
||||||
"""
|
"""
|
||||||
return _ReduceGrad3D.apply(tensor, parallel_mode)
|
return _AllGatherTensor3D.apply(tensor, dim, parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
class _ReduceScatterTensor3D(torch.autograd.Function):
|
class _ReduceScatterTensor3D(torch.autograd.Function):
|
||||||
|
|
|
@ -100,7 +100,8 @@ class Linear3D(ParallelLayer):
|
||||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||||
self.depth = get_depth_from_env()
|
self.depth = get_depth_from_env()
|
||||||
self.in_features_per_partition = divide(in_features, self.depth)
|
self.in_features_per_partition = divide(in_features, self.depth)
|
||||||
self.out_features_per_partition = divide(out_features, self.depth)
|
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||||
|
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.empty(self.in_features_per_partition,
|
torch.empty(self.in_features_per_partition,
|
||||||
|
@ -108,8 +109,8 @@ class Linear3D(ParallelLayer):
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
|
self.bias = Parameter(
|
||||||
dtype=dtype))
|
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
@ -118,21 +119,20 @@ class Linear3D(ParallelLayer):
|
||||||
swap_in_out_group()
|
swap_in_out_group()
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self) -> None:
|
def _set_tensor_parallel_attributes(self) -> None:
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
|
||||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
|
||||||
|
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||||
|
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||||
|
|
||||||
|
@ -257,7 +257,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||||
self.depth = get_depth_from_env()
|
self.depth = get_depth_from_env()
|
||||||
self.in_features_per_partition = divide(in_features, self.depth)
|
self.in_features_per_partition = divide(in_features, self.depth)
|
||||||
self.out_features_per_partition = divide(num_classes, self.depth)
|
self.out_features_per_partition = divide(num_classes, self.depth**2)
|
||||||
|
self.bias_features_per_partition = divide(num_classes, self.depth)
|
||||||
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
@ -270,8 +271,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
self.has_weight = True
|
self.has_weight = True
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
|
self.bias = Parameter(
|
||||||
dtype=dtype))
|
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
@ -289,15 +290,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
fan_in, fan_out = self.in_features, self.num_classes
|
fan_in, fan_out = self.in_features, self.num_classes
|
||||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
|
||||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
|
||||||
|
|
||||||
if self.has_weight:
|
if self.has_weight:
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||||
|
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||||
|
|
||||||
|
@ -523,11 +523,11 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||||
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth)
|
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
|
||||||
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
|
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
|
||||||
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
|
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||||
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition
|
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition * self.depth
|
||||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||||
|
@ -546,13 +546,12 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
|
||||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
|
||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None and \
|
||||||
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||||
|
@ -561,7 +560,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||||
masked_input = input_.clone() - self.vocab_start_index
|
masked_input = input_.clone() - self.vocab_start_index
|
||||||
masked_input[input_mask] = 0
|
masked_input[input_mask] = 0
|
||||||
|
|
||||||
weight = reduce_grad_3d(self.weight, self.weight_parallel_mode)
|
weight = all_gather_tensor_3d(self.weight, 0, self.weight_parallel_mode)
|
||||||
|
|
||||||
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ def check_linear():
|
||||||
torch.distributed.broadcast(weight_master, src=0)
|
torch.distributed.broadcast(weight_master, src=0)
|
||||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||||
|
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||||
layer.weight.data.copy_(weight)
|
layer.weight.data.copy_(weight)
|
||||||
bias_master = layer_master.bias.data
|
bias_master = layer_master.bias.data
|
||||||
torch.distributed.broadcast(bias_master, src=0)
|
torch.distributed.broadcast(bias_master, src=0)
|
||||||
|
@ -93,6 +94,7 @@ def check_linear():
|
||||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||||
|
|
||||||
bias_grad = layer_master.bias.grad
|
bias_grad = layer_master.bias.grad
|
||||||
|
@ -301,6 +303,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||||
weight_master = layer_master.weight.data
|
weight_master = layer_master.weight.data
|
||||||
torch.distributed.broadcast(weight_master, src=0)
|
torch.distributed.broadcast(weight_master, src=0)
|
||||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||||
|
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||||
layer.weight.data.copy_(weight)
|
layer.weight.data.copy_(weight)
|
||||||
bias_master = layer_master.bias.data
|
bias_master = layer_master.bias.data
|
||||||
|
@ -358,6 +361,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||||
|
|
||||||
B_grad = layer_master.weight.grad
|
B_grad = layer_master.weight.grad
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
||||||
rank, check_equal(B_grad, layer.weight.grad)))
|
rank, check_equal(B_grad, layer.weight.grad)))
|
||||||
|
@ -470,6 +474,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||||
weight_master = embed_master.weight.data
|
weight_master = embed_master.weight.data
|
||||||
torch.distributed.broadcast(weight_master, src=0)
|
torch.distributed.broadcast(weight_master, src=0)
|
||||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||||
|
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||||
embed.weight.data.copy_(weight)
|
embed.weight.data.copy_(weight)
|
||||||
|
|
||||||
|
@ -518,6 +523,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||||
|
|
||||||
B_grad = embed_master.weight.grad
|
B_grad = embed_master.weight.grad
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||||
check_equal(B_grad,
|
check_equal(B_grad,
|
||||||
|
@ -710,6 +716,7 @@ def check_vocab_parallel_embed():
|
||||||
weight_master = layer_master.weight.data
|
weight_master = layer_master.weight.data
|
||||||
torch.distributed.broadcast(weight_master, src=0)
|
torch.distributed.broadcast(weight_master, src=0)
|
||||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||||
|
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||||
layer.weight.data.copy_(weight)
|
layer.weight.data.copy_(weight)
|
||||||
|
|
||||||
|
@ -751,6 +758,7 @@ def check_vocab_parallel_embed():
|
||||||
|
|
||||||
B_grad = layer_master.weight.grad
|
B_grad = layer_master.weight.grad
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||||
check_equal(B_grad,
|
check_equal(B_grad,
|
||||||
|
|
Loading…
Reference in New Issue