From 1e4bf85cdb51d54d7d2285a35dd92ef7da78e3a6 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 11 Mar 2022 14:48:11 +0800 Subject: [PATCH] fixed bug in activation checkpointing test (#387) --- colossalai/context/random/__init__.py | 10 ++++------ colossalai/context/random/_helper.py | 7 ++++++- colossalai/context/random/seed_manager.py | 8 ++++++-- tests/test_utils/test_activation_checkpointing.py | 14 +++++++++----- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py index 675fea5aa..422c3676c 100644 --- a/colossalai/context/random/__init__.py +++ b/colossalai/context/random/__init__.py @@ -1,9 +1,7 @@ -from ._helper import (seed, set_mode, with_seed, add_seed, - get_seeds, get_states, get_current_mode, - set_seed_states, sync_states, moe_set_seed) +from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states, + sync_states, moe_set_seed, reset_seeds) __all__ = [ - 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', - 'get_states', 'get_current_mode', 'set_seed_states', 'sync_states', - 'moe_set_seed' + 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', + 'sync_states', 'moe_set_seed', 'reset_seeds' ] diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py index 456731192..107bd04b9 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/context/random/_helper.py @@ -154,4 +154,9 @@ def moe_set_seed(seed): global_rank = gpc.get_global_rank() add_seed(ParallelMode.TENSOR, global_rank, True) print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ", - f"tensor seed {global_rank}", flush=True) + f"tensor seed {global_rank}", + flush=True) + + +def reset_seeds(): + _SEED_MANAGER.reset() diff --git a/colossalai/context/random/seed_manager.py b/colossalai/context/random/seed_manager.py index fae1ce6f2..90bce8b46 100644 --- a/colossalai/context/random/seed_manager.py +++ b/colossalai/context/random/seed_manager.py @@ -66,8 +66,7 @@ class SeedManager: :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added """ - assert isinstance( - parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' if overwrtie is False: assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' elif parallel_mode in self._seed_states: @@ -78,3 +77,8 @@ class SeedManager: self._seed_states[parallel_mode] = torch.cuda.get_rng_state() self._seeds[parallel_mode] = seed torch.cuda.set_rng_state(current_state) + + def reset(self): + self._current_mode = None + self._seeds = dict() + self._seed_states = dict() diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index f4127228d..619ab4bdc 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, seed, set_mode +from colossalai.context.random import add_seed, seed, set_mode, reset_seeds from colossalai.utils import checkpoint @@ -17,12 +17,12 @@ def forward(x, weight): out_ = F.dropout(out, p=0.4, training=True) return out_ + @pytest.mark.gpu @pytest.mark.parametrize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload): - if cpu_offload: - add_seed(ParallelMode.GLOBAL, 1024) - add_seed(ParallelMode.DATA, 1026) + add_seed(ParallelMode.GLOBAL, 1024) + add_seed(ParallelMode.DATA, 1026) set_mode(ParallelMode.GLOBAL) global_cuda_rng_state = torch.cuda.get_rng_state() set_mode(ParallelMode.DATA) @@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload): assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() - + # as seed manager is singleton + # if we don't reset seeds here, + # other tests will fail if running together with this test + # as other tests can't overwrite the seed set by this test + reset_seeds()