mirror of https://github.com/hpcaitech/ColossalAI
[zero] support multiple (partial) backward passes (#5596)
* [zero] support multiple (partial) backward passes * [misc] update requirementspull/5602/head
parent
89049b0d89
commit
3788fefc7a
|
@ -11,7 +11,9 @@ from .base_store import BaseStore
|
|||
class BucketStore(BaseStore):
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
self.reset_all()
|
||||
|
||||
def reset_all(self) -> None:
|
||||
# init
|
||||
self.current_group_id = 0
|
||||
self._num_elements_in_bucket = 0
|
||||
|
|
|
@ -40,7 +40,13 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
@ -273,11 +279,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# Backward Reduction Hook #
|
||||
###########################
|
||||
|
||||
def _grad_handler(self, param, group_id, grad):
|
||||
def _grad_handler(self, group_id, param):
|
||||
# if run with no_sync context, would not sync grad when backward
|
||||
if self.require_grad_sync:
|
||||
self._add_to_bucket(param, group_id)
|
||||
return grad
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the working params
|
||||
|
@ -286,7 +291,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
param.register_hook(partial(self._grad_handler, param, group_id))
|
||||
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
|
||||
|
||||
#######################
|
||||
# Reduction Functions #
|
||||
|
@ -415,7 +420,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
self._update_partitoned_grad(
|
||||
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
|
||||
non_moe_grad_in_bucket_current_rank,
|
||||
recieved_grad,
|
||||
group_id,
|
||||
1,
|
||||
)
|
||||
|
||||
if len(moe_grad_list) > 0:
|
||||
|
@ -423,7 +431,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
|
||||
)
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
|
||||
dist.reduce_scatter(
|
||||
recieved_grad,
|
||||
flat_grads_list,
|
||||
group=self.moe_extra_dp_pg,
|
||||
)
|
||||
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
||||
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
|
||||
for split_recieved_grad in recieved_grad:
|
||||
|
@ -444,14 +456,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._add_grad(grad, self._world_size, group_id, param_id, rank)
|
||||
|
||||
def _update_partitoned_grad(
|
||||
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
|
||||
self,
|
||||
origin_grad_list: List,
|
||||
flat_grad: torch.Tensor,
|
||||
group_id: int,
|
||||
partition_num: int,
|
||||
) -> None:
|
||||
sync_tensor(flat_grad, origin_grad_list)
|
||||
for grad in origin_grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._add_grad(grad, partition_num, group_id, param_id)
|
||||
|
||||
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
|
||||
def _add_grad(
|
||||
self,
|
||||
grad: torch.Tensor,
|
||||
partition_num: int,
|
||||
group_id: int,
|
||||
param_id: int,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
|
@ -534,6 +557,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if param.grad is not None:
|
||||
param.grad.detach()
|
||||
param.grad.zero_()
|
||||
self._bucket_store.reset_all()
|
||||
|
||||
####################
|
||||
# Update Parameter #
|
||||
|
@ -655,14 +679,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
|
||||
all_splited_param,
|
||||
splited_param.to(device).to(self._dtype),
|
||||
group=self.moe_extra_dp_pg,
|
||||
)
|
||||
else:
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
||||
dist.all_gather(
|
||||
all_splited_param,
|
||||
splited_param.to(device).to(self._dtype),
|
||||
group=self.dp_pg,
|
||||
)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
|
@ -685,7 +715,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
[float(total_norm)],
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
@ -698,10 +730,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
[float(total_norm_exponentiated)],
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.float,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
total_norm_exponentiated_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=self.dp_pg,
|
||||
)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
|
@ -920,5 +956,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
if hasattr(self, "moe_master_to_working_map"):
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
return {
|
||||
**self._param_store.master_to_working_param,
|
||||
**self.moe_master_to_working_map,
|
||||
}
|
||||
return self._param_store.master_to_working_param
|
||||
|
|
|
@ -8,7 +8,7 @@ click
|
|||
fabric
|
||||
contexttimer
|
||||
ninja
|
||||
torch>=1.12
|
||||
torch>=2.1.0
|
||||
safetensors
|
||||
einops
|
||||
pydantic
|
||||
|
|
Loading…
Reference in New Issue