ColossalAI/colossalai/shardformer/layer/dropout.py

84 lines
3.5 KiB
Python

from typing import List, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput']
class DropoutForParallelInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.
Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
@staticmethod
def from_native_module(module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput":
"""
Create a DropoutForParallelInput layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group)
def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
return input
class DropoutForReplicatedInput(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.
Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
# offset the seed with randomizer index only
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False)
@staticmethod
def from_native_module(
module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput":
"""
Create a Dropout1D layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group)
def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
return input