mirror of https://github.com/hpcaitech/ColossalAI
fixed bug in activation checkpointing test (#387)
parent
3af13a2c3e
commit
1e4bf85cdb
|
@ -1,9 +1,7 @@
|
||||||
from ._helper import (seed, set_mode, with_seed, add_seed,
|
from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states,
|
||||||
get_seeds, get_states, get_current_mode,
|
sync_states, moe_set_seed, reset_seeds)
|
||||||
set_seed_states, sync_states, moe_set_seed)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds',
|
'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
|
||||||
'get_states', 'get_current_mode', 'set_seed_states', 'sync_states',
|
'sync_states', 'moe_set_seed', 'reset_seeds'
|
||||||
'moe_set_seed'
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -154,4 +154,9 @@ def moe_set_seed(seed):
|
||||||
global_rank = gpc.get_global_rank()
|
global_rank = gpc.get_global_rank()
|
||||||
add_seed(ParallelMode.TENSOR, global_rank, True)
|
add_seed(ParallelMode.TENSOR, global_rank, True)
|
||||||
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
|
print(f"moe seed condition: {global_rank} with moe seed {moe_mp_seed}, ",
|
||||||
f"tensor seed {global_rank}", flush=True)
|
f"tensor seed {global_rank}",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_seeds():
|
||||||
|
_SEED_MANAGER.reset()
|
||||||
|
|
|
@ -66,8 +66,7 @@ class SeedManager:
|
||||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
|
||||||
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
||||||
parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
|
|
||||||
if overwrtie is False:
|
if overwrtie is False:
|
||||||
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added'
|
||||||
elif parallel_mode in self._seed_states:
|
elif parallel_mode in self._seed_states:
|
||||||
|
@ -78,3 +77,8 @@ class SeedManager:
|
||||||
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
||||||
self._seeds[parallel_mode] = seed
|
self._seeds[parallel_mode] = seed
|
||||||
torch.cuda.set_rng_state(current_state)
|
torch.cuda.set_rng_state(current_state)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._current_mode = None
|
||||||
|
self._seeds = dict()
|
||||||
|
self._seed_states = dict()
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.context.random import add_seed, seed, set_mode
|
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||||
from colossalai.utils import checkpoint
|
from colossalai.utils import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ def forward(x, weight):
|
||||||
out_ = F.dropout(out, p=0.4, training=True)
|
out_ = F.dropout(out, p=0.4, training=True)
|
||||||
return out_
|
return out_
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||||
def test_activation_checkpointing(cpu_offload):
|
def test_activation_checkpointing(cpu_offload):
|
||||||
if cpu_offload:
|
add_seed(ParallelMode.GLOBAL, 1024)
|
||||||
add_seed(ParallelMode.GLOBAL, 1024)
|
add_seed(ParallelMode.DATA, 1026)
|
||||||
add_seed(ParallelMode.DATA, 1026)
|
|
||||||
set_mode(ParallelMode.GLOBAL)
|
set_mode(ParallelMode.GLOBAL)
|
||||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
set_mode(ParallelMode.DATA)
|
set_mode(ParallelMode.DATA)
|
||||||
|
@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload):
|
||||||
|
|
||||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# as seed manager is singleton
|
||||||
|
# if we don't reset seeds here,
|
||||||
|
# other tests will fail if running together with this test
|
||||||
|
# as other tests can't overwrite the seed set by this test
|
||||||
|
reset_seeds()
|
||||||
|
|
Loading…
Reference in New Issue