[zero] add zero optimizer for ColoTensor (#1046)

* add zero optimizer

* torch ok

* unit test ok

* polish code

* fix bugs

* polish unit test

* polish zero optim

* polish colo ddp v2

* refactor folder structure

* add comment

* polish unit test

* polish zero optim

* polish unit test
pull/1052/head
ver217 2022-06-02 12:13:15 +08:00 committed by GitHub
parent e32470b6de
commit 51b9a49655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 245 additions and 15 deletions

View File

@ -8,5 +8,6 @@ from .lars import Lars
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam
from .hybrid_adam import HybridAdam from .hybrid_adam import HybridAdam
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', __all__ = [
'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'] 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
]

View File

@ -4,7 +4,8 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks
__all__ = ['ColoDDP', 'ColoDDPV2'] __all__ = ['ColoDDP', 'ColoDDPV2']
@ -87,27 +88,23 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager) self.param_op_hook = ZeROHookV2(chunk_manager)
self.fp32_params = [] self.fp32_params = []
self.overflow_counter = 0
# TODO: get param order and filter unused params # TODO: get param order and filter unused params
for p in module.parameters(): for p in module.parameters():
assert p.dtype == torch.half assert p.dtype == torch.half
fp32_p = p.float() fp32_p = p.float().detach()
self.chunk_manager.append_tensor(p, 'fp16_param') self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param') self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
if not self.chunk_manager.is_chunk_free(p):
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
with use_param_op_hooks(self.param_op_hook): with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release() self.chunk_manager.exec_lazy_release()
return outputs return outputs
def backward(self, loss: torch.Tensor): def _post_backward(self):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
self.chunk_manager.exec_lazy_release() self.chunk_manager.exec_lazy_release()
for p in self.module.parameters(): for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
@ -115,6 +112,16 @@ class ColoDDPV2(ColoDDP):
else: else:
p.grad = p.data p.grad = p.data
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
@ -123,8 +130,11 @@ class ColoDDPV2(ColoDDP):
if self.dp_world_size > 1: if self.dp_world_size > 1:
grad = grad / self.dp_world_size grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
self.chunk_manager.reduce_chunk(p) chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p) self.chunk_manager.release_chunk(p)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
return empty_grad return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:

View File

@ -153,6 +153,11 @@ class Chunk:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}' return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
@property
def has_inf_or_nan(self) -> bool:
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item()
class ChunkManager: class ChunkManager:
@ -230,11 +235,12 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state) chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> None: def reduce_chunk(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce: if not chunk.can_reduce:
return return False
chunk.reduce(is_all_reduce=not self.enable_distributed_storage) chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from .zero_optimizer import ZeroOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
@ -35,4 +36,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer return zero_model, zero_optimizer
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2'] __all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']

View File

@ -1,5 +1,6 @@
import torch import torch
from colossalai.tensor import ParamOpHook, ChunkManager, TensorState from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.tensor.chunk import ChunkManager, TensorState
from enum import Enum from enum import Enum
from typing import List from typing import List
from contextlib import contextmanager from contextlib import contextmanager

View File

@ -0,0 +1,118 @@
import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self,
optim: Optimizer,
module: ColoDDPV2,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32):
super().__init__(optim)
assert isinstance(module, ColoDDPV2)
self.module = module
self.optim_state = OptimState.UNSCALED
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
for p, fp32_p in zip(module.parameters(), module.fp32_params):
self.fp16_param_to_fp32_param[p] = fp32_p
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
self._logger = get_dist_logger()
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
p.data = self.fp16_param_to_fp32_param[p]
else:
assert p.grad is None
def _update_fp16_params(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
# TODO(ver217): copy chunk
fp32_p = self.fp16_param_to_fp32_param[p]
self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter)
# all-reduce across global group
dist.all_reduce(self._found_overflow)
return self._found_overflow.item() > 0
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = 0
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self._update_fp16_params()
return
self._update_params_ptr()
ret = self.optim.step(*args, **kwargs)
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
self.module.backward_by_grad(tensor, grad)

View File

@ -0,0 +1,93 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.storage().size() > 0:
assert p.dtype == torch.half
assert tensor_equal(torch_p, p), f'{torch_p} vs {p}'
def run_step(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
optimizer.step()
return logits
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_gpt(use_chunk, use_zero):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda().half()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
# print(chunk_manager)
check_param_equal(model, torch_model)
model.train()
torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
logits = run_step(model, criterion, optim, input_ids, attn_mask)
torch_logits = run_step(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits)
check_param_equal(model, torch_model)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(4)