mirror of https://github.com/hpcaitech/ColossalAI
[zero] add zero optimizer for ColoTensor (#1046)
* add zero optimizer * torch ok * unit test ok * polish code * fix bugs * polish unit test * polish zero optim * polish colo ddp v2 * refactor folder structure * add comment * polish unit test * polish zero optim * polish unit testpull/1052/head
parent
e32470b6de
commit
51b9a49655
|
@ -8,5 +8,6 @@ from .lars import Lars
|
|||
from .cpu_adam import CPUAdam
|
||||
from .hybrid_adam import HybridAdam
|
||||
|
||||
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD',
|
||||
'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT']
|
||||
__all__ = [
|
||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
|
||||
]
|
||||
|
|
|
@ -4,7 +4,8 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState
|
||||
from colossalai.tensor.param_op_hook import use_param_op_hooks
|
||||
|
||||
__all__ = ['ColoDDP', 'ColoDDPV2']
|
||||
|
||||
|
@ -87,27 +88,23 @@ class ColoDDPV2(ColoDDP):
|
|||
self.chunk_manager = chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(chunk_manager)
|
||||
self.fp32_params = []
|
||||
self.overflow_counter = 0
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert p.dtype == torch.half
|
||||
fp32_p = p.float()
|
||||
fp32_p = p.float().detach()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
|
||||
if not self.chunk_manager.is_chunk_free(p):
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
|
||||
with use_param_op_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return outputs
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
for p in self.module.parameters():
|
||||
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
|
||||
|
@ -115,6 +112,16 @@ class ColoDDPV2(ColoDDP):
|
|||
else:
|
||||
p.grad = p.data
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
self._post_backward()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
|
@ -123,8 +130,11 @@ class ColoDDPV2(ColoDDP):
|
|||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
self.chunk_manager.reduce_chunk(p)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
reduced = self.chunk_manager.reduce_chunk(p)
|
||||
self.chunk_manager.release_chunk(p)
|
||||
if reduced and not chunk.is_free:
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
|
|
|
@ -153,6 +153,11 @@ class Chunk:
|
|||
def __repr__(self) -> str:
|
||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self.data[:self.utilized_size]).any().item()
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
|
@ -230,11 +235,12 @@ class ChunkManager:
|
|||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, tensor: torch.Tensor) -> None:
|
||||
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if not chunk.can_reduce:
|
||||
return
|
||||
return False
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from .zero_optimizer import ZeroOptimizer
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
|
@ -35,4 +36,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
|||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from colossalai.tensor import ParamOpHook, ChunkManager, TensorState
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from typing import Dict
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: ColoDDPV2,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, ColoDDPV2)
|
||||
self.module = module
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
for p, fp32_p in zip(module.parameters(), module.fp32_params):
|
||||
self.fp16_param_to_fp32_param[p] = fp32_p
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.is_chunk_free(p):
|
||||
p.data = self.fp16_param_to_fp32_param[p]
|
||||
else:
|
||||
assert p.grad is None
|
||||
|
||||
def _update_fp16_params(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.is_chunk_free(p):
|
||||
# TODO(ver217): copy chunk
|
||||
fp32_p = self.fp16_param_to_fp32_param[p]
|
||||
self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.module.overflow_counter)
|
||||
|
||||
# all-reduce across global group
|
||||
dist.all_reduce(self._found_overflow)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_grads(self):
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.loss_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = 0
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
if found_inf:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return
|
||||
self._update_params_ptr()
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
self.module.backward_by_grad(tensor, grad)
|
|
@ -0,0 +1,93 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
if p.storage().size() > 0:
|
||||
assert p.dtype == torch.half
|
||||
assert tensor_equal(torch_p, p), f'{torch_p} vs {p}'
|
||||
|
||||
|
||||
def run_step(model, criterion, optimizer, input_ids, attn_mask):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
return logits
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
def run_gpt(use_chunk, use_zero):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda().half()
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
|
||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
model = ColoDDPV2(model, chunk_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# print(chunk_manager)
|
||||
check_param_equal(model, torch_model)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
logits = run_step(model, criterion, optim, input_ids, attn_mask)
|
||||
torch_logits = run_step(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert tensor_equal(logits, torch_logits)
|
||||
check_param_equal(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(4)
|
Loading…
Reference in New Issue