|
|
|
import os
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
class SeedManager:
|
|
|
|
"""
|
|
|
|
This class is a random state manager to change random state for different random seed.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
original_state = torch.cuda.get_rng_state()
|
|
|
|
# TODO: unify this seed manager with the colossalai.context.random
|
|
|
|
seed = os.getpid()
|
|
|
|
torch.cuda.manual_seed(int(seed))
|
|
|
|
self.dropout_state = torch.cuda.get_rng_state()
|
|
|
|
torch.cuda.set_rng_state(original_state)
|
|
|
|
|
|
|
|
def set_mode(self, rng_state):
|
|
|
|
torch.cuda.set_rng_state(rng_state)
|
|
|
|
|
|
|
|
def get_current_mode(self):
|
|
|
|
current_state = torch.cuda.get_rng_state()
|
|
|
|
return current_state
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def dropout_mode(self):
|
|
|
|
"""
|
|
|
|
This is a context manager to change the dropout state and recover the original state.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
::
|
|
|
|
>>> with _seed_manager.dropout_mode():
|
|
|
|
>>> input = super().forward(input)
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
current_mode = self.get_current_mode()
|
|
|
|
yield self.set_mode(self.dropout_state)
|
|
|
|
finally:
|
|
|
|
self.dropout_state = self.get_current_mode()
|
|
|
|
self.set_mode(current_mode)
|
|
|
|
|
|
|
|
|
|
|
|
_seed_manager = SeedManager()
|
|
|
|
|
|
|
|
|
|
|
|
class Dropout1D(nn.Dropout):
|
|
|
|
|
|
|
|
def __init__(self, p=0.5, inplace=False):
|
|
|
|
super().__init__(p, inplace)
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
with _seed_manager.dropout_mode():
|
|
|
|
input = super().forward(input)
|
|
|
|
return input
|