[zero] add unit testings for hybrid parallelism (#2486)

pull/2494/head
HELSON 2023-01-18 10:36:10 +08:00 committed by GitHub
parent fcc6d61d92
commit d565a24849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 188 additions and 98 deletions

View File

@ -7,7 +7,6 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
@ -19,25 +18,24 @@ class BucketStore(BaseStore):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
def add_grad(self, tensor, reduce_rank: int = None):
self._grads[reduce_rank].append(tensor)
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
return self._grads[reduce_rank]
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]

View File

@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None):
@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank
###########################################################
# Backward Reduction Hook
###########################################################
###########################
# Backward Reduction Hook #
###########################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
return grad
def _attach_reduction_hook(self):
# we iterate over the fp16 params
@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else:
reduce_rank = None
def _define_and_attach(param, reduce_rank):
# get the AccumulateGrad object of the param itself
accum_grad_obj = get_grad_accumulate_object(param)
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
param=param,
reduce_rank=reduce_rank)
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args):
reduction_func()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
accum_grad_obj.register_hook(reduce_grad_hook)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
_define_and_attach(param, reduce_rank)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_in_bucket(reduce_rank)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)
def _reduce_grads_in_bucket(self, reduce_rank=None):
#######################
# Reduction Functions #
#######################
def _run_reduction(self, reduce_rank=None):
# reduce grads
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
def _add_to_reduction_bucket(self, param, reduce_rank=None):
param_size = param.numel()
for tensor_list in grad_buckets_by_dtype:
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._run_reduction(reduce_rank)
##############################
# Reduction Utility Function #
##############################
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
################################
# torch.optim.Optimizer methods
@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._reduce_and_remove_grads_by_bucket(param)
self._add_to_reduction_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
self._reduce_grads_in_bucket()
self._run_reduction()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._reduce_grads_in_bucket(reduce_rank)
self._run_reduction(reduce_rank)

View File

@ -4,6 +4,7 @@ import random
import numpy as np
import torch
import torch.distributed as dist
from torch.testing import assert_close
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
return tensor_chunk.clone()
def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
assert_close(t_a, t_b, rtol=rtol, atol=atol)
return True
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
def tensor_shard_equal(tensor: torch.Tensor,
shard: torch.Tensor,
rank: int,
world_size: int,
rtol: float = 1e-3,
atol: float = 1e-1):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
return tensor_equal(tensor, shard, rtol, atol)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
if rank is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
else:
raise NotImplementedError

View File

@ -0,0 +1,98 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import LowLevelZeroOptimizer
from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(32, 128)
self.act = nn.GELU()
self.linear2 = nn.Linear(128, 32)
def forward(self, x):
y = self.linear1(x)
y = self.act(y)
y = self.linear2(y)
return x + y
@parameterize("overlap_flag", [False, True])
@parameterize("partition_flag", [False, True])
def exam_zero_with_tp(overlap_flag, partition_flag):
set_seed(233010)
tp_pg = ProcessGroup(tp_degree=2)
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
hybrid_model = TestModel()
torch_model = TestModel().cuda()
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
pt.data.copy_(ph.data)
for name, param in hybrid_model.named_parameters():
if 'linear1' in name:
split_param_row_tp1d(param, tp_pg)
param.compute_spec.set_output_replicate(False)
if 'linear2.weight' in name:
split_param_col_tp1d(param, tp_pg)
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1)
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1)
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
initial_scale=1,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)
data = torch.randn(8, 32, device=get_current_device())
torch_loss = torch_model(data).sum()
hybrid_loss = hybrid_model(data).sum()
assert_close(torch_loss, hybrid_loss)
torch_loss.backward()
hybrid_optim.backward(hybrid_loss)
hybrid_optim.sync_grad()
torch_optim.step()
hybrid_optim.step()
for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
assert strict_shard_equal(pt.data, ph.data, tp_pg)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_with_tp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_with_tp():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_with_tp()