|
|
|
@ -90,38 +90,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
|
self._verbose = verbose
|
|
|
|
|
|
|
|
|
|
# stage 2
|
|
|
|
|
self._partition_grads = partition_grad
|
|
|
|
|
|
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
|
|
|
|
|
|
# grad accumulation
|
|
|
|
|
self.require_grad_sync = True
|
|
|
|
|
|
|
|
|
|
# if process_group is none, will use the default one
|
|
|
|
|
self.dp_pg = dp_process_group
|
|
|
|
|
self._local_rank = dist.get_rank(group=self.dp_pg)
|
|
|
|
|
self._world_size = dist.get_world_size(group=self.dp_pg)
|
|
|
|
|
|
|
|
|
|
# extra dp
|
|
|
|
|
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
|
|
|
|
|
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
|
|
|
|
|
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
|
|
|
|
|
# And moe working and master param are split by extra dp pg.
|
|
|
|
|
self.moe_extra_dp_pg = moe_extra_dp_process_group
|
|
|
|
|
if self.moe_extra_dp_pg is not None:
|
|
|
|
|
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
|
|
|
|
|
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
|
|
|
|
|
|
|
|
|
|
# working and master params for mixed precision training
|
|
|
|
|
self._working_param_groups = dict()
|
|
|
|
|
self._master_param_groups_of_current_rank = dict()
|
|
|
|
|
|
|
|
|
|
# communication params
|
|
|
|
|
self._overlap_communication = overlap_communication
|
|
|
|
|
self._reduce_bucket_size = reduce_bucket_size
|
|
|
|
|
self._communication_dtype = communication_dtype
|
|
|
|
|
|
|
|
|
|
# gradient clipping
|
|
|
|
|
self._clip_grad_norm = clip_grad_norm
|
|
|
|
|
|
|
|
|
@ -140,9 +114,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
# ParameterStore will manage the tensor buffers used for zero
|
|
|
|
|
# it will not manage the tensors used by mixed precision training
|
|
|
|
|
self._param_store = ParameterStore(self.dp_pg)
|
|
|
|
|
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
|
|
|
|
|
self._bucket_store = BucketStore(self.dp_pg)
|
|
|
|
|
self._param_store = ParameterStore(dp_process_group)
|
|
|
|
|
self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True)
|
|
|
|
|
self._bucket_store = BucketStore(
|
|
|
|
|
dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# moe param should not be stored in working_groups
|
|
|
|
|
# because they have different parallel strategy
|
|
|
|
@ -157,7 +133,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
group_params = list()
|
|
|
|
|
for param in param_group["params"]:
|
|
|
|
|
if param.requires_grad:
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
# skip moe param
|
|
|
|
|
if is_moe_tensor(param):
|
|
|
|
|
self.working_moe_params.append(param)
|
|
|
|
@ -194,15 +170,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
param_group["params"] = self.master_moe_params
|
|
|
|
|
self.optim.param_groups.append(param_group)
|
|
|
|
|
|
|
|
|
|
# initialize communication stream for
|
|
|
|
|
# communication-computation overlapping
|
|
|
|
|
if self._overlap_communication:
|
|
|
|
|
self._comm_stream = get_accelerator().Stream()
|
|
|
|
|
|
|
|
|
|
# reduction hook is only used if overlapping communication
|
|
|
|
|
# or stage 2 is used
|
|
|
|
|
# if it is stage 1 without overlapping, no hook will be attached
|
|
|
|
|
if self._overlap_communication or self._partition_grads:
|
|
|
|
|
if self._bucket_store._overlap_communication or self._grad_store._partition_grads:
|
|
|
|
|
self._attach_reduction_hook()
|
|
|
|
|
|
|
|
|
|
# initialize mixed precision mixin
|
|
|
|
@ -222,6 +193,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
elif self._dtype is torch.bfloat16:
|
|
|
|
|
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
self.remove_hooks()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dtype(self):
|
|
|
|
|
return self._dtype
|
|
|
|
@ -246,7 +220,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
|
|
|
|
|
|
|
|
|
|
for param in param_list:
|
|
|
|
|
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
|
|
|
|
padding_size = (
|
|
|
|
|
self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size
|
|
|
|
|
) % self._bucket_store.zero_world_size
|
|
|
|
|
self._param_store.record_param_padding_size(param, padding_size)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
@ -258,12 +234,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
else:
|
|
|
|
|
padding_param = param.data.view(-1)
|
|
|
|
|
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
|
|
|
|
|
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
|
|
|
|
|
splited_params = splited_params[self.moe_extra_dp_pg_rank]
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param):
|
|
|
|
|
splited_params = padding_param.split(
|
|
|
|
|
padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size
|
|
|
|
|
)
|
|
|
|
|
splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank]
|
|
|
|
|
else:
|
|
|
|
|
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
|
|
|
|
splited_params = splited_params[self._local_rank]
|
|
|
|
|
splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size)
|
|
|
|
|
splited_params = splited_params[self._bucket_store.zero_local_rank]
|
|
|
|
|
|
|
|
|
|
# use fp32 when master_weights is True
|
|
|
|
|
if self._master_weights is True:
|
|
|
|
@ -280,10 +258,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
# Backward Reduction Hook #
|
|
|
|
|
###########################
|
|
|
|
|
|
|
|
|
|
def _grad_handler(self, group_id, param):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def grad_handler(
|
|
|
|
|
param: nn.Parameter,
|
|
|
|
|
group_id: int,
|
|
|
|
|
bucket_store: BucketStore,
|
|
|
|
|
param_store: ParameterStore,
|
|
|
|
|
grad_store: GradientStore,
|
|
|
|
|
):
|
|
|
|
|
# if run with no_sync context, would not sync grad when backward
|
|
|
|
|
if self.require_grad_sync:
|
|
|
|
|
self._add_to_bucket(param, group_id)
|
|
|
|
|
if grad_store.require_grad_sync:
|
|
|
|
|
LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store)
|
|
|
|
|
|
|
|
|
|
def _attach_reduction_hook(self):
|
|
|
|
|
# we iterate over the working params
|
|
|
|
@ -292,29 +277,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
param_group = self._working_param_groups[group_id]
|
|
|
|
|
for param in param_group:
|
|
|
|
|
if param.requires_grad:
|
|
|
|
|
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
|
|
|
|
|
param._grad_handle = param.register_post_accumulate_grad_hook(
|
|
|
|
|
partial(
|
|
|
|
|
LowLevelZeroOptimizer.grad_handler,
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
bucket_store=self._bucket_store,
|
|
|
|
|
param_store=self._param_store,
|
|
|
|
|
grad_store=self._grad_store,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
#######################
|
|
|
|
|
# Reduction Functions #
|
|
|
|
|
#######################
|
|
|
|
|
|
|
|
|
|
def _run_reduction(self):
|
|
|
|
|
if self._bucket_store.num_elements_in_bucket() > 0:
|
|
|
|
|
self._bucket_store.build_grad_in_bucket()
|
|
|
|
|
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
flat_grads = self._bucket_store.get_flatten_grad()
|
|
|
|
|
flat_grads /= self._world_size
|
|
|
|
|
@staticmethod
|
|
|
|
|
def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
|
|
|
|
|
if bucket_store.num_elements_in_bucket() > 0:
|
|
|
|
|
bucket_store.build_grad_in_bucket()
|
|
|
|
|
if bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
flat_grads = bucket_store.get_flatten_grad()
|
|
|
|
|
flat_grads /= bucket_store.zero_world_size
|
|
|
|
|
else:
|
|
|
|
|
# record moe and non moe param
|
|
|
|
|
moe_list = []
|
|
|
|
|
for param in self._bucket_store._param_list:
|
|
|
|
|
for param in bucket_store._param_list:
|
|
|
|
|
moe_list.append(is_moe_tensor(param))
|
|
|
|
|
|
|
|
|
|
# divide them into different groups
|
|
|
|
|
moe_grad_list = []
|
|
|
|
|
non_moe_grad_list = []
|
|
|
|
|
for grad_list in self._bucket_store._grad_in_bucket.values():
|
|
|
|
|
for grad_list in bucket_store._grad_in_bucket.values():
|
|
|
|
|
non_moe_cur_grad = []
|
|
|
|
|
moe_cur_grad = []
|
|
|
|
|
for i in range(len(grad_list)):
|
|
|
|
@ -332,7 +324,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
for grad_list in non_moe_grad_list:
|
|
|
|
|
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
|
|
|
|
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
|
|
|
|
|
non_moe_flat_grads /= self._world_size
|
|
|
|
|
non_moe_flat_grads /= bucket_store.zero_world_size
|
|
|
|
|
|
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
|
moe_flat_grads = []
|
|
|
|
@ -341,12 +333,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
|
|
|
|
|
|
|
|
|
|
# ready to add other tensors to bucket
|
|
|
|
|
self._bucket_store.reset_num_elements_in_bucket()
|
|
|
|
|
bucket_store.reset_num_elements_in_bucket()
|
|
|
|
|
|
|
|
|
|
if self._overlap_communication:
|
|
|
|
|
stream = self._comm_stream
|
|
|
|
|
if bucket_store._overlap_communication:
|
|
|
|
|
stream = bucket_store.comm_stream
|
|
|
|
|
# in case of the memory being reused in the default stream
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
if bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
flat_grads.record_stream(stream)
|
|
|
|
|
else:
|
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
@ -359,53 +351,63 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
stream = get_accelerator().current_stream()
|
|
|
|
|
|
|
|
|
|
with get_accelerator().stream(stream):
|
|
|
|
|
group_id = self._bucket_store.current_group_id
|
|
|
|
|
group_id = bucket_store.current_group_id
|
|
|
|
|
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
if bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
grad_dtype = flat_grads.dtype
|
|
|
|
|
if self._communication_dtype is not None:
|
|
|
|
|
flat_grads = flat_grads.to(self._communication_dtype)
|
|
|
|
|
if bucket_store._communication_dtype is not None:
|
|
|
|
|
flat_grads = flat_grads.to(bucket_store._communication_dtype)
|
|
|
|
|
|
|
|
|
|
if not self._partition_grads:
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
dist.all_reduce(flat_grads, group=self.dp_pg)
|
|
|
|
|
if not grad_store._partition_grads:
|
|
|
|
|
if bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
|
|
|
|
if flat_grads.dtype != grad_dtype:
|
|
|
|
|
flat_grads = flat_grads.to(grad_dtype)
|
|
|
|
|
|
|
|
|
|
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
|
|
|
|
grad_in_bucket = self._bucket_store.get_grad()
|
|
|
|
|
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
|
|
|
|
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size)
|
|
|
|
|
grad_in_bucket = bucket_store.get_grad()
|
|
|
|
|
LowLevelZeroOptimizer.update_unpartitoned_grad(
|
|
|
|
|
bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# sync extra zero group
|
|
|
|
|
else:
|
|
|
|
|
# sync non moe param in global dp group
|
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
|
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
|
|
|
|
|
dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg)
|
|
|
|
|
flat_grads_per_rank = non_moe_flat_grads.split(
|
|
|
|
|
non_moe_flat_grads.numel() // self._world_size
|
|
|
|
|
non_moe_flat_grads.numel() // bucket_store.zero_world_size
|
|
|
|
|
)
|
|
|
|
|
LowLevelZeroOptimizer.update_unpartitoned_grad(
|
|
|
|
|
bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id
|
|
|
|
|
)
|
|
|
|
|
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
|
|
|
|
|
|
|
|
|
|
# sync moe param only in zero group
|
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
|
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
|
|
|
|
|
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
|
|
|
|
|
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
|
|
|
|
|
dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg)
|
|
|
|
|
flat_grads_per_rank = moe_flat_grads.split(
|
|
|
|
|
moe_flat_grads.numel() // bucket_store.zero_world_size
|
|
|
|
|
)
|
|
|
|
|
LowLevelZeroOptimizer.update_unpartitoned_grad(
|
|
|
|
|
bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
|
|
|
|
if bucket_store.moe_extra_dp_pg is None:
|
|
|
|
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size))
|
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
|
|
|
|
|
|
|
|
|
|
if recieved_grad.dtype != grad_dtype:
|
|
|
|
|
recieved_grad = recieved_grad.to(grad_dtype)
|
|
|
|
|
|
|
|
|
|
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
|
|
|
|
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
|
|
|
|
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
|
|
|
|
|
LowLevelZeroOptimizer.update_partitoned_grad(
|
|
|
|
|
bucket_store, grad_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# categorize moe and non moe param
|
|
|
|
|
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
|
|
|
|
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
|
|
|
|
|
moe_grad_in_bucket_current_rank = []
|
|
|
|
|
non_moe_grad_in_bucket_current_rank = []
|
|
|
|
|
for idx, grad in enumerate(grad_in_bucket_current_rank):
|
|
|
|
@ -416,11 +418,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
|
flat_grads_list = list(
|
|
|
|
|
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
|
|
|
|
|
non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size)
|
|
|
|
|
)
|
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
|
|
|
|
self._update_partitoned_grad(
|
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
|
|
|
|
|
LowLevelZeroOptimizer.update_partitoned_grad(
|
|
|
|
|
bucket_store,
|
|
|
|
|
grad_store,
|
|
|
|
|
non_moe_grad_in_bucket_current_rank,
|
|
|
|
|
recieved_grad,
|
|
|
|
|
group_id,
|
|
|
|
@ -429,35 +433,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
|
flat_grads_list = list(
|
|
|
|
|
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
|
|
|
|
|
moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size)
|
|
|
|
|
)
|
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
|
dist.reduce_scatter(
|
|
|
|
|
recieved_grad,
|
|
|
|
|
flat_grads_list,
|
|
|
|
|
group=self.moe_extra_dp_pg,
|
|
|
|
|
group=bucket_store.moe_extra_dp_pg,
|
|
|
|
|
)
|
|
|
|
|
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
|
|
|
|
param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size
|
|
|
|
|
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
|
|
|
|
|
for split_recieved_grad in recieved_grad:
|
|
|
|
|
split_recieved_grad = _unflatten_dense_tensors(
|
|
|
|
|
split_recieved_grad, moe_grad_in_bucket_current_rank
|
|
|
|
|
)
|
|
|
|
|
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
|
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
self._add_grad(real_grad, param_slice, group_id, param_id)
|
|
|
|
|
param_id = bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
LowLevelZeroOptimizer.add_grad(
|
|
|
|
|
grad_store, real_grad, param_slice, group_id, param_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._bucket_store.reset()
|
|
|
|
|
bucket_store.reset()
|
|
|
|
|
|
|
|
|
|
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_unpartitoned_grad(
|
|
|
|
|
bucket_store: BucketStore,
|
|
|
|
|
grad_store: GradientStore,
|
|
|
|
|
origin_grad_list: List,
|
|
|
|
|
flat_grad_list: List,
|
|
|
|
|
group_id: int,
|
|
|
|
|
) -> None:
|
|
|
|
|
for rank, grad_list in enumerate(origin_grad_list):
|
|
|
|
|
sync_tensor(flat_grad_list[rank], grad_list)
|
|
|
|
|
for grad in grad_list:
|
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
self._add_grad(grad, self._world_size, group_id, param_id, rank)
|
|
|
|
|
param_id = bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank)
|
|
|
|
|
|
|
|
|
|
def _update_partitoned_grad(
|
|
|
|
|
self,
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_partitoned_grad(
|
|
|
|
|
bucket_store: BucketStore,
|
|
|
|
|
grad_store: GradientStore,
|
|
|
|
|
origin_grad_list: List,
|
|
|
|
|
flat_grad: torch.Tensor,
|
|
|
|
|
group_id: int,
|
|
|
|
@ -465,23 +480,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
) -> None:
|
|
|
|
|
sync_tensor(flat_grad, origin_grad_list)
|
|
|
|
|
for grad in origin_grad_list:
|
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
self._add_grad(grad, partition_num, group_id, param_id)
|
|
|
|
|
param_id = bucket_store.get_param_id_of_grad(grad)
|
|
|
|
|
LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id)
|
|
|
|
|
|
|
|
|
|
def _add_grad(
|
|
|
|
|
self,
|
|
|
|
|
@staticmethod
|
|
|
|
|
def add_grad(
|
|
|
|
|
grad_store: GradientStore,
|
|
|
|
|
grad: torch.Tensor,
|
|
|
|
|
partition_num: int,
|
|
|
|
|
group_id: int,
|
|
|
|
|
param_id: int,
|
|
|
|
|
rank: int = 0,
|
|
|
|
|
) -> None:
|
|
|
|
|
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
|
|
|
|
|
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
|
|
|
|
if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
|
|
|
|
|
grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
|
|
|
|
else:
|
|
|
|
|
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
|
|
|
|
grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
|
|
|
|
|
|
|
|
|
def _add_to_bucket(self, param, group_id):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def add_to_bucket(
|
|
|
|
|
param: nn.Parameter,
|
|
|
|
|
group_id: int,
|
|
|
|
|
bucket_store: BucketStore,
|
|
|
|
|
param_store: ParameterStore,
|
|
|
|
|
grad_store: GradientStore,
|
|
|
|
|
):
|
|
|
|
|
param_size = param.numel()
|
|
|
|
|
|
|
|
|
|
# check if the bucket is full
|
|
|
|
@ -489,13 +512,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
# or got a grad of param from another group
|
|
|
|
|
# after reduction, the bucket will be empty
|
|
|
|
|
if (
|
|
|
|
|
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
|
|
|
|
|
or group_id != self._bucket_store.current_group_id
|
|
|
|
|
bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size
|
|
|
|
|
or group_id != bucket_store.current_group_id
|
|
|
|
|
):
|
|
|
|
|
self._run_reduction()
|
|
|
|
|
LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store)
|
|
|
|
|
|
|
|
|
|
padding_size = self._param_store.get_param_padding_size(param)
|
|
|
|
|
self._bucket_store.add_param_grad(group_id, param, padding_size)
|
|
|
|
|
padding_size = param_store.get_param_padding_size(param)
|
|
|
|
|
bucket_store.add_param_grad(group_id, param, padding_size)
|
|
|
|
|
|
|
|
|
|
################################
|
|
|
|
|
# torch.optim.Optimizer methods
|
|
|
|
@ -503,7 +526,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
def backward(self, loss, retain_graph=False):
|
|
|
|
|
assert not (
|
|
|
|
|
self._partition_grads and not self.require_grad_sync
|
|
|
|
|
self._grad_store._partition_grads and not self._grad_store.require_grad_sync
|
|
|
|
|
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
|
|
|
|
|
|
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
@ -511,31 +534,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
loss.backward(retain_graph=retain_graph)
|
|
|
|
|
|
|
|
|
|
if not self.require_grad_sync:
|
|
|
|
|
if not self._grad_store.require_grad_sync:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self._reduce_grad(self._partition_grads)
|
|
|
|
|
self._reduce_grad(self._grad_store._partition_grads)
|
|
|
|
|
|
|
|
|
|
# clear reduced grads
|
|
|
|
|
if self._overlap_communication:
|
|
|
|
|
if self._bucket_store._overlap_communication:
|
|
|
|
|
get_accelerator().synchronize()
|
|
|
|
|
self.zero_grad()
|
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
|
|
|
assert not (
|
|
|
|
|
self._partition_grads and not self.require_grad_sync
|
|
|
|
|
self._grad_store._partition_grads and not self._grad_store.require_grad_sync
|
|
|
|
|
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
|
|
|
|
|
|
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
|
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
|
|
|
|
torch.autograd.backward(tensor, grad)
|
|
|
|
|
|
|
|
|
|
if not self.require_grad_sync:
|
|
|
|
|
if not self._grad_store.require_grad_sync:
|
|
|
|
|
return
|
|
|
|
|
self._reduce_grad(self._partition_grads)
|
|
|
|
|
self._reduce_grad(self._grad_store._partition_grads)
|
|
|
|
|
|
|
|
|
|
# clear reduced grads
|
|
|
|
|
if self._overlap_communication:
|
|
|
|
|
if self._bucket_store._overlap_communication:
|
|
|
|
|
get_accelerator().synchronize()
|
|
|
|
|
|
|
|
|
|
self.zero_grad()
|
|
|
|
@ -566,7 +589,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
def step(self, closure=None):
|
|
|
|
|
assert closure is None, "closure is not supported by step()"
|
|
|
|
|
if not self.require_grad_sync:
|
|
|
|
|
if not self._grad_store.require_grad_sync:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
|
|
|
@ -585,7 +608,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
# and should not be updated
|
|
|
|
|
real_working_params = dict()
|
|
|
|
|
real_master_params = dict()
|
|
|
|
|
grad_index = 0 if self._partition_grads else self._local_rank
|
|
|
|
|
grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
master_params = self._master_param_groups_of_current_rank[group_id]
|
|
|
|
|
real_working_params[group_id] = []
|
|
|
|
@ -598,14 +621,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
|
|
|
|
if len(grads) > 0:
|
|
|
|
|
# moe hybrid zero
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
|
real_working_params[group_id].append(working_param)
|
|
|
|
|
if self._partition_grads:
|
|
|
|
|
if self._grad_store._partition_grads:
|
|
|
|
|
grad = grads
|
|
|
|
|
else:
|
|
|
|
|
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
|
|
|
|
param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size
|
|
|
|
|
grad = grads[
|
|
|
|
|
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
|
|
|
|
|
self._bucket_store.moe_extra_dp_pg_rank
|
|
|
|
|
* param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1)
|
|
|
|
|
* param_slice
|
|
|
|
|
]
|
|
|
|
|
grad = flatten(grad)
|
|
|
|
|
else:
|
|
|
|
@ -674,25 +699,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
|
|
|
|
for idx, splited_param in enumerate(master_working_param):
|
|
|
|
|
working_param = real_working_params[group_id][idx]
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
|
all_splited_param = [
|
|
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
|
|
|
|
for _ in range(self.moe_extra_dp_pg_size)
|
|
|
|
|
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(
|
|
|
|
|
all_splited_param,
|
|
|
|
|
splited_param.to(device).to(self._dtype),
|
|
|
|
|
group=self.moe_extra_dp_pg,
|
|
|
|
|
group=self._bucket_store.moe_extra_dp_pg,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
all_splited_param = [
|
|
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
|
|
|
|
for _ in range(self._world_size)
|
|
|
|
|
for _ in range(self._bucket_store.zero_world_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(
|
|
|
|
|
all_splited_param,
|
|
|
|
|
splited_param.to(device).to(self._dtype),
|
|
|
|
|
group=self.dp_pg,
|
|
|
|
|
group=self._bucket_store.torch_pg,
|
|
|
|
|
)
|
|
|
|
|
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
|
|
|
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
|
|
|
@ -720,7 +745,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
device=get_accelerator().get_current_device(),
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
)
|
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
|
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg)
|
|
|
|
|
total_norm = total_norm_cuda.item()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
@ -738,7 +763,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
torch.distributed.all_reduce(
|
|
|
|
|
total_norm_exponentiated_cuda,
|
|
|
|
|
op=torch.distributed.ReduceOp.SUM,
|
|
|
|
|
group=self.dp_pg,
|
|
|
|
|
group=self._bucket_store.torch_pg,
|
|
|
|
|
)
|
|
|
|
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
|
@ -773,27 +798,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
param_group = self._working_param_groups[group_id]
|
|
|
|
|
for param in param_group:
|
|
|
|
|
if param.requires_grad and param.grad is not None:
|
|
|
|
|
self._add_to_bucket(param, group_id)
|
|
|
|
|
LowLevelZeroOptimizer.add_to_bucket(
|
|
|
|
|
param,
|
|
|
|
|
group_id,
|
|
|
|
|
self._bucket_store,
|
|
|
|
|
self._param_store,
|
|
|
|
|
self._grad_store,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._run_reduction()
|
|
|
|
|
LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
|
|
|
|
|
|
|
|
|
|
def _reduce_grad(self, partition_grad):
|
|
|
|
|
# if not overlapping communication (no reduction hook is attached) when zero1
|
|
|
|
|
# we need to manually reduce these gradients
|
|
|
|
|
if not partition_grad and not self._overlap_communication:
|
|
|
|
|
if not partition_grad and not self._bucket_store._overlap_communication:
|
|
|
|
|
self._sync_grad()
|
|
|
|
|
else:
|
|
|
|
|
self._run_reduction()
|
|
|
|
|
LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
|
|
|
|
|
|
|
|
|
|
# this context comes from pytorch DDP
|
|
|
|
|
@contextmanager
|
|
|
|
|
def no_sync(self):
|
|
|
|
|
old_require_grad_sync = self.require_grad_sync
|
|
|
|
|
self.require_grad_sync = False
|
|
|
|
|
old_require_grad_sync = self._grad_store.require_grad_sync
|
|
|
|
|
self._grad_store.require_grad_sync = False
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
self.require_grad_sync = old_require_grad_sync
|
|
|
|
|
self._grad_store.require_grad_sync = old_require_grad_sync
|
|
|
|
|
|
|
|
|
|
##############
|
|
|
|
|
# State Dict #
|
|
|
|
@ -833,16 +864,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
for k, v in state.items():
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
working_param = self._param_store.master_to_working_param[id(param)]
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
gather_tensor = [
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype)
|
|
|
|
|
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
|
|
|
|
|
else:
|
|
|
|
|
gather_tensor = [
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype)
|
|
|
|
|
for _ in range(self._bucket_store.zero_world_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
|
|
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg)
|
|
|
|
|
param_state = (
|
|
|
|
|
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
)
|
|
|
|
@ -862,17 +895,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
for param_idx, state in zero_state_dict["state"].items():
|
|
|
|
|
for k, v in state.items():
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
|
|
|
|
|
padding_size = (
|
|
|
|
|
self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size
|
|
|
|
|
) % self._bucket_store.zero_world_size
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
v = v.flatten()
|
|
|
|
|
if padding_size > 0:
|
|
|
|
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size)
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = (
|
|
|
|
|
v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
v_list = v.split(v.numel() // self._world_size)
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
|
|
|
|
v_list = v.split(v.numel() // self._bucket_store.zero_world_size)
|
|
|
|
|
zero_state_dict["state"][param_idx][k] = (
|
|
|
|
|
v_list[self._bucket_store.zero_local_rank].detach().clone()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.optim.load_state_dict(zero_state_dict)
|
|
|
|
|
|
|
|
|
@ -904,16 +943,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
|
|
for k, v in states.items():
|
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
|
state_tensor = [
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype)
|
|
|
|
|
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
|
|
|
|
|
else:
|
|
|
|
|
state_tensor = [
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
|
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype)
|
|
|
|
|
for _ in range(self._bucket_store.zero_world_size)
|
|
|
|
|
]
|
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
|
|
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg)
|
|
|
|
|
state_tensor = (
|
|
|
|
|
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
|
)
|
|
|
|
@ -944,14 +985,30 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
working_param = p.data.view(-1)
|
|
|
|
|
if padding_size > 0:
|
|
|
|
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
|
|
|
|
|
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p):
|
|
|
|
|
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
|
|
|
|
else:
|
|
|
|
|
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
|
|
|
|
master_param.copy_(
|
|
|
|
|
working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank]
|
|
|
|
|
)
|
|
|
|
|
if hasattr(self, "master_moe_params"):
|
|
|
|
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
|
|
|
|
master_moe_param.copy_(working_moe_param)
|
|
|
|
|
|
|
|
|
|
def remove_hooks(self) -> None:
|
|
|
|
|
"""remove the registered hooks
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
plugin (LowLevelZeroPlugin): the plugin to bound this method.
|
|
|
|
|
"""
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
param_group = self._working_param_groups[group_id]
|
|
|
|
|
for param in param_group:
|
|
|
|
|
if param.requires_grad:
|
|
|
|
|
assert hasattr(param, "_grad_handle")
|
|
|
|
|
param._grad_handle.remove()
|
|
|
|
|
delattr(param, "_grad_handle")
|
|
|
|
|
|
|
|
|
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
|
|
|
|
return self._param_store.working_to_master_param
|
|
|
|
|
|
|
|
|
|