mirror of https://github.com/hpcaitech/ColossalAI
[zero] use bucket during allgather (#5860)
* [zero] use bucket during allgather * [zero] rename apipull/5864/head
parent
8e718a1421
commit
5dfbcd7746
|
@ -1,3 +1,7 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +10,7 @@ class TensorBucket:
|
||||||
self._max_size = size
|
self._max_size = size
|
||||||
self._current_size = 0
|
self._current_size = 0
|
||||||
self._bucket = []
|
self._bucket = []
|
||||||
|
self._write_back_pairs = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_size(self):
|
def max_size(self):
|
||||||
|
@ -21,7 +26,7 @@ class TensorBucket:
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self._bucket) == 0
|
return len(self._bucket) == 0
|
||||||
|
|
||||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
|
||||||
tensor_size = tensor.numel()
|
tensor_size = tensor.numel()
|
||||||
|
|
||||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||||
|
@ -30,6 +35,8 @@ class TensorBucket:
|
||||||
|
|
||||||
self._bucket.append(tensor)
|
self._bucket.append(tensor)
|
||||||
self._current_size += tensor_size
|
self._current_size += tensor_size
|
||||||
|
write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
|
||||||
|
self._write_back_pairs[tensor] = write_back_tensor
|
||||||
|
|
||||||
def will_exceed_max_size(self, tensor_size):
|
def will_exceed_max_size(self, tensor_size):
|
||||||
expected_size = self._current_size + tensor_size
|
expected_size = self._current_size + tensor_size
|
||||||
|
@ -40,12 +47,30 @@ class TensorBucket:
|
||||||
|
|
||||||
def empty(self):
|
def empty(self):
|
||||||
self._bucket = []
|
self._bucket = []
|
||||||
self._size = 0
|
self._current_size = 0
|
||||||
|
self._write_back_pairs = {}
|
||||||
|
|
||||||
def flatten(self):
|
def flatten(self):
|
||||||
return _flatten_dense_tensors(self._bucket)
|
return _flatten_dense_tensors(self._bucket)
|
||||||
|
|
||||||
|
def unflatten(self, flat_tensor):
|
||||||
|
return _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||||
|
|
||||||
def unflatten_and_copy(self, flat_tensor):
|
def unflatten_and_copy(self, flat_tensor):
|
||||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
unflattened_tensor_list = self.unflatten(flat_tensor)
|
||||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||||
old.copy_(new)
|
old.copy_(new)
|
||||||
|
|
||||||
|
def all_gather(self, group=None):
|
||||||
|
flat = self.flatten()
|
||||||
|
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
|
||||||
|
dist.all_gather(buffers, flat, group=group)
|
||||||
|
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
|
||||||
|
# transpose the list of list
|
||||||
|
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||||
|
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||||
|
write_back_tensor = self._write_back_pairs[tensor]
|
||||||
|
write_back_tensor.data.copy_(
|
||||||
|
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
|
||||||
|
)
|
||||||
|
self.empty()
|
||||||
|
|
|
@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
|
|
||||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||||
|
@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||||
|
|
||||||
|
tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
|
||||||
|
moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
|
||||||
|
|
||||||
# update working partition updated by the current rank
|
# update working partition updated by the current rank
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||||
for idx, splited_param in enumerate(master_working_param):
|
for idx, splited_param in enumerate(master_working_param):
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
|
param_to_gather = splited_param.to(device).to(self._dtype)
|
||||||
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||||
all_splited_param = [
|
try:
|
||||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
|
except RuntimeError:
|
||||||
]
|
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||||
dist.all_gather(
|
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
all_splited_param,
|
|
||||||
splited_param.to(device).to(self._dtype),
|
|
||||||
group=self._bucket_store.moe_extra_dp_pg,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
all_splited_param = [
|
try:
|
||||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
for _ in range(self._bucket_store.zero_world_size)
|
except RuntimeError:
|
||||||
]
|
tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||||
dist.all_gather(
|
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||||
all_splited_param,
|
|
||||||
splited_param.to(device).to(self._dtype),
|
|
||||||
group=self._bucket_store.torch_pg,
|
|
||||||
)
|
|
||||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
|
||||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
if not moe_tensor_bucket.is_empty():
|
||||||
|
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||||
|
if not tensor_bucket.is_empty():
|
||||||
|
tensor_bucket.all_gather(self._bucket_store.torch_pg)
|
||||||
|
|
||||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue