mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
3.0 KiB
81 lines
3.0 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
|
class SeedManager:
|
|
"""This class is a manager of all random seeds involved in the system.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._current_mode = None
|
|
self._seeds = dict()
|
|
self._seed_states = dict()
|
|
|
|
@property
|
|
def current_mode(self):
|
|
return self._current_mode
|
|
|
|
@property
|
|
def seeds(self):
|
|
return self._seeds
|
|
|
|
@property
|
|
def seed_states(self):
|
|
return self._seed_states
|
|
|
|
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
|
|
"""Sets the state of the seed manager for `parallel_mode`.
|
|
|
|
:param parallel_mode: The chosen parallel mode
|
|
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
|
:param state: the state to be set
|
|
:type state: :class:`torch.Tensor`
|
|
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
|
|
"""
|
|
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
|
|
self._seed_states[parallel_mode] = state
|
|
|
|
def set_mode(self, parallel_mode: ParallelMode):
|
|
"""Sets the current mode of the seed manager.
|
|
|
|
:param parallel_mode: The chosen parallel mode
|
|
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
|
"""
|
|
if self.current_mode:
|
|
# save the current state for current mode
|
|
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
|
|
|
|
# set the new state for new mode
|
|
self._current_mode = parallel_mode
|
|
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
|
|
|
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False):
|
|
"""Adds a seed to the seed manager for `parallel_mode`.
|
|
|
|
:param parallel_mode: The chosen parallel mode
|
|
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
|
:param seed: The seed to be added
|
|
:type seed: int
|
|
:param overwrtie: Whether allows to overwrite the seed that has been set already
|
|
:type overwrtie: bool, optional
|
|
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
|
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
|
"""
|
|
assert isinstance(
|
|
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
|
if overwrtie is False:
|
|
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
|
elif parallel_mode in self._seed_states:
|
|
print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True)
|
|
|
|
current_state = torch.cuda.get_rng_state()
|
|
torch.cuda.manual_seed(seed)
|
|
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
|
self._seeds[parallel_mode] = seed
|
|
torch.cuda.set_rng_state(current_state)
|