From 3893fa1a8d0cdcc3df3c6e4eb11956e688541cbd Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Jun 2023 15:00:26 +0800 Subject: [PATCH] [shardformer] refactored embedding and dropout to parallel module (#4013) * [shardformer] refactored embedding and dropout to parallel module * polish code --- .../shardformer/layer/dist_crossentropy.py | 20 +- colossalai/shardformer/layer/dropout.py | 36 +- colossalai/shardformer/layer/layers.py | 467 +++--------------- .../test_layer/test_dropout.py | 53 ++ .../test_layer/test_embedding.py | 43 ++ .../test_layer/test_linear_1d.py | 2 +- 6 files changed, 198 insertions(+), 423 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_dropout.py create mode 100644 tests/test_shardformer/test_layer/test_embedding.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index ff05209fe..7840c2f2e 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -3,6 +3,7 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function +from torch.distributed import ProcessGroup class DistCrossEntropy(Function): @@ -14,7 +15,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -34,15 +35,15 @@ class DistCrossEntropy(Function): """ # get the max logits_max = torch.max(vocab_logits, dim=-1)[0] - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) # minus the max to avoid the result of sum of exp is too large and the log is nan vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device partition_vocab_size = vocab_logits.size()[-1] - rank = dist.get_rank() - world_size = dist.get_world_size() + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) global_vocab_size = partition_vocab_size * world_size # [down, up) => false, other device and -100 => true @@ -67,11 +68,11 @@ class DistCrossEntropy(Function): pred_logits[mask] = 0.0 # allreduce the get all x(i,y) - dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) sum_exp_logits = torch.sum(exp_logits, dim=-1) - dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] @@ -101,5 +102,8 @@ class DistCrossEntropy(Function): return grad_logits, None, None -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index) +def cross_entropy_1d(vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 5d295be6b..ec08d072f 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,19 +1,43 @@ -import torch -import torch.distributed as dist -import torch.nn as nn +from typing import List, Union +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from .layers import ParallelModule from .utils import create_randomizer_with_offset -class Dropout1D(nn.Dropout): +class Dropout1D(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. - def __init__(self, p=0.5, inplace=False, process_group=None): - super().__init__(p, inplace) + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + @staticmethod + def from_native_module(module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D": + """ + Create a Dropout1D layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return Dropout1D(p=p, inplace=inplace, process_group=process_group) + def forward(self, input): with self.randomizer.fork_rng(): input = super().forward(input) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 2ad6523c9..87d24f18e 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -22,12 +22,7 @@ from colossalai.kernel import LayerNorm from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.nn.layer.parallel_1d._utils import ( - gather_forward_split_backward, - get_parallel_input, - reduce_grad, - set_parallel_input, -) +from colossalai.nn.layer.parallel_1d._utils import get_parallel_input, reduce_grad, set_parallel_input from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise @@ -432,279 +427,7 @@ class LayerNorm1D(ColossalaiModule): super()._save_to_state_dict(destination, prefix, keep_vars) -class Classifier1D(ParallelLayer): - r"""RowLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = False - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - output_parallel = F.linear(input_, self.weight) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if self.bias is not None: - output = output + self.bias - return output - - -class VocabParallelClassifier1D(ParallelLayer): - r"""ColLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.gather_output = gather_output - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - return output - - -# @LAYERS.register_module - - -class Embedding1D(ParallelLayer): +class Embedding1D(ParallelModule): r"""Embedding for 1D parallelism. Args: @@ -739,7 +462,8 @@ class Embedding1D(ParallelLayer): embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, - gather_output: bool = True, + device: torch.device = None, + process_group: ProcessGroup = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -747,66 +471,79 @@ class Embedding1D(ParallelLayer): self.num_embeddings = num_embeddings self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + if device is None: + device = get_current_device() - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.gather_output: - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) return output @@ -926,89 +663,3 @@ class VocabParallelEmbedding1D(ParallelLayer): # Reduce across all the model parallel GPUs. output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) return output - - -class Dropout1D(ParallelLayer): - """Dropout layer of 1D parallelism. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False): - super().__init__() - self.parallel_input = get_parallel_input() - self.p = p - self.inplace = inplace - - def forward(self, input_: Tensor) -> Tensor: - if self.parallel_input: - with seed(ParallelMode.TENSOR): - output = F.dropout(input_, self.p, self.training, self.inplace) - else: - output = F.dropout(input_, self.p, self.training, self.inplace) - return output - - -# @LAYERS.register_module -class PatchEmbedding1D(ColossalaiModule): - """ - 2D Image to Patch Embedding - - :param img_size: image size - :type img_size: int - :param patch_size: patch size - :type patch_size: int - :param in_chans: number of channels of input image - :type in_chans: int - :param embed_size: size of embedding - :type embed_size: int - :param dtype: The dtype of parameters, defaults to None - :type dtype: torch.dtype, optional - :param flatten: whether to flatten output tensor, defaults to True - :type flatten: bool, optional - :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer - :type weight_initializer: typing.Callable, optional - :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer - :type bias_initializer: typing.Callable, optional - :param position_embed_initializer: The initializer of position embedding, defaults to zero - :type position_embed_initializer: typing.Callable, optional - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) - super().__init__(embed) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - for key in param_keys: - param = state_dict.pop(key, None) - if param is not None: - local_state[key] = param - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py new file mode 100644 index 000000000..c48c11b36 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -0,0 +1,53 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn + + +def check_dropout(): + dropout = nn.Dropout().cuda() + dropout_1d = Dropout1D.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + + # we set seed so that dropout will generate the same mask + torch.cuda.manual_seed(1024) + out = dropout(x) + + # we set seed to simulate the same scenario + # but expect the dropout mask to be different + # due to the internal randomness control + torch.cuda.manual_seed(1024) + out_1d = dropout_1d(x) + + # ensure out is the same across all ranks + world_size = dist.get_world_size() + out_all = [torch.empty_like(out) for _ in range(world_size)] + dist.all_gather(out_all, out) + + for i in range(world_size): + assert_equal(out_all[i], out_all[0]) + + # ensure out_1d is different across ranks + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_not_equal(out_1d_all[i], out_1d_all[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_dropout() + + +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py new file mode 100644 index 000000000..462349ecb --- /dev/null +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -0,0 +1,43 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import Embedding1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_embedding_1d(): + embedding = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + + assert embedding_1d.weight.shape == torch.Size([32, 64]) + + # check computation correctness + x = torch.randint(low=0, high=32, size=(4, 32)).cuda() + out = embedding(x) + gather_out = embedding_1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_embedding_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 449522c64..2a3ce9938 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -5,7 +5,7 @@ from torch.testing import assert_close import colossalai from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_linear_1d_col():