From 3af13a2c3e917ce23d44050cfeea71cfa9f23e81 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 11 Mar 2022 14:40:01 +0800 Subject: [PATCH] [zero] polish ShardedOptimV2 unittest (#385) * place params on cpu after zero init context * polish code * bucketzed cpu gpu tensor transter * find a bug in sharded optim unittest * add offload unittest for ShardedOptimV2. * polish code and make it more robust --- .../zero/sharded_model/sharded_model_v2.py | 4 ++ .../zero/sharded_optim/sharded_optim_v2.py | 4 ++ .../test_sharded_optim_v2.py | 44 ++++++++----------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 55e7b26f0..7510cb68e 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -79,6 +79,10 @@ class ShardedModelV2(nn.Module): self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) self._require_backward_grad_sync: bool = True + @property + def cpu_offload(self): + return self._cpu_offload + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index d5fb44648..b9be80fed 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -44,6 +44,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): super().__init__(optimizer) self.shard_strategy = shard_strategy self.model: ShardedModelV2 = sharded_model + if cpu_offload and not sharded_model.cpu_offload: + raise RuntimeError( + f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload" + ) self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index d9a003458..aa8735c26 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -24,8 +24,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False): model.train() optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(data) - loss = criterion(y, label) + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() if isinstance(model, ShardedModelV2): optimizer.backward(loss) @@ -34,19 +38,7 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False): optimizer.step() -def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - loss = model(data, label) - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, cpu_offload): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') test_models = ['repeated_computed_layers', 'resnet18', 'bert'] for model_name in test_models: @@ -54,33 +46,33 @@ def run_dist(rank, world_size, port): shard_strategy = TensorShardStrategy() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model = model(checkpoint=True).cuda() - zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy) + zero_model = ShardedModelV2(copy.deepcopy(model), + shard_strategy, + offload_config=dict(device='cpu') if cpu_offload else None) if dist.get_world_size() > 1: model = DDP(model) optim = Adam(model.parameters(), lr=1e-3) sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), zero_model, shard_strategy, + cpu_offload=cpu_offload, initial_scale=2**5) for i, (data, label) in enumerate(train_dataloader): if i > 2: break data, label = data.cuda(), label.cuda() - if criterion is None: - run_step_no_criterion(model, optim, data, label, False) - run_step_no_criterion(zero_model, sharded_optim, data, label, False) - else: - run_step(model, optim, data, label, criterion, False) - run_step(zero_model, sharded_optim, data, label, criterion, False) + run_step(model, optim, data, label, criterion, False) + run_step(zero_model, sharded_optim, data, label, criterion, False) check_sharded_params_padding(model, zero_model, loose=True) @pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2, 4]) -def test_sharded_optim_v2(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +@pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("cpu_offload", [True, False]) +def test_sharded_optim_v2(world_size, cpu_offload): + run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_sharded_optim_v2(world_size=2) + test_sharded_optim_v2(world_size=2, cpu_offload=True)