diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py index 9e0c05d89..ec322a78b 100644 --- a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py +++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py @@ -7,7 +7,6 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - self._grads = dict() self._params = dict() self._num_elements_in_bucket = dict() @@ -19,25 +18,24 @@ class BucketStore(BaseStore): def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): self._num_elements_in_bucket[reduce_rank] += num_elements - def add_grad(self, tensor, reduce_rank: int = None): - self._grads[reduce_rank].append(tensor) - def add_param(self, tensor, reduce_rank: int = None): self._params[reduce_rank].append(tensor) def reset(self): keys = [None] + list(range(self._world_size)) - self._grads = {rank: [] for rank in keys} self._params = {rank: [] for rank in keys} self._num_elements_in_bucket = {rank: 0 for rank in keys} def reset_by_rank(self, reduce_rank=None): - self._grads[reduce_rank] = [] self._params[reduce_rank] = [] self._num_elements_in_bucket[reduce_rank] = 0 def get_grad(self, reduce_rank: int = None): - return self._grads[reduce_rank] + param_list = self.get_param(reduce_rank) + for param in param_list: + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + return [param.grad for param in param_list] def get_param(self, reduce_rank: int = None): return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 38736d01a..f45b5e200 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): reduce_bucket_size: int = 1024 * 1024, # communication communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 + partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None): @@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) return params_per_rank - ########################################################### - # Backward Reduction Hook - ########################################################### + ########################### + # Backward Reduction Hook # + ########################### + + def _grad_handler(self, param, grad, reduce_rank): + self._add_to_reduction_bucket(param, reduce_rank) + return grad def _attach_reduction_hook(self): # we iterate over the fp16 params @@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): else: reduce_rank = None - def _define_and_attach(param, reduce_rank): - # get the AccumulateGrad object of the param itself - accum_grad_obj = get_grad_accumulate_object(param) - self._grad_store.add_accumulate_grad_object(accum_grad_obj) + param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) - reduction_func = partial(self._reduce_and_remove_grads_by_bucket, - param=param, - reduce_rank=reduce_rank) + def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() - # define hook - # NOT IMPORTANT BUT GOOD TO KNOW: - # args here is not grad, but allow_unreacable and accumulate_grad - def reduce_grad_hook(*args): - reduction_func() + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduce_global_rank = None + if reduce_rank is not None: + reduce_global_rank = self._dp_global_ranks[reduce_rank] + reduced_flat = reduce_tensor_dp_group(tensor=flat, + dtype=self._communication_dtype, + dst_local_rank=reduce_rank, + dst_global_rank=reduce_global_rank, + group=self._dp_torch_group) - accum_grad_obj.register_hook(reduce_grad_hook) + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) - _define_and_attach(param, reduce_rank) + def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) - def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): - param_size = param.numel() + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) - # check if the bucket is full - # if full, will reduce the grads already in the bucket - # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_in_bucket(reduce_rank) + if param_bucket.is_full_or_oversized(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + if not param_bucket.is_empty(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + def _reduce_grads(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_half_float_double(grads) - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_grad(param.grad, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + for tensor_list in grad_buckets_by_dtype: + self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, + bucket_size=bucket_size, + reduce_rank=reduce_rank) - def _reduce_grads_in_bucket(self, reduce_rank=None): + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self, reduce_rank=None): # reduce grads - self._reduce_grads_by_rank(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + self._reduce_grads(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) # use communication stream if overlapping # communication with computation @@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._bucket_store.reset_by_rank(reduce_rank) - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_half_float_double(grads) + def _add_to_reduction_bucket(self, param, reduce_rank=None): + param_size = param.numel() - for tensor_list in grad_buckets_by_dtype: - self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._run_reduction(reduce_rank) - ############################## - # Reduction Utility Function # - ############################## - def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduce_global_rank = None - if reduce_rank is not None: - reduce_global_rank = self._dp_global_ranks[reduce_rank] - reduced_flat = reduce_tensor_dp_group(tensor=flat, - dtype=self._communication_dtype, - dst_local_rank=reduce_rank, - dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) ################################ # torch.optim.Optimizer methods @@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # broadcast the updated model weights handles = [] for group_id in range(self.num_param_groups): - for rank in range(self._world_size): - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + for index in range(self._world_size): + rank = self._dp_global_ranks[index] + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id) handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True) handles.append(handle) @@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): param_group = self._fp16_param_groups[group_id] for param in param_group: if param.grad is not None: - self._reduce_and_remove_grads_by_bucket(param) + self._add_to_reduction_bucket(param) # we need to reduce the gradients # left in the communication bucket - self._reduce_grads_in_bucket() + self._run_reduction() def _reduce_grad_stage2(self): # when partition_grads is True, reduction hooks @@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # only need to reduce the gradients # left in the communication bucket for reduce_rank in range(self._world_size): - self._reduce_grads_in_bucket(reduce_rank) + self._run_reduction(reduce_rank) diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py index 6b58aa801..b405f8cd2 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_tensor/common_utils/_utils.py @@ -4,6 +4,7 @@ import random import numpy as np import torch import torch.distributed as dist +from torch.testing import assert_close from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): return tensor_chunk.clone() -def tensor_equal(A, B): - return torch.allclose(A, B, rtol=1e-3, atol=1e-1) +def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1): + assert_close(t_a, t_b, rtol=rtol, atol=atol) + return True -def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size): +def tensor_shard_equal(tensor: torch.Tensor, + shard: torch.Tensor, + rank: int, + world_size: int, + rtol: float = 1e-3, + atol: float = 1e-1): assert tensor.ndim == shard.ndim if tensor.shape == shard.shape: - return tensor_equal(tensor, shard) + return tensor_equal(tensor, shard, rtol, atol) else: dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) if dims_not_eq.numel() == 1: @@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) if rank is None: rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol) else: raise NotImplementedError diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py new file mode 100644 index 000000000..8ba6e3cb6 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -0,0 +1,98 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import LowLevelZeroOptimizer +from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal + + +def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): + return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(32, 128) + self.act = nn.GELU() + self.linear2 = nn.Linear(128, 32) + + def forward(self, x): + y = self.linear1(x) + y = self.act(y) + y = self.linear2(y) + return x + y + + +@parameterize("overlap_flag", [False, True]) +@parameterize("partition_flag", [False, True]) +def exam_zero_with_tp(overlap_flag, partition_flag): + set_seed(233010) + tp_pg = ProcessGroup(tp_degree=2) + + with ColoInitContext(device=get_current_device(), default_pg=tp_pg): + hybrid_model = TestModel() + torch_model = TestModel().cuda() + for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): + pt.data.copy_(ph.data) + + for name, param in hybrid_model.named_parameters(): + if 'linear1' in name: + split_param_row_tp1d(param, tp_pg) + param.compute_spec.set_output_replicate(False) + if 'linear2.weight' in name: + split_param_col_tp1d(param, tp_pg) + + torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1) + hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1) + hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, + initial_scale=1, + overlap_communication=overlap_flag, + partition_grad=partition_flag) + + dp_local_rank = tp_pg.dp_local_rank() + set_seed(255 + dp_local_rank) + + data = torch.randn(8, 32, device=get_current_device()) + torch_loss = torch_model(data).sum() + hybrid_loss = hybrid_model(data).sum() + assert_close(torch_loss, hybrid_loss) + + torch_loss.backward() + hybrid_optim.backward(hybrid_loss) + hybrid_optim.sync_grad() + + torch_optim.step() + hybrid_optim.step() + + for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()): + assert strict_shard_equal(pt.data, ph.data, tp_pg) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + exam_zero_with_tp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_with_tp(): + world_size = 4 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_with_tp()