2022-05-21 05:52:04 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
from functools import partial
|
2022-05-31 04:00:12 +00:00
|
|
|
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
2022-06-13 08:11:53 +00:00
|
|
|
from colossalai.tensor.chunk import TensorState, Chunk
|
|
|
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
2022-06-10 06:48:28 +00:00
|
|
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
|
|
from typing import Dict
|
|
|
|
from colossalai.logging import get_dist_logger
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
|
|
"""Free underlying storage of a Tensor."""
|
|
|
|
if data.storage().size() > 0:
|
|
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
|
|
# is the sole occupant of the Storage.
|
|
|
|
assert data.storage_offset() == 0
|
|
|
|
data.storage().resize_(0)
|
|
|
|
|
|
|
|
|
2022-06-15 07:57:04 +00:00
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
|
|
args = args.to(dtype)
|
|
|
|
elif isinstance(args, (list, tuple)):
|
|
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
|
|
elif isinstance(args, dict):
|
|
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2022-05-21 05:52:04 +00:00
|
|
|
class ColoDDP(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.module = module
|
|
|
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
|
|
|
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
|
|
|
|
for p in module.parameters():
|
|
|
|
if p.requires_grad:
|
|
|
|
p.register_hook(partial(self.grad_handle, p))
|
|
|
|
|
|
|
|
def parameters(self, recurse: bool = True):
|
|
|
|
return self.module.parameters(recurse)
|
|
|
|
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
|
|
|
return self.module.named_parameters(prefix, recurse)
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
|
|
loss.backward()
|
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
|
|
|
for p in self.module.parameters():
|
2022-06-06 09:36:11 +00:00
|
|
|
if p.grad.device.type != "cpu":
|
|
|
|
p.grad = p._saved_grad
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
def grad_handle(self, p, grad):
|
2022-06-06 09:36:11 +00:00
|
|
|
if grad.device.type != "cpu":
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
|
if self.dp_world_size > 1:
|
|
|
|
grad = grad / self.dp_world_size
|
2022-06-03 04:09:49 +00:00
|
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
with torch.cuda.stream(self.comm_stream):
|
|
|
|
group = gpc.get_group(ParallelMode.DATA)
|
|
|
|
dist.all_reduce(grad, group=group)
|
|
|
|
ColoDDP._save_grad(p, grad)
|
|
|
|
grad.record_stream(self.comm_stream)
|
|
|
|
else:
|
2022-05-21 05:52:04 +00:00
|
|
|
ColoDDP._save_grad(p, grad)
|
2022-06-06 09:36:11 +00:00
|
|
|
return empty_grad
|
2022-06-09 12:56:34 +00:00
|
|
|
|
2022-05-21 05:52:04 +00:00
|
|
|
else:
|
2022-06-06 09:36:11 +00:00
|
|
|
group = gpc.get_cpu_group(ParallelMode.DATA)
|
|
|
|
dist.all_reduce(grad, group=group)
|
|
|
|
return grad
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _save_grad(p, grad):
|
|
|
|
if hasattr(p, '_saved_grad'):
|
|
|
|
p._saved_grad.add_(grad)
|
|
|
|
else:
|
|
|
|
p._saved_grad = grad
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
for p in self.module.parameters():
|
|
|
|
if getattr(p, '_saved_grad', None) is not None:
|
|
|
|
if set_to_none:
|
|
|
|
p._saved_grad = None
|
|
|
|
else:
|
|
|
|
if p._saved_grad.grad_fn is not None:
|
|
|
|
p._saved_grad.detach_()
|
|
|
|
else:
|
|
|
|
p._saved_grad.requires_grad_(False)
|
|
|
|
p._saved_grad.zero_()
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ColoDDPV2(ColoDDP):
|
|
|
|
|
2022-06-10 06:48:28 +00:00
|
|
|
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
2022-06-15 07:57:04 +00:00
|
|
|
super().__init__(module.half())
|
2022-06-10 06:48:28 +00:00
|
|
|
self.gemini_manager = gemini_manager
|
|
|
|
self.chunk_manager = gemini_manager.chunk_manager
|
|
|
|
self.param_op_hook = ZeROHookV2(gemini_manager)
|
2022-05-31 04:00:12 +00:00
|
|
|
self.fp32_params = []
|
2022-06-02 04:13:15 +00:00
|
|
|
self.overflow_counter = 0
|
2022-06-10 06:48:28 +00:00
|
|
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
2022-06-15 07:05:19 +00:00
|
|
|
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
|
|
|
self.chunk_manager.create_group('fp32_param')
|
2022-05-31 04:00:12 +00:00
|
|
|
# TODO: get param order and filter unused params
|
|
|
|
for p in module.parameters():
|
|
|
|
assert p.dtype == torch.half
|
2022-06-02 04:13:15 +00:00
|
|
|
fp32_p = p.float().detach()
|
2022-05-31 04:00:12 +00:00
|
|
|
self.chunk_manager.append_tensor(p, 'fp16_param')
|
|
|
|
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
|
|
|
self.fp32_params.append(fp32_p)
|
2022-06-10 06:48:28 +00:00
|
|
|
self.grads_device[p] = self.gemini_manager.default_device
|
|
|
|
self._logger = get_dist_logger()
|
2022-05-31 04:00:12 +00:00
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
2022-06-15 07:57:04 +00:00
|
|
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
2022-05-31 04:00:12 +00:00
|
|
|
self.module.zero_grad(set_to_none=True)
|
2022-06-10 06:48:28 +00:00
|
|
|
self.gemini_manager.pre_iter()
|
2022-06-13 08:11:53 +00:00
|
|
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
2022-05-31 04:00:12 +00:00
|
|
|
outputs = self.module(*args, **kwargs)
|
|
|
|
self.chunk_manager.exec_lazy_release()
|
2022-06-15 07:57:04 +00:00
|
|
|
return _cast_float(outputs, torch.float)
|
2022-05-31 04:00:12 +00:00
|
|
|
|
2022-06-10 06:48:28 +00:00
|
|
|
def _setup_grads_ptr(self):
|
2022-05-31 04:00:12 +00:00
|
|
|
for p in self.module.parameters():
|
2022-06-10 07:33:06 +00:00
|
|
|
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
2022-05-31 04:00:12 +00:00
|
|
|
p.grad = None
|
|
|
|
else:
|
|
|
|
p.grad = p.data
|
|
|
|
|
2022-06-10 06:48:28 +00:00
|
|
|
def _post_backward(self):
|
|
|
|
self.chunk_manager.exec_lazy_release()
|
|
|
|
self._setup_grads_ptr()
|
|
|
|
self._logger.info(
|
|
|
|
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
|
|
|
|
)
|
|
|
|
self.gemini_manager.post_iter()
|
|
|
|
|
2022-06-02 04:13:15 +00:00
|
|
|
def backward(self, loss: torch.Tensor):
|
2022-06-13 08:11:53 +00:00
|
|
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
2022-06-02 04:13:15 +00:00
|
|
|
loss.backward()
|
|
|
|
self._post_backward()
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
2022-06-13 08:11:53 +00:00
|
|
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
2022-06-02 04:13:15 +00:00
|
|
|
torch.autograd.backward(tensor, grad)
|
|
|
|
self._post_backward()
|
|
|
|
|
2022-05-31 04:00:12 +00:00
|
|
|
def grad_handle(self, p, grad):
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
|
with torch._C.DisableTorchFunction():
|
|
|
|
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
|
|
|
if self.dp_world_size > 1:
|
|
|
|
grad = grad / self.dp_world_size
|
|
|
|
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
2022-06-02 04:13:15 +00:00
|
|
|
chunk = self.chunk_manager.get_chunk(p)
|
2022-06-09 12:56:34 +00:00
|
|
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
|
|
|
self.chunk_manager.release_chunk(chunk)
|
2022-06-10 07:33:06 +00:00
|
|
|
if reduced and not chunk.is_empty:
|
2022-06-02 04:13:15 +00:00
|
|
|
self.overflow_counter += chunk.has_inf_or_nan
|
2022-06-10 06:48:28 +00:00
|
|
|
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
2022-05-31 04:00:12 +00:00
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
2022-06-10 06:48:28 +00:00
|
|
|
|
|
|
|
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
|
|
|
for tensor in chunk.get_tensors():
|
|
|
|
self.grads_device[tensor] = device
|