You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/shardformer/layer/dropout.py

59 lines
1.5 KiB

import os
from contextlib import contextmanager
import torch
import torch.nn as nn
class SeedManager:
"""
This class is a random state manager to change random state for different random seed.
"""
def __init__(self):
original_state = torch.cuda.get_rng_state()
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)
def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state
@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)
_seed_manager = SeedManager()
class Dropout1D(nn.Dropout):
def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)
def forward(self, input):
with _seed_manager.dropout_mode():
input = super().forward(input)
return input