[tensor] impl ColoDDP for ColoTensor (#1009)

* impl ColoDDP for ColoTensor

* polish code
pull/1012/head
ver217 2022-05-21 13:52:04 +08:00 committed by GitHub
parent ae7c338105
commit cefc29ff06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 9 deletions

78
colossalai/nn/parallel.py Normal file
View File

@ -0,0 +1,78 @@
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
__all__ = ['ColoDDP']
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)
class ColoDDP(torch.nn.Module):
def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self.module = module
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
for p in module.parameters():
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
return self.module(*args, **kwargs)
def backward(self, loss: torch.Tensor):
loss.backward()
torch.cuda.current_stream().wait_stream(self.comm_stream)
for p in self.module.parameters():
p.grad = p._saved_grad
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA))
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
ColoDDP._save_grad(p, grad)
return empty_grad
@staticmethod
def _save_grad(p, grad):
if hasattr(p, '_saved_grad'):
p._saved_grad.add_(grad)
else:
p._saved_grad = grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
for p in self.module.parameters():
if getattr(p, '_saved_grad', None) is not None:
if set_to_none:
p._saved_grad = None
else:
if p._saved_grad.grad_fn is not None:
p._saved_grad.detach_()
else:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()

View File

@ -9,8 +9,10 @@ from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP
def init_1d_row_spec(model):
@ -43,7 +45,7 @@ def check_grad_equal(model, torch_model):
assert tensor_shard_equal(torch_p.grad, p.grad)
def run_gpt(init_spec_func):
def run_gpt(init_spec_func, use_ddp):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -51,37 +53,50 @@ def run_gpt(init_spec_func):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
if use_ddp:
model = ColoDDP(model)
torch_model = DDP(torch_model,
device_ids=[gpc.get_global_rank()],
process_group=gpc.get_group(ParallelMode.DATA))
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
init_spec_func(model)
check_param_equal(model, torch_model)
model.train()
torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits)
loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids)
loss.backward()
if use_ddp:
model.backward(loss)
else:
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model)
if i > 0:
break
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec)
run_gpt(init_1d_col_spec)
run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)