2022-11-02 08:11:34 +00:00
|
|
|
from collections import OrderedDict
|
2022-05-21 05:52:04 +00:00
|
|
|
from functools import partial
|
2023-04-04 05:48:16 +00:00
|
|
|
from typing import Iterable, Optional, Set
|
2022-11-02 08:11:34 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
2023-04-04 05:48:16 +00:00
|
|
|
from colossalai.utils import is_ddp_ignored
|
2022-08-02 07:49:13 +00:00
|
|
|
|
2022-11-02 08:11:34 +00:00
|
|
|
from .reducer import Reducer
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
|
|
"""Free underlying storage of a Tensor."""
|
|
|
|
if data.storage().size() > 0:
|
|
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
|
|
# is the sole occupant of the Storage.
|
|
|
|
assert data.storage_offset() == 0
|
|
|
|
data.storage().resize_(0)
|
|
|
|
|
|
|
|
|
2022-06-15 07:57:04 +00:00
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
|
|
args = args.to(dtype)
|
|
|
|
elif isinstance(args, (list, tuple)):
|
|
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
|
|
elif isinstance(args, dict):
|
|
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2022-05-21 05:52:04 +00:00
|
|
|
class ColoDDP(torch.nn.Module):
|
2022-06-21 08:35:23 +00:00
|
|
|
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
|
|
|
|
|
2022-07-21 07:54:53 +00:00
|
|
|
Example:
|
2022-06-21 08:35:23 +00:00
|
|
|
>>> from colossalai.core import global_context as gpc
|
|
|
|
>>> from colossalai.context import ParallelMode
|
|
|
|
>>> model = torch.nn.Linear(20, 1)
|
2022-07-04 10:54:37 +00:00
|
|
|
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
|
|
|
>>> model = ColoDDP(model, pg)
|
2022-06-21 08:35:23 +00:00
|
|
|
>>> logits = model(x)
|
|
|
|
>>> loss = criterion(logits, labels)
|
|
|
|
>>> model.backward(loss)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
module (torch.nn.Module): Module to apply DDP.
|
|
|
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
|
|
|
If it's None, the default data parallel group will be used. Defaults to None.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
module: torch.nn.Module,
|
2022-07-04 10:54:37 +00:00
|
|
|
process_group: ColoProcessGroup,
|
2022-06-29 02:34:13 +00:00
|
|
|
bucket_cap_mb: int = 25,
|
|
|
|
rebuild_bucket: bool = True) -> None:
|
2022-06-21 08:35:23 +00:00
|
|
|
assert not isinstance(module, ColoDDP)
|
2022-05-21 05:52:04 +00:00
|
|
|
super().__init__()
|
|
|
|
self.module = module
|
|
|
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
2022-07-04 10:54:37 +00:00
|
|
|
assert process_group
|
|
|
|
|
2022-07-05 06:58:28 +00:00
|
|
|
self.process_group = process_group
|
|
|
|
self.dp_world_size = self.process_group.dp_world_size()
|
|
|
|
|
2022-06-29 02:34:13 +00:00
|
|
|
self.reducer = Reducer(bucket_cap_mb)
|
|
|
|
self.rebuild_bucket = rebuild_bucket
|
2022-05-21 05:52:04 +00:00
|
|
|
for p in module.parameters():
|
2023-01-11 04:22:45 +00:00
|
|
|
if is_ddp_ignored(p):
|
2022-06-16 04:54:46 +00:00
|
|
|
continue
|
2022-05-21 05:52:04 +00:00
|
|
|
if p.requires_grad:
|
|
|
|
p.register_hook(partial(self.grad_handle, p))
|
|
|
|
|
|
|
|
def parameters(self, recurse: bool = True):
|
|
|
|
return self.module.parameters(recurse)
|
|
|
|
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
|
|
|
return self.module.named_parameters(prefix, recurse)
|
|
|
|
|
2022-08-02 07:49:13 +00:00
|
|
|
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
|
|
|
return self.module.named_buffers(prefix, recurse)
|
|
|
|
|
|
|
|
def named_children(self):
|
|
|
|
return self.module.named_children()
|
|
|
|
|
|
|
|
def named_modules(self,
|
|
|
|
memo: Optional[Set[torch.nn.Module]] = None,
|
|
|
|
prefix: str = '',
|
|
|
|
remove_duplicate: bool = True):
|
|
|
|
return self.module.named_modules(memo, prefix, remove_duplicate)
|
|
|
|
|
2022-05-21 05:52:04 +00:00
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
|
|
loss.backward()
|
2022-06-29 02:34:13 +00:00
|
|
|
with torch.cuda.stream(self.comm_stream):
|
|
|
|
self.reducer.flush()
|
2022-05-21 05:52:04 +00:00
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
2022-06-29 02:34:13 +00:00
|
|
|
if self.rebuild_bucket:
|
|
|
|
self.reducer.free()
|
2022-05-21 05:52:04 +00:00
|
|
|
for p in self.module.parameters():
|
2023-01-11 04:22:45 +00:00
|
|
|
if is_ddp_ignored(p):
|
2022-06-16 04:54:46 +00:00
|
|
|
continue
|
2022-06-06 09:36:11 +00:00
|
|
|
if p.grad.device.type != "cpu":
|
|
|
|
p.grad = p._saved_grad
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
def grad_handle(self, p, grad):
|
2022-06-06 09:36:11 +00:00
|
|
|
if grad.device.type != "cpu":
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
|
if self.dp_world_size > 1:
|
|
|
|
grad = grad / self.dp_world_size
|
2022-06-03 04:09:49 +00:00
|
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
with torch.cuda.stream(self.comm_stream):
|
2022-06-29 02:34:13 +00:00
|
|
|
self.reducer.all_reduce_async(grad,
|
2022-07-05 06:58:28 +00:00
|
|
|
group=self.process_group.dp_process_group(),
|
2022-06-29 02:34:13 +00:00
|
|
|
callback_fn=partial(self._save_grad, p))
|
2022-06-03 04:09:49 +00:00
|
|
|
grad.record_stream(self.comm_stream)
|
|
|
|
else:
|
2022-05-21 05:52:04 +00:00
|
|
|
ColoDDP._save_grad(p, grad)
|
2022-06-06 09:36:11 +00:00
|
|
|
return empty_grad
|
2022-06-09 12:56:34 +00:00
|
|
|
|
2022-05-21 05:52:04 +00:00
|
|
|
else:
|
2022-07-18 06:14:52 +00:00
|
|
|
# TODO(jiaruifang) fixme
|
2022-07-05 06:58:28 +00:00
|
|
|
self.process_group.set_cpu_groups()
|
|
|
|
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
2022-06-06 09:36:11 +00:00
|
|
|
return grad
|
2022-05-21 05:52:04 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _save_grad(p, grad):
|
|
|
|
if hasattr(p, '_saved_grad'):
|
|
|
|
p._saved_grad.add_(grad)
|
|
|
|
else:
|
|
|
|
p._saved_grad = grad
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
for p in self.module.parameters():
|
|
|
|
if getattr(p, '_saved_grad', None) is not None:
|
|
|
|
if set_to_none:
|
|
|
|
p._saved_grad = None
|
|
|
|
else:
|
|
|
|
if p._saved_grad.grad_fn is not None:
|
|
|
|
p._saved_grad.detach_()
|
|
|
|
else:
|
|
|
|
p._saved_grad.requires_grad_(False)
|
|
|
|
p._saved_grad.zero_()
|
2022-05-31 04:00:12 +00:00
|
|
|
|
2022-06-16 04:54:46 +00:00
|
|
|
@staticmethod
|
|
|
|
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
|
|
|
"""Sets parameters to be ignored by DDP.
|
|
|
|
This method must be called before initializing ColoDDP.
|
|
|
|
|
2022-07-21 07:54:53 +00:00
|
|
|
Example:
|
2022-06-16 04:54:46 +00:00
|
|
|
>>> params_to_ignore = []
|
|
|
|
>>> for p in module.parameters():
|
|
|
|
>>> if should_ignore(p):
|
|
|
|
>>> params_to_ignore.append(p)
|
|
|
|
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
|
|
|
>>> module = ColoDDP(module)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
|
|
|
"""
|
|
|
|
for p in params_to_ignore:
|
|
|
|
p._ddp_to_ignore = True
|
|
|
|
|
2022-06-20 02:51:47 +00:00
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
|
|
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
|
|
|
return self.module.load_state_dict(state_dict, strict)
|