diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 4e62c77e9..5a5bee27b 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -5,6 +5,7 @@ import pytest import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.nn import CheckpointModule from colossalai.logging import get_dist_logger from colossalai.testing import parameterize from colossalai.utils import free_port @@ -18,10 +19,10 @@ from colossalai.utils import get_current_device from tests.test_zero_data_parallel.common import CONFIG -class MoeModel(nn.Module): +class MoeModel(CheckpointModule): - def __init__(self): - super().__init__() + def __init__(self, checkpoint: bool = False): + super().__init__(checkpoint) self.proj1 = nn.Linear(4, 16) expert_cls = nn.Linear expert_args_dict = dict(in_features=16, out_features=16) @@ -52,7 +53,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): shard_strategy=shard_strategy_class(), shard_param=True, model_numel_tensor=model_numel_tensor): - model = MoeModel() + model = MoeModel(checkpoint=True) for name, param in model.named_parameters(): assert hasattr(param, 'colo_attr') diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index d90294adf..a2268ea5c 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -31,7 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class): with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True): - zero_model = MoeModel() + zero_model = MoeModel(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) # check whether parameters are identical in ddp @@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class): if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload) - model = MoeModel().half() + model = MoeModel(checkpoint=True).half() col_model_deepcopy(zero_model, model) model = model.cuda() grad_handler = MoeGradientHandler(model) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 7956f86f4..8fbf655f8 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -65,7 +65,7 @@ def _run_test_sharded_optim_v2(cpu_offload, with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), shard_strategy=shard_strategy, shard_param=True): - zero_model = MoeModel() + zero_model = MoeModel(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy, @@ -78,7 +78,7 @@ def _run_test_sharded_optim_v2(cpu_offload, if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device())) - model = MoeModel().half() + model = MoeModel(checkpoint=True).half() col_model_deepcopy(zero_model, model) model = model.cuda().float() @@ -129,4 +129,4 @@ def test_moe_zero_optim(world_size): if __name__ == '__main__': - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=4)