[zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements
pull/5602/head
Hongxin Liu 2024-04-16 17:49:21 +08:00 committed by GitHub
parent 89049b0d89
commit 3788fefc7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 15 deletions

View File

@ -11,7 +11,9 @@ from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self.reset_all()
def reset_all(self) -> None:
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0

View File

@ -40,7 +40,13 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
@ -273,11 +279,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Backward Reduction Hook #
###########################
def _grad_handler(self, param, group_id, grad):
def _grad_handler(self, group_id, param):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
return grad
def _attach_reduction_hook(self):
# we iterate over the working params
@ -286,7 +291,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
param.register_hook(partial(self._grad_handler, param, group_id))
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
#######################
# Reduction Functions #
@ -415,7 +420,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
non_moe_grad_in_bucket_current_rank,
recieved_grad,
group_id,
1,
)
if len(moe_grad_list) > 0:
@ -423,7 +431,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
dist.reduce_scatter(
recieved_grad,
flat_grads_list,
group=self.moe_extra_dp_pg,
)
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
@ -444,14 +456,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._add_grad(grad, self._world_size, group_id, param_id, rank)
def _update_partitoned_grad(
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
self,
origin_grad_list: List,
flat_grad: torch.Tensor,
group_id: int,
partition_num: int,
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
def _add_grad(
self,
grad: torch.Tensor,
partition_num: int,
group_id: int,
param_id: int,
rank: int = 0,
) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
@ -534,6 +557,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
self._bucket_store.reset_all()
####################
# Update Parameter #
@ -655,14 +679,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.moe_extra_dp_pg,
)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.dp_pg,
)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@ -685,7 +715,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
[float(total_norm)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@ -698,10 +730,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
[float(total_norm_exponentiated)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
@ -920,5 +956,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return {
**self._param_store.master_to_working_param,
**self.moe_master_to_working_map,
}
return self._param_store.master_to_working_param

View File

@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
torch>=1.12
torch>=2.1.0
safetensors
einops
pydantic