From d0ae0f2215750a158dd66d4428e0ed6b5f5141a9 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 9 Mar 2022 16:09:36 +0800 Subject: [PATCH] [zero] update sharded optim v2 (#334) --- .../zero/sharded_model/sharded_model_v2.py | 22 ++++-- .../zero/sharded_optim/sharded_optim_v2.py | 73 +++++++++++------- tests/test_zero_data_parallel/common.py | 6 +- .../test_shard_model_v2.py | 7 +- .../test_sharded_optim_v2.py | 75 ++++++++++++------- 5 files changed, 115 insertions(+), 68 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 3531488db..f6c5e10f5 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module): # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() self.reducer.free() + # In case some post bwd hook is not fired + if self.shard_param: + for p in self.module.parameters(): + if not p.col_attr.param_is_sharded: + self.shard_strategy.shard([p.col_attr.data]) for p in self.module.parameters(): p.col_attr.bwd_count = 0 if not p.requires_grad: @@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module): if not self._require_backward_grad_sync: continue # Write grad back to p.grad and set p.col_attr.grad to None - p.grad.data = p.col_attr.grad + # We have to make sure grad and param have the same shape + # If world size > 1, and sharded param, `.view()` may be not needed + # If world size == 1, and sharded param, `data` is a flatten tensor + # But the shape `grad` is the same as unsharded param + p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape) p.col_attr.grad = None - # In case some post bwd hook is not fired - if self.shard_param: - for p in self.module.parameters(): - if not p.col_attr.param_is_sharded: - self.shard_strategy.shard([p.col_attr.data]) @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: @@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module): if param.col_attr.grad is None: param.col_attr.grad = reduced_grad.data else: - param.col_attr.grad.add_(reduced_grad.data) + # When dp size = 1 + # param.col_attr.grad is local accumulated grad shard (full but flatten) + # But reduced_grad here is full grad + # We should call `view_as` + param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad)) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index ac8b60033..36330c5f6 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, Optional import torch import torch.distributed as dist @@ -8,7 +8,9 @@ from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -26,7 +28,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def __init__(self, optimizer: Optimizer, - sharded_model: Union[nn.Module, ShardedModelV2], + sharded_model: ShardedModelV2, + shard_strategy: BaseShardStrategy, cpu_offload: bool = False, initial_scale: float = 2**32, min_scale: float = 1, @@ -37,9 +40,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): max_scale: int = 2**32, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: + assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' super().__init__(optimizer) - self.model: Union[nn.Module, ShardedModelV2] = sharded_model - self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) + self.shard_strategy = shard_strategy + self.model: ShardedModelV2 = sharded_model self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) @@ -52,20 +56,25 @@ class ShardedOptimizerV2(ColossalaiOptimizer): growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale) - self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device) + self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) - # Store fp32 params + # Store fp32 param shards self.master_params: Dict[Parameter, Tensor] = {} for group in optimizer.param_groups: for p in group['params']: - if hasattr(p, 'ca_attr'): - assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' - self.master_params[p] = p.ca_attr.payload(self.device) - else: - self.master_params[p] = p.data.to(device=self.device) - if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float: - self.master_params[p] = self.master_params[p].to(torch.float) + assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' + is_param_sharded = p.col_attr.data.is_sharded + if not is_param_sharded: + # TODO (ver217): we may not use shard / gather here + # Param is no sharded, which means we use ZeRO-2 here + # As we only store param shard, we shard it here + self.shard_strategy.shard([p.col_attr.data]) + self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device) + if not is_param_sharded: + # In this branch, there's no need to shard param + # So we gather here + self.shard_strategy.gather([p.col_attr.data]) def step(self, *args, **kwargs): # unscale grads if scaled @@ -83,28 +92,36 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for group in self.optim.param_groups: for p in group['params']: p.data = self.master_params[p] + # Now p.data is sharded + # So optimizer states are sharded naturally ret = self.optim.step(*args, **kwargs) # Write master param to payload for group in self.optim.param_groups: for p in group['params']: - if hasattr(p, 'ca_attr'): - p.ca_attr.set_payload(p.data) - p.data = p.ca_attr.payload() + is_param_sharded = p.col_attr.data.is_sharded + if not is_param_sharded: + # We use ZeRO-2 here + # The `p.col_attr.data` saves full fp16 param + # But we only have updated fp32 param shard here + # So we first shard full fp16 param and copy fp32 param shard to it + # Then we will gather them + self.shard_strategy.shard([p.col_attr.data]) + # We have to use `copy_payload` instead of `reset_payload` + # Since p.data is fp32 and p.col_attr.data is fp16 + p.col_attr.data.copy_payload(p.data) + if not is_param_sharded: + # We gather full fp16 param here + self.shard_strategy.gather([p.col_attr.data]) + p.data = p.col_attr.data.payload return ret def backward(self, loss: Tensor) -> None: loss = self.loss_scale * loss self.optim_state = OptimState.SCALED - if self.model_is_sharded: - self.model.backward(loss) - else: - super().backward(loss) + self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: - if self.model_is_sharded: - self.model.backward_by_grad(tensor, grad) - else: - super().backward_by_grad(tensor, grad) + self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): if self.optim_state == OptimState.SCALED: @@ -113,7 +130,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): @property def loss_scale(self): - return self.grad_scaler.scale + return self.grad_scaler.scale.item() def _check_overflow(self): # clear previous overflow record @@ -141,3 +158,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if p.grad is not None: p.grad.data.div_(self.loss_scale) self.optim_state = OptimState.UNSCALED + + def zero_grad(self, *args, **kwargs): + # We must set grad to None + # Because we will judge whether local grad accumulation + # is enabled by wheter grad is None + self.optim.zero_grad(set_to_none=True) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index ff5bc5902..b4677f06f 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -95,12 +95,12 @@ def check_params_padding(model, zero_model, loose=False): def check_sharded_params_padding(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.ca_attr.payload(p.device) + zero_p = zero_p.col_attr.data.payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue - p = chunks[rank] + p = chunks[rank].float() if zero_p.size(0) > p.size(0): zero_p = zero_p[:p.size(0)] assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) + assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index af54adfaf..919b74ed3 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -17,7 +17,7 @@ from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from common import CONFIG, check_grads, check_grads_padding +from common import CONFIG, check_grads_padding def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): @@ -69,10 +69,7 @@ def run_dist(rank, world_size, port): run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, False) - if dist.get_world_size() > 1: - check_grads_padding(model, zero_model, loose=True) - else: - check_grads(model, zero_model, loose=True) + check_grads_padding(model, zero_model, loose=True) @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 6f80e2dd3..d9a003458 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -9,22 +9,23 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.utils import free_port +from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam -from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_sharded_params_padding) +from common import CONFIG, check_sharded_params_padding -def run_step(model, optimizer, x, enable_autocast=False): +def run_step(model, optimizer, data, label, criterion, enable_autocast=False): model.train() optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(x) - loss = y.sum() + y = model(data) + loss = criterion(y, label) loss = loss.float() if isinstance(model, ShardedModelV2): optimizer.backward(loss) @@ -33,35 +34,53 @@ def run_step(model, optimizer, x, enable_autocast=False): optimizer.step() -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') +def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False): + model.train() + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=enable_autocast): + loss = model(data, label) + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + optimizer.step() - model = Net(checkpoint=True).cuda() - zero_model = copy.deepcopy(model) - zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA)) - for n, p in zero_model.named_parameters(): - p._name = n - optim = Adam(model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), zero_model) - for _ in range(2): - x = torch.rand(2, 5).cuda() - run_step(zero_model, sharded_optim, x, False) - run_step(model, optim, x, False) +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + shard_strategy = TensorShardStrategy() + model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() + model = model(checkpoint=True).cuda() + zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy) if dist.get_world_size() > 1: - check_grads_padding(model, zero_model) - check_sharded_params_padding(model, zero_model) - else: - check_grads(model, zero_model) - check_params(model, zero_model) + model = DDP(model) + optim = Adam(model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), + zero_model, + shard_strategy, + initial_scale=2**5) + for i, (data, label) in enumerate(train_dataloader): + if i > 2: + break + data, label = data.cuda(), label.cuda() + if criterion is None: + run_step_no_criterion(model, optim, data, label, False) + run_step_no_criterion(zero_model, sharded_optim, data, label, False) + else: + run_step(model, optim, data, label, criterion, False) + run_step(zero_model, sharded_optim, data, label, criterion, False) + check_sharded_params_padding(model, zero_model, loose=True) -@pytest.mark.skip -def test_sharded_optim_v2(): - world_size = 2 +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2, 4]) +def test_sharded_optim_v2(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_sharded_optim_v2() + test_sharded_optim_v2(world_size=2)