#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch from torch import Tensor from colossalai.context.parallel_mode import ParallelMode class SeedManager: """This class is a manager of all random seeds involved in the system. """ def __init__(self): self._current_mode = None self._seeds = dict() self._seed_states = dict() @property def current_mode(self): return self._current_mode @property def seeds(self): return self._seeds @property def seed_states(self): return self._seed_states def set_state(self, parallel_mode: ParallelMode, state: Tensor): """Sets the state of the seed manager for `parallel_mode`. :param parallel_mode: The chosen parallel mode :type parallel_mode: :class:`colossalai.context.ParallelMode` :param state: the state to be set :type state: :class:`torch.Tensor` :raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager """ assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' self._seed_states[parallel_mode] = state def set_mode(self, parallel_mode: ParallelMode): """Sets the current mode of the seed manager. :param parallel_mode: The chosen parallel mode :type parallel_mode: :class:`colossalai.context.ParallelMode` """ if self.current_mode: # save the current state for current mode self._seed_states[self._current_mode] = torch.cuda.get_rng_state() # set the new state for new mode self._current_mode = parallel_mode torch.cuda.set_rng_state(self._seed_states[parallel_mode]) def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False): """Adds a seed to the seed manager for `parallel_mode`. :param parallel_mode: The chosen parallel mode :type parallel_mode: :class:`colossalai.context.ParallelMode` :param seed: The seed to be added :type seed: int :param overwrtie: Whether allows to overwrite the seed that has been set already :type overwrtie: bool, optional :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added """ assert isinstance( parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' if overwrtie is False: assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' elif parallel_mode in self._seed_states: print(f"Warnning: {parallel_mode} seed has been overwritten.", flush=True) current_state = torch.cuda.get_rng_state() torch.cuda.manual_seed(seed) self._seed_states[parallel_mode] = torch.cuda.get_rng_state() self._seeds[parallel_mode] = seed torch.cuda.set_rng_state(current_state)