Browse Source

[zero] update sharded optim v2 (#334)

pull/394/head
ver217 3 years ago committed by Frank Lee
parent
commit
d0ae0f2215
  1. 22
      colossalai/zero/sharded_model/sharded_model_v2.py
  2. 73
      colossalai/zero/sharded_optim/sharded_optim_v2.py
  3. 6
      tests/test_zero_data_parallel/common.py
  4. 7
      tests/test_zero_data_parallel/test_shard_model_v2.py
  5. 75
      tests/test_zero_data_parallel/test_sharded_optim_v2.py

22
colossalai/zero/sharded_model/sharded_model_v2.py

@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
p.grad.data = p.col_attr.grad
# We have to make sure grad and param have the same shape
# If world size > 1, and sharded param, `.view()` may be not needed
# If world size == 1, and sharded param, `data` is a flatten tensor
# But the shape `grad` is the same as unsharded param
p.grad.data = p.col_attr.grad.view(p.col_attr.data.shape)
p.col_attr.grad = None
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module):
if param.col_attr.grad is None:
param.col_attr.grad = reduced_grad.data
else:
param.col_attr.grad.add_(reduced_grad.data)
# When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])

73
colossalai/zero/sharded_optim/sharded_optim_v2.py

@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Optional, Union
from typing import Dict, Optional
import torch
import torch.distributed as dist
@ -8,7 +8,9 @@ from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -26,7 +28,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
optimizer: Optimizer,
sharded_model: Union[nn.Module, ShardedModelV2],
sharded_model: ShardedModelV2,
shard_strategy: BaseShardStrategy,
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -37,9 +40,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.shard_strategy = shard_strategy
self.model: ShardedModelV2 = sharded_model
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -52,20 +56,25 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
# Store fp32 params
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups:
for p in group['params']:
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
self.master_params[p] = p.ca_attr.payload(self.device)
else:
self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
self.master_params[p] = self.master_params[p].to(torch.float)
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.data])
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.col_attr.data])
def step(self, *args, **kwargs):
# unscale grads if scaled
@ -83,28 +92,36 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs)
# Write master param to payload
for group in self.optim.param_groups:
for p in group['params']:
if hasattr(p, 'ca_attr'):
p.ca_attr.set_payload(p.data)
p.data = p.ca_attr.payload()
is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.col_attr.data` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.col_attr.data])
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
p.col_attr.data.copy_payload(p.data)
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.data])
p.data = p.col_attr.data.payload
return ret
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if self.model_is_sharded:
self.model.backward(loss)
else:
super().backward(loss)
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
if self.model_is_sharded:
self.model.backward_by_grad(tensor, grad)
else:
super().backward_by_grad(tensor, grad)
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
@ -113,7 +130,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
@property
def loss_scale(self):
return self.grad_scaler.scale
return self.grad_scaler.scale.item()
def _check_overflow(self):
# clear previous overflow record
@ -141,3 +158,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
def zero_grad(self, *args, **kwargs):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)

6
tests/test_zero_data_parallel/common.py

@ -95,12 +95,12 @@ def check_params_padding(model, zero_model, loose=False):
def check_sharded_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.ca_attr.payload(p.device)
zero_p = zero_p.col_attr.data.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

7
tests/test_zero_data_parallel/test_shard_model_v2.py

@ -17,7 +17,7 @@ from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads, check_grads_padding
from common import CONFIG, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
@ -69,10 +69,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
check_grads_padding(model, zero_model, loose=True)
@pytest.mark.dist

75
tests/test_zero_data_parallel/test_sharded_optim_v2.py

@ -9,22 +9,23 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_sharded_params_padding)
from common import CONFIG, check_sharded_params_padding
def run_step(model, optimizer, x, enable_autocast=False):
def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
@ -33,35 +34,53 @@ def run_step(model, optimizer, x, enable_autocast=False):
optimizer.step()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
loss = model(data, label)
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
else:
loss.backward()
optimizer.step()
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
for n, p in zero_model.named_parameters():
p._name = n
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), zero_model)
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_step(zero_model, sharded_optim, x, False)
run_step(model, optim, x, False)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model)
check_sharded_params_padding(model, zero_model)
else:
check_grads(model, zero_model)
check_params(model, zero_model)
model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
zero_model,
shard_strategy,
initial_scale=2**5)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.cuda(), label.cuda()
if criterion is None:
run_step_no_criterion(model, optim, data, label, False)
run_step_no_criterion(zero_model, sharded_optim, data, label, False)
else:
run_step(model, optim, data, label, criterion, False)
run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True)
@pytest.mark.skip
def test_sharded_optim_v2():
world_size = 2
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_sharded_optim_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_v2()
test_sharded_optim_v2(world_size=2)

Loading…
Cancel
Save