[zero] update sharded optim v2 (#334)

pull/394/head
ver217 2022-03-09 16:09:36 +08:00 committed by Frank Lee
parent 2b8cddd40e
commit d0ae0f2215
5 changed files with 115 additions and 68 deletions

View File

@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish. # Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
self.reducer.free() self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
for p in self.module.parameters(): for p in self.module.parameters():
p.col_attr.bwd_count = 0 p.col_attr.bwd_count = 0
if not p.requires_grad: if not p.requires_grad:
@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
# Write grad back to p.grad and set p.col_attr.grad to None # Write grad back to p.grad and set p.col_attr.grad to None
p.grad.data = p.col_attr.grad # We have to make sure grad and param have the same shape
# If world size > 1, and sharded param, `.view()` may be not needed
# If world size == 1, and sharded param, `data` is a flatten tensor
# But the shape `grad` is the same as unsharded param
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
p.col_attr.grad = None p.col_attr.grad = None
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module):
if param.col_attr.grad is None: if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data param.col_attr.grad = reduced_grad.data
else: else:
param.col_attr.grad.add_(reduced_grad.data) # When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import Dict, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -8,7 +8,9 @@ from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -26,7 +28,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optimizer: Optimizer, optimizer: Optimizer,
sharded_model: Union[nn.Module, ShardedModelV2], sharded_model: ShardedModelV2,
shard_strategy: BaseShardStrategy,
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
@ -37,9 +40,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale: int = 2**32, max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer) super().__init__(optimizer)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model self.shard_strategy = shard_strategy
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) self.model: ShardedModelV2 = sharded_model
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -52,20 +56,25 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval=growth_interval, growth_interval=growth_interval,
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale) max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device) self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
# Store fp32 params # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups: for group in optimizer.param_groups:
for p in group['params']: for p in group['params']:
if hasattr(p, 'ca_attr'): assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' is_param_sharded = p.col_attr.data.is_sharded
self.master_params[p] = p.ca_attr.payload(self.device) if not is_param_sharded:
else: # TODO (ver217): we may not use shard / gather here
self.master_params[p] = p.data.to(device=self.device) # Param is no sharded, which means we use ZeRO-2 here
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float: # As we only store param shard, we shard it here
self.master_params[p] = self.master_params[p].to(torch.float) self.shard_strategy.shard([p.col_attr.data])
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.col_attr.data])
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
# unscale grads if scaled # unscale grads if scaled
@ -83,28 +92,36 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
p.data = self.master_params[p] p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(*args, **kwargs)
# Write master param to payload # Write master param to payload
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
if hasattr(p, 'ca_attr'): is_param_sharded = p.col_attr.data.is_sharded
p.ca_attr.set_payload(p.data) if not is_param_sharded:
p.data = p.ca_attr.payload() # We use ZeRO-2 here
# The `p.col_attr.data` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.col_attr.data])
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
p.col_attr.data.copy_payload(p.data)
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.data])
p.data = p.col_attr.data.payload
return ret return ret
def backward(self, loss: Tensor) -> None: def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED self.optim_state = OptimState.SCALED
if self.model_is_sharded: self.model.backward(loss)
self.model.backward(loss)
else:
super().backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
if self.model_is_sharded: self.model.backward_by_grad(tensor, grad)
self.model.backward_by_grad(tensor, grad)
else:
super().backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float): def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED: if self.optim_state == OptimState.SCALED:
@ -113,7 +130,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
@property @property
def loss_scale(self): def loss_scale(self):
return self.grad_scaler.scale return self.grad_scaler.scale.item()
def _check_overflow(self): def _check_overflow(self):
# clear previous overflow record # clear previous overflow record
@ -141,3 +158,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if p.grad is not None: if p.grad is not None:
p.grad.data.div_(self.loss_scale) p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
def zero_grad(self, *args, **kwargs):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)

View File

@ -95,12 +95,12 @@ def check_params_padding(model, zero_model, loose=False):
def check_sharded_params_padding(model, zero_model, loose=False): def check_sharded_params_padding(model, zero_model, loose=False):
rank = dist.get_rank() rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()): for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.ca_attr.payload(p.device) zero_p = zero_p.col_attr.data.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size()) chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks): if rank >= len(chunks):
continue continue
p = chunks[rank] p = chunks[rank].float()
if zero_p.size(0) > p.size(0): if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)] zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose) assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

View File

@ -17,7 +17,7 @@ from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads, check_grads_padding from common import CONFIG, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
@ -69,10 +69,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1: check_grads_padding(model, zero_model, loose=True)
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
@pytest.mark.dist @pytest.mark.dist

View File

@ -9,22 +9,23 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam from torch.optim import Adam
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_sharded_params_padding) from common import CONFIG, check_sharded_params_padding
def run_step(model, optimizer, x, enable_autocast=False): def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast): with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x) y = model(data)
loss = y.sum() loss = criterion(y, label)
loss = loss.float() loss = loss.float()
if isinstance(model, ShardedModelV2): if isinstance(model, ShardedModelV2):
optimizer.backward(loss) optimizer.backward(loss)
@ -33,35 +34,53 @@ def run_step(model, optimizer, x, enable_autocast=False):
optimizer.step() optimizer.step()
def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
loss = model(data, label)
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
else:
loss.backward()
optimizer.step()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
model = Net(checkpoint=True).cuda() for model_name in test_models:
zero_model = copy.deepcopy(model) get_components_func = non_distributed_component_funcs.get_callable(model_name)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA)) shard_strategy = TensorShardStrategy()
for n, p in zero_model.named_parameters(): model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
p._name = n model = model(checkpoint=True).cuda()
optim = Adam(model.parameters(), lr=1e-3) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), zero_model)
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_step(zero_model, sharded_optim, x, False)
run_step(model, optim, x, False)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
check_grads_padding(model, zero_model) model = DDP(model)
check_sharded_params_padding(model, zero_model) optim = Adam(model.parameters(), lr=1e-3)
else: sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
check_grads(model, zero_model) zero_model,
check_params(model, zero_model) shard_strategy,
initial_scale=2**5)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.cuda(), label.cuda()
if criterion is None:
run_step_no_criterion(model, optim, data, label, False)
run_step_no_criterion(zero_model, sharded_optim, data, label, False)
else:
run_step(model, optim, data, label, criterion, False)
run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True)
@pytest.mark.skip @pytest.mark.dist
def test_sharded_optim_v2(): @pytest.mark.parametrize("world_size", [1, 2, 4])
world_size = 2 def test_sharded_optim_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2() test_sharded_optim_v2(world_size=2)