|
|
|
@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor |
|
|
|
|
|
|
|
|
|
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor |
|
|
|
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore |
|
|
|
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): |
|
|
|
@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
for group_id in range(self.num_param_groups): |
|
|
|
|
release_param_grad(self._master_param_groups_of_current_rank[group_id]) |
|
|
|
|
|
|
|
|
|
tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) |
|
|
|
|
moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) |
|
|
|
|
|
|
|
|
|
# update working partition updated by the current rank |
|
|
|
|
device = get_accelerator().get_current_device() |
|
|
|
|
for group_id in range(self.num_param_groups): |
|
|
|
|
master_working_param = self.optim.param_groups[group_id]["params"] |
|
|
|
|
for idx, splited_param in enumerate(master_working_param): |
|
|
|
|
working_param = real_working_params[group_id][idx] |
|
|
|
|
param_to_gather = splited_param.to(device).to(self._dtype) |
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): |
|
|
|
|
all_splited_param = [ |
|
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) |
|
|
|
|
for _ in range(self._bucket_store.moe_extra_dp_pg_size) |
|
|
|
|
] |
|
|
|
|
dist.all_gather( |
|
|
|
|
all_splited_param, |
|
|
|
|
splited_param.to(device).to(self._dtype), |
|
|
|
|
group=self._bucket_store.moe_extra_dp_pg, |
|
|
|
|
) |
|
|
|
|
try: |
|
|
|
|
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
|
except RuntimeError: |
|
|
|
|
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) |
|
|
|
|
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
|
else: |
|
|
|
|
all_splited_param = [ |
|
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) |
|
|
|
|
for _ in range(self._bucket_store.zero_world_size) |
|
|
|
|
] |
|
|
|
|
dist.all_gather( |
|
|
|
|
all_splited_param, |
|
|
|
|
splited_param.to(device).to(self._dtype), |
|
|
|
|
group=self._bucket_store.torch_pg, |
|
|
|
|
) |
|
|
|
|
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) |
|
|
|
|
try: |
|
|
|
|
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
|
except RuntimeError: |
|
|
|
|
tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) |
|
|
|
|
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) |
|
|
|
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] |
|
|
|
|
if not moe_tensor_bucket.is_empty(): |
|
|
|
|
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) |
|
|
|
|
if not tensor_bucket.is_empty(): |
|
|
|
|
tensor_bucket.all_gather(self._bucket_store.torch_pg) |
|
|
|
|
|
|
|
|
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: |
|
|
|
|
r""" |
|
|
|
|