mirror of https://github.com/hpcaitech/ColossalAI
fixed bug in activation checkpointing test (#387)
parent
3af13a2c3e
commit
1e4bf85cdb
|
@ -1,9 +1,7 @@
|
|||
from ._helper import (seed, set_mode, with_seed, add_seed,
|
||||
get_seeds, get_states, get_current_mode,
|
||||
set_seed_states, sync_states, moe_set_seed)
|
||||
from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states,
|
||||
sync_states, moe_set_seed, reset_seeds)
|
||||
|
||||
__all__ = [
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
|
||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
|
||||
'moe_set_seed'
|
||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
|
||||
'sync_states', 'moe_set_seed', 'reset_seeds'
|
||||
]
|
||||
|
|
|
@ -154,4 +154,9 @@ def moe_set_seed(seed):
|
|||
global_rank = gpc.get_global_rank()
|
||||
add_seed(ParallelMode.TENSOR, global_rank, True)
|
||||
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
|
||||
f"tensor seed {global_rank}", flush=True)
|
||||
f"tensor seed {global_rank}",
|
||||
flush=True)
|
||||
|
||||
|
||||
def reset_seeds():
|
||||
_SEED_MANAGER.reset()
|
||||
|
|
|
@ -66,8 +66,7 @@ class SeedManager:
|
|||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||
"""
|
||||
assert isinstance(
|
||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||
if overwrtie is False:
|
||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||
elif parallel_mode in self._seed_states:
|
||||
|
@ -78,3 +77,8 @@ class SeedManager:
|
|||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||
self._seeds[parallel_mode] = seed
|
||||
torch.cuda.set_rng_state(current_state)
|
||||
|
||||
def reset(self):
|
||||
self._current_mode = None
|
||||
self._seeds = dict()
|
||||
self._seed_states = dict()
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, seed, set_mode
|
||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
|
||||
|
@ -17,12 +17,12 @@ def forward(x, weight):
|
|||
out_ = F.dropout(out, p=0.4, training=True)
|
||||
return out_
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload):
|
||||
if cpu_offload:
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
set_mode(ParallelMode.DATA)
|
||||
|
@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload):
|
|||
|
||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
# as other tests can't overwrite the seed set by this test
|
||||
reset_seeds()
|
||||
|
|
Loading…
Reference in New Issue