diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 16ba8a6d6..5b09019b9 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -1,3 +1,7 @@ +from typing import Optional + +import torch +import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -6,6 +10,7 @@ class TensorBucket: self._max_size = size self._current_size = 0 self._bucket = [] + self._write_back_pairs = {} @property def max_size(self): @@ -21,7 +26,7 @@ class TensorBucket: def is_empty(self): return len(self._bucket) == 0 - def add_to_bucket(self, tensor, allow_oversize=False): + def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None): tensor_size = tensor.numel() if not allow_oversize and self.will_exceed_max_size(tensor_size): @@ -30,6 +35,8 @@ class TensorBucket: self._bucket.append(tensor) self._current_size += tensor_size + write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor + self._write_back_pairs[tensor] = write_back_tensor def will_exceed_max_size(self, tensor_size): expected_size = self._current_size + tensor_size @@ -40,12 +47,30 @@ class TensorBucket: def empty(self): self._bucket = [] - self._size = 0 + self._current_size = 0 + self._write_back_pairs = {} def flatten(self): return _flatten_dense_tensors(self._bucket) + def unflatten(self, flat_tensor): + return _unflatten_dense_tensors(flat_tensor, self._bucket) + def unflatten_and_copy(self, flat_tensor): - unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + unflattened_tensor_list = self.unflatten(flat_tensor) for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) + + def all_gather(self, group=None): + flat = self.flatten() + buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] + dist.all_gather(buffers, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffers] + # transpose the list of list + unflat_buffers = list(map(list, zip(*unflat_buffers))) + for unflat_shards, tensor in zip(unflat_buffers, self._bucket): + write_back_tensor = self._write_back_pairs[tensor] + write_back_tensor.data.copy_( + _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor) + ) + self.empty() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 5f7f2a4e2..d19e0a002 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor -from .bookkeeping import BucketStore, GradientStore, ParameterStore +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for group_id in range(self.num_param_groups): release_param_grad(self._master_param_groups_of_current_rank[group_id]) + tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) + moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) + # update working partition updated by the current rank device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] + param_to_gather = splited_param.to(device).to(self._dtype) if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.moe_extra_dp_pg, - ) + try: + moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) else: - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.torch_pg, - ) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + try: + tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + if not moe_tensor_bucket.is_empty(): + moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(self._bucket_store.torch_pg) def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r"""