diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index bded5084e..5ae36cd04 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -12,11 +12,12 @@ from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.optimizer import CPUAdam from common import CONFIG, check_sharded_params_padding -def run_step(model, optimizer, data, label, criterion, enable_autocast=False): +def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): model.train() optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=enable_autocast): @@ -34,13 +35,17 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False): optimizer.step() -def run_dist(rank, world_size, port, cpu_offload, shard_strategy): +def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') test_models = ['repeated_computed_layers', 'resnet18', 'bert'] shard_strategy = shard_strategy() + + if use_cpuadam and cpu_offload is False: + return + for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) - model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model, train_dataloader, _, optimizer_class, criterion = get_components_func() model = model(checkpoint=True).cuda() zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, @@ -48,33 +53,59 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy): if dist.get_world_size() > 1: model = DDP(model) lr = 1e-3 - optim = optimizer_class(model.parameters(), lr=lr) - sharded_optim = ShardedOptimizerV2(zero_model, - optimizer_class, - cpu_offload=cpu_offload, - initial_scale=2**5, - lr=lr) + if use_cpuadam: + optim = torch.optim.Adam(model.parameters(), lr=lr) + sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, cpu_offload=cpu_offload, initial_scale=2**5, lr=lr) + else: + optim = optimizer_class(model.parameters(), lr=lr) + sharded_optim = ShardedOptimizerV2(zero_model, + optimizer_class, + cpu_offload=cpu_offload, + initial_scale=2**5, + lr=lr) for i, (data, label) in enumerate(train_dataloader): - if i > 2: + #FIXME() if i > 5, the unittest will fail + if i > 3: break data, label = data.cuda(), label.cuda() - run_step(model, optim, data, label, criterion, False) - run_step(zero_model, sharded_optim, data, label, criterion, False) + _run_step(model, optim, data, label, criterion, False) + _run_step(zero_model, sharded_optim, data, label, criterion, False) check_sharded_params_padding(model, zero_model, loose=True) +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("cpu_offload", [False]) +@pytest.mark.parametrize("use_cpuadam", [False]) +@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) +def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy, use_cpuadam): + run_func = partial(_run_dist, + world_size=world_size, + port=free_port(), + cpu_offload=cpu_offload, + shard_strategy=shard_strategy, + use_cpuadam=use_cpuadam) + mp.spawn(run_func, nprocs=world_size) + + @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) -@pytest.mark.parametrize("cpu_offload", [True, False]) +@pytest.mark.parametrize("cpu_offload", [True]) @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy): - run_func = partial(run_dist, +@pytest.mark.parametrize("use_cpuadam", [True, False]) +def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_cpuadam): + run_func = partial(_run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload, - shard_strategy=shard_strategy) + shard_strategy=shard_strategy, + use_cpuadam=use_cpuadam) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy) \ No newline at end of file + test_sharded_optim_v2_cpu_adam(world_size=2, + cpu_offload=False, + shard_strategy=TensorShardStrategy, + use_cpuadam=True) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py b/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py deleted file mode 100644 index 424ca3a65..000000000 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import copy -from functools import partial - -import colossalai -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.nn.optimizer import CPUAdam -from colossalai.utils import free_port -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Adam - -from common import CONFIG, check_sharded_params_padding - - -def run_step(model, optimizer, data, label, criterion, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(data) - loss = criterion(y, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - loss = model(data, label) - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -def run_dist(rank, world_size, port, shard_strategy): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - shard_strategy = shard_strategy() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - model = model(checkpoint=True).cuda() - zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'}) - if dist.get_world_size() > 1: - model = DDP(model) - optim = Adam(model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, initial_scale=2**5, cpu_offload=True, lr=1e-3) - for i, (data, label) in enumerate(train_dataloader): - if i > 2: - break - data, label = data.cuda(), label.cuda() - if criterion is None: - run_step_no_criterion(model, optim, data, label, False) - run_step_no_criterion(zero_model, sharded_optim, data, label, False) - else: - run_step(model, optim, data, label, criterion, False) - run_step(zero_model, sharded_optim, data, label, criterion, False) - check_sharded_params_padding(model, zero_model, loose=True) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def test_sharded_optim_v2(world_size, shard_strategy): - run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)