diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 07bac1511..6721dc8b8 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -1,10 +1,8 @@ -import imp from functools import partial import torch import torch.distributed as dist from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import CPUAdam from colossalai.utils import checkpoint from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 @@ -20,23 +18,22 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, use_memory_tracer=False, shard_strategy=TensorShardStrategy) -_ZERO_OPTIMIZER_CONFIG = dict( - cpu_offload=False, - initial_scale=2**5, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale=2**32, - lr=1e-3) +_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, + initial_scale=2**5, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale=2**32, + lr=1e-3) ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), zero=dict( model_config=_ZERO_MODEL_CONFIG, optimizer_config=_ZERO_OPTIMIZER_CONFIG, -), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) + ), + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) CONFIG = dict(fp16=dict(mode=None,), zero=dict(level=3, diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 2cf22c063..cab8de7d6 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -10,8 +10,7 @@ import torch.multiprocessing as mp from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, - TensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy @@ -22,10 +21,10 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True]) -@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy): +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - shard_strategy = shard_strategy() + shard_strategy = shard_strategy_class() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, _, criterion = get_components_func() diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index aa552b062..7382d879a 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, - TensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_optim import ShardedOptimizerV2 @@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @parameterize("cpu_offload", [True, False]) @parameterize("use_cpuadam", [True, False]) -@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam): +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - shard_strategy = shard_strategy() + shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return diff --git a/tests/test_zero_data_parallel/test_state_dict.py b/tests/test_zero_data_parallel/test_state_dict.py index d434a53e5..091f711d1 100644 --- a/tests/test_zero_data_parallel/test_state_dict.py +++ b/tests/test_zero_data_parallel/test_state_dict.py @@ -8,20 +8,21 @@ import colossalai import pytest import torch import torch.multiprocessing as mp +from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize + from common import CONFIG -@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_zero_state_dict(shard_strategy): +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_zero_state_dict(shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18'] - shard_strategy = shard_strategy() + shard_strategy = shard_strategy_class() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()