mirror of https://github.com/hpcaitech/ColossalAI
[zero] add unit testings for hybrid parallelism (#2486)
parent
fcc6d61d92
commit
d565a24849
|
@ -7,7 +7,6 @@ class BucketStore(BaseStore):
|
|||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
|
@ -19,25 +18,24 @@ class BucketStore(BaseStore):
|
|||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
def add_grad(self, tensor, reduce_rank: int = None):
|
||||
self._grads[reduce_rank].append(tensor)
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._grads = {rank: [] for rank in keys}
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._grads[reduce_rank] = []
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
return self._grads[reduce_rank]
|
||||
param_list = self.get_param(reduce_rank)
|
||||
for param in param_list:
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
return [param.grad for param in param_list]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
||||
|
|
|
@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
|
@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
|
||||
return params_per_rank
|
||||
|
||||
###########################################################
|
||||
# Backward Reduction Hook
|
||||
###########################################################
|
||||
###########################
|
||||
# Backward Reduction Hook #
|
||||
###########################
|
||||
|
||||
def _grad_handler(self, param, grad, reduce_rank):
|
||||
self._add_to_reduction_bucket(param, reduce_rank)
|
||||
return grad
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
|
@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
else:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
|
||||
|
||||
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank)
|
||||
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args):
|
||||
reduction_func()
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduce_global_rank = None
|
||||
if reduce_rank is not None:
|
||||
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
||||
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_local_rank=reduce_rank,
|
||||
dst_global_rank=reduce_global_rank,
|
||||
group=self._dp_torch_group)
|
||||
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
|
||||
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
def _reduce_grads(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
|
||||
bucket_size=bucket_size,
|
||||
reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_grads_in_bucket(self, reduce_rank=None):
|
||||
#######################
|
||||
# Reduction Functions #
|
||||
#######################
|
||||
|
||||
def _run_reduction(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
self._reduce_grads(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
|
@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
def _add_to_reduction_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._run_reduction(reduce_rank)
|
||||
|
||||
##############################
|
||||
# Reduction Utility Function #
|
||||
##############################
|
||||
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduce_global_rank = None
|
||||
if reduce_rank is not None:
|
||||
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
||||
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_local_rank=reduce_rank,
|
||||
dst_global_rank=reduce_global_rank,
|
||||
group=self._dp_torch_group)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
################################
|
||||
# torch.optim.Optimizer methods
|
||||
|
@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# broadcast the updated model weights
|
||||
handles = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._world_size):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
for index in range(self._world_size):
|
||||
rank = self._dp_global_ranks[index]
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
|
||||
handles.append(handle)
|
||||
|
||||
|
@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._reduce_and_remove_grads_by_bucket(param)
|
||||
self._add_to_reduction_bucket(param)
|
||||
|
||||
# we need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
self._reduce_grads_in_bucket()
|
||||
self._run_reduction()
|
||||
|
||||
def _reduce_grad_stage2(self):
|
||||
# when partition_grads is True, reduction hooks
|
||||
|
@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# only need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
for reduce_rank in range(self._world_size):
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
self._run_reduction(reduce_rank)
|
||||
|
|
|
@ -4,6 +4,7 @@ import random
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
|
|||
return tensor_chunk.clone()
|
||||
|
||||
|
||||
def tensor_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||
def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
|
||||
assert_close(t_a, t_b, rtol=rtol, atol=atol)
|
||||
return True
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
|
||||
def tensor_shard_equal(tensor: torch.Tensor,
|
||||
shard: torch.Tensor,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
rtol: float = 1e-3,
|
||||
atol: float = 1e-1):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
return tensor_equal(tensor, shard, rtol, atol)
|
||||
else:
|
||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||
if dims_not_eq.numel() == 1:
|
||||
|
@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
|
|||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
if rank is None:
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
|
||||
|
||||
|
||||
def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
|
||||
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
self.linear1 = nn.Linear(32, 128)
|
||||
self.act = nn.GELU()
|
||||
self.linear2 = nn.Linear(128, 32)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.linear1(x)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
return x + y
|
||||
|
||||
|
||||
@parameterize("overlap_flag", [False, True])
|
||||
@parameterize("partition_flag", [False, True])
|
||||
def exam_zero_with_tp(overlap_flag, partition_flag):
|
||||
set_seed(233010)
|
||||
tp_pg = ProcessGroup(tp_degree=2)
|
||||
|
||||
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
|
||||
hybrid_model = TestModel()
|
||||
torch_model = TestModel().cuda()
|
||||
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
|
||||
pt.data.copy_(ph.data)
|
||||
|
||||
for name, param in hybrid_model.named_parameters():
|
||||
if 'linear1' in name:
|
||||
split_param_row_tp1d(param, tp_pg)
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
if 'linear2.weight' in name:
|
||||
split_param_col_tp1d(param, tp_pg)
|
||||
|
||||
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1)
|
||||
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
|
||||
initial_scale=1,
|
||||
overlap_communication=overlap_flag,
|
||||
partition_grad=partition_flag)
|
||||
|
||||
dp_local_rank = tp_pg.dp_local_rank()
|
||||
set_seed(255 + dp_local_rank)
|
||||
|
||||
data = torch.randn(8, 32, device=get_current_device())
|
||||
torch_loss = torch_model(data).sum()
|
||||
hybrid_loss = hybrid_model(data).sum()
|
||||
assert_close(torch_loss, hybrid_loss)
|
||||
|
||||
torch_loss.backward()
|
||||
hybrid_optim.backward(hybrid_loss)
|
||||
hybrid_optim.sync_grad()
|
||||
|
||||
torch_optim.step()
|
||||
hybrid_optim.step()
|
||||
|
||||
for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
|
||||
assert strict_shard_equal(pt.data, ph.data, tp_pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_zero_with_tp()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_with_tp():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_with_tp()
|
Loading…
Reference in New Issue