Browse Source

fixed bug in activation checkpointing test (#387)

pull/383/head^2
Frank Lee 3 years ago
parent
commit
1e4bf85cdb
  1. 10
      colossalai/context/random/__init__.py
  2. 7
      colossalai/context/random/_helper.py
  3. 8
      colossalai/context/random/seed_manager.py
  4. 14
      tests/test_utils/test_activation_checkpointing.py

10
colossalai/context/random/__init__.py

@ -1,9 +1,7 @@
from ._helper import (seed, set_mode, with_seed, add_seed,
get_seeds, get_states, get_current_mode,
set_seed_states, sync_states, moe_set_seed)
from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states,
sync_states, moe_set_seed, reset_seeds)
__all__ = [
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
'moe_set_seed'
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
'sync_states', 'moe_set_seed', 'reset_seeds'
]

7
colossalai/context/random/_helper.py

@ -154,4 +154,9 @@ def moe_set_seed(seed):
global_rank = gpc.get_global_rank()
add_seed(ParallelMode.TENSOR, global_rank, True)
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
f"tensor seed {global_rank}", flush=True)
f"tensor seed {global_rank}",
flush=True)
def reset_seeds():
_SEED_MANAGER.reset()

8
colossalai/context/random/seed_manager.py

@ -66,8 +66,7 @@ class SeedManager:
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
assert isinstance(
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrtie is False:
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
elif parallel_mode in self._seed_states:
@ -78,3 +77,8 @@ class SeedManager:
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
self._seeds[parallel_mode] = seed
torch.cuda.set_rng_state(current_state)
def reset(self):
self._current_mode = None
self._seeds = dict()
self._seed_states = dict()

14
tests/test_utils/test_activation_checkpointing.py

@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.random import add_seed, seed, set_mode
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
from colossalai.utils import checkpoint
@ -17,12 +17,12 @@ def forward(x, weight):
out_ = F.dropout(out, p=0.4, training=True)
return out_
@pytest.mark.gpu
@pytest.mark.parametrize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload):
if cpu_offload:
add_seed(ParallelMode.GLOBAL, 1024)
add_seed(ParallelMode.DATA, 1026)
add_seed(ParallelMode.GLOBAL, 1024)
add_seed(ParallelMode.DATA, 1026)
set_mode(ParallelMode.GLOBAL)
global_cuda_rng_state = torch.cuda.get_rng_state()
set_mode(ParallelMode.DATA)
@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload):
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
torch.cuda.empty_cache()
# as seed manager is singleton
# if we don't reset seeds here,
# other tests will fail if running together with this test
# as other tests can't overwrite the seed set by this test
reset_seeds()

Loading…
Cancel
Save