[zero]remove registered gradients hooks (#5687)

* remove registered hooks

fix

fix

fix zero

fix

fix

fix

fix

fix zero

fix zero

fix

fix

fix

* fix

fix

fix
pull/5525/head
flybird11111 2024-05-07 12:01:38 +08:00 committed by GitHub
parent c25f83c85f
commit 77ec773388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 256 additions and 167 deletions

View File

@ -735,7 +735,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized. # Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads() all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads) grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None: if self._grad_store.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required. # Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else: else:
@ -759,7 +759,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph) super().backward(loss, retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads() self._sync_sp_grads()
else: else:
@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward_by_grad method to compute gradients. # Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads() self._sync_sp_grads()
else: else:
@ -1272,7 +1272,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# run with gradients accumulation # run with gradients accumulation
if model.require_grad_sync == False or ( if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
): ):
return outputs return outputs

View File

@ -6,6 +6,7 @@ class BaseStore:
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
self._world_size = dist.get_world_size(group=torch_pg) self._world_size = dist.get_world_size(group=torch_pg)
self._local_rank = dist.get_rank(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg)
self.torch_pg = torch_pg
@property @property
def world_size(self): def world_size(self):

View File

@ -1,16 +1,43 @@
from typing import Dict from typing import Dict, Optional
import torch import torch
import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch._utils import _flatten_dense_tensors from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from .base_store import BaseStore from .base_store import BaseStore
class BucketStore(BaseStore): class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_communication: bool,
communication_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: ProcessGroup = None,
):
super().__init__(torch_pg) super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
# communication params
self._overlap_communication = overlap_communication
self._communication_dtype = communication_dtype
if self._overlap_communication:
self.comm_stream = get_accelerator().Stream()
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
self.reset_all() self.reset_all()
def reset_all(self) -> None: def reset_all(self) -> None:

View File

@ -6,7 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore): class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False): def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
super().__init__(*args) super().__init__(*args)
""" """
self._grads_of_params mapping the parameter and its gradient slices self._grads_of_params mapping the parameter and its gradient slices
@ -18,9 +18,12 @@ class GradientStore(BaseStore):
} }
""" """
self._grads_of_params = dict() self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]` # stage 2
self._partition_grads = partition_grad
# grad accumulation
self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict() self.grad_to_param_mapping = dict()
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:

View File

@ -90,38 +90,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._logger = get_dist_logger() self._logger = get_dist_logger()
self._verbose = verbose self._verbose = verbose
# stage 2
self._partition_grads = partition_grad
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
# grad accumulation
self.require_grad_sync = True
# if process_group is none, will use the default one
self.dp_pg = dp_process_group
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
# working and master params for mixed precision training # working and master params for mixed precision training
self._working_param_groups = dict() self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict() self._master_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient clipping # gradient clipping
self._clip_grad_norm = clip_grad_norm self._clip_grad_norm = clip_grad_norm
@ -140,9 +114,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# ParameterStore will manage the tensor buffers used for zero # ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training # it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self.dp_pg) self._param_store = ParameterStore(dp_process_group)
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad, require_grad_sync=True)
self._bucket_store = BucketStore(self.dp_pg) self._bucket_store = BucketStore(
dp_process_group, reduce_bucket_size, overlap_communication, communication_dtype, moe_extra_dp_process_group
)
# moe param should not be stored in working_groups # moe param should not be stored in working_groups
# because they have different parallel strategy # because they have different parallel strategy
@ -157,7 +133,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
group_params = list() group_params = list()
for param in param_group["params"]: for param in param_group["params"]:
if param.requires_grad: if param.requires_grad:
if self.moe_extra_dp_pg is None: if self._bucket_store.moe_extra_dp_pg is None:
# skip moe param # skip moe param
if is_moe_tensor(param): if is_moe_tensor(param):
self.working_moe_params.append(param) self.working_moe_params.append(param)
@ -194,15 +170,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group["params"] = self.master_moe_params param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group) self.optim.param_groups.append(param_group)
# initialize communication stream for
# communication-computation overlapping
if self._overlap_communication:
self._comm_stream = get_accelerator().Stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# or stage 2 is used # or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached # if it is stage 1 without overlapping, no hook will be attached
if self._overlap_communication or self._partition_grads: if self._bucket_store._overlap_communication or self._grad_store._partition_grads:
self._attach_reduction_hook() self._attach_reduction_hook()
# initialize mixed precision mixin # initialize mixed precision mixin
@ -222,6 +193,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
elif self._dtype is torch.bfloat16: elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin() self.mixed_precision_mixin = BF16MixedPrecisionMixin()
def __del__(self):
self.remove_hooks()
@property @property
def dtype(self): def dtype(self):
return self._dtype return self._dtype
@ -246,7 +220,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list: for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size padding_size = (
self._bucket_store.zero_world_size - param.numel() % self._bucket_store.zero_world_size
) % self._bucket_store.zero_world_size
self._param_store.record_param_padding_size(param, padding_size) self._param_store.record_param_padding_size(param, padding_size)
with torch.no_grad(): with torch.no_grad():
@ -258,12 +234,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else: else:
padding_param = param.data.view(-1) padding_param = param.data.view(-1)
if self.moe_extra_dp_pg is not None and is_moe_tensor(param): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(param):
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) splited_params = padding_param.split(
splited_params = splited_params[self.moe_extra_dp_pg_rank] padding_param.numel() // self._bucket_store.moe_extra_dp_pg_size
)
splited_params = splited_params[self._bucket_store.moe_extra_dp_pg_rank]
else: else:
splited_params = padding_param.split(padding_param.numel() // self._world_size) splited_params = padding_param.split(padding_param.numel() // self._bucket_store.zero_world_size)
splited_params = splited_params[self._local_rank] splited_params = splited_params[self._bucket_store.zero_local_rank]
# use fp32 when master_weights is True # use fp32 when master_weights is True
if self._master_weights is True: if self._master_weights is True:
@ -280,10 +258,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Backward Reduction Hook # # Backward Reduction Hook #
########################### ###########################
def _grad_handler(self, group_id, param): @staticmethod
def grad_handler(
param: nn.Parameter,
group_id: int,
bucket_store: BucketStore,
param_store: ParameterStore,
grad_store: GradientStore,
):
# if run with no_sync context, would not sync grad when backward # if run with no_sync context, would not sync grad when backward
if self.require_grad_sync: if grad_store.require_grad_sync:
self._add_to_bucket(param, group_id) LowLevelZeroOptimizer.add_to_bucket(param, group_id, bucket_store, param_store, grad_store)
def _attach_reduction_hook(self): def _attach_reduction_hook(self):
# we iterate over the working params # we iterate over the working params
@ -292,29 +277,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id] param_group = self._working_param_groups[group_id]
for param in param_group: for param in param_group:
if param.requires_grad: if param.requires_grad:
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) param._grad_handle = param.register_post_accumulate_grad_hook(
partial(
LowLevelZeroOptimizer.grad_handler,
group_id=group_id,
bucket_store=self._bucket_store,
param_store=self._param_store,
grad_store=self._grad_store,
)
)
####################### #######################
# Reduction Functions # # Reduction Functions #
####################### #######################
@staticmethod
def _run_reduction(self): def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
if self._bucket_store.num_elements_in_bucket() > 0: if bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket() bucket_store.build_grad_in_bucket()
if bucket_store.moe_extra_dp_pg is None:
if self.moe_extra_dp_pg is None: flat_grads = bucket_store.get_flatten_grad()
flat_grads = self._bucket_store.get_flatten_grad() flat_grads /= bucket_store.zero_world_size
flat_grads /= self._world_size
else: else:
# record moe and non moe param # record moe and non moe param
moe_list = [] moe_list = []
for param in self._bucket_store._param_list: for param in bucket_store._param_list:
moe_list.append(is_moe_tensor(param)) moe_list.append(is_moe_tensor(param))
# divide them into different groups # divide them into different groups
moe_grad_list = [] moe_grad_list = []
non_moe_grad_list = [] non_moe_grad_list = []
for grad_list in self._bucket_store._grad_in_bucket.values(): for grad_list in bucket_store._grad_in_bucket.values():
non_moe_cur_grad = [] non_moe_cur_grad = []
moe_cur_grad = [] moe_cur_grad = []
for i in range(len(grad_list)): for i in range(len(grad_list)):
@ -332,7 +324,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for grad_list in non_moe_grad_list: for grad_list in non_moe_grad_list:
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
non_moe_flat_grads /= self._world_size non_moe_flat_grads /= bucket_store.zero_world_size
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
moe_flat_grads = [] moe_flat_grads = []
@ -341,12 +333,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
# ready to add other tensors to bucket # ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket() bucket_store.reset_num_elements_in_bucket()
if self._overlap_communication: if bucket_store._overlap_communication:
stream = self._comm_stream stream = bucket_store.comm_stream
# in case of the memory being reused in the default stream # in case of the memory being reused in the default stream
if self.moe_extra_dp_pg is None: if bucket_store.moe_extra_dp_pg is None:
flat_grads.record_stream(stream) flat_grads.record_stream(stream)
else: else:
if len(non_moe_grad_list) > 0: if len(non_moe_grad_list) > 0:
@ -359,53 +351,63 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
stream = get_accelerator().current_stream() stream = get_accelerator().current_stream()
with get_accelerator().stream(stream): with get_accelerator().stream(stream):
group_id = self._bucket_store.current_group_id group_id = bucket_store.current_group_id
if self.moe_extra_dp_pg is None: if bucket_store.moe_extra_dp_pg is None:
grad_dtype = flat_grads.dtype grad_dtype = flat_grads.dtype
if self._communication_dtype is not None: if bucket_store._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype) flat_grads = flat_grads.to(bucket_store._communication_dtype)
if not self._partition_grads: if not grad_store._partition_grads:
if self.moe_extra_dp_pg is None: if bucket_store.moe_extra_dp_pg is None:
dist.all_reduce(flat_grads, group=self.dp_pg) dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype: if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype) flat_grads = flat_grads.to(grad_dtype)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) flat_grads_per_rank = flat_grads.split(flat_grads.numel() // bucket_store.zero_world_size)
grad_in_bucket = self._bucket_store.get_grad() grad_in_bucket = bucket_store.get_grad()
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, grad_in_bucket.values(), flat_grads_per_rank, group_id
)
# sync extra zero group # sync extra zero group
else: else:
# sync non moe param in global dp group # sync non moe param in global dp group
if len(non_moe_grad_list) > 0: if len(non_moe_grad_list) > 0:
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) dist.all_reduce(non_moe_flat_grads, group=bucket_store.torch_pg)
flat_grads_per_rank = non_moe_flat_grads.split( flat_grads_per_rank = non_moe_flat_grads.split(
non_moe_flat_grads.numel() // self._world_size non_moe_flat_grads.numel() // bucket_store.zero_world_size
)
LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, non_moe_grad_list, flat_grads_per_rank, group_id
) )
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
# sync moe param only in zero group # sync moe param only in zero group
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) dist.all_reduce(moe_flat_grads, group=bucket_store.moe_extra_dp_pg)
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) flat_grads_per_rank = moe_flat_grads.split(
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) moe_flat_grads.numel() // bucket_store.zero_world_size
)
LowLevelZeroOptimizer.update_unpartitoned_grad(
bucket_store, grad_store, moe_grad_list, flat_grads_per_rank, group_id
)
else: else:
if self.moe_extra_dp_pg is None: if bucket_store.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0]) recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype: if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype) recieved_grad = recieved_grad.to(grad_dtype)
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store, grad_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1
)
else: else:
# categorize moe and non moe param # categorize moe and non moe param
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
moe_grad_in_bucket_current_rank = [] moe_grad_in_bucket_current_rank = []
non_moe_grad_in_bucket_current_rank = [] non_moe_grad_in_bucket_current_rank = []
for idx, grad in enumerate(grad_in_bucket_current_rank): for idx, grad in enumerate(grad_in_bucket_current_rank):
@ -416,11 +418,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(non_moe_grad_list) > 0: if len(non_moe_grad_list) > 0:
flat_grads_list = list( flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size)
) )
recieved_grad = torch.zeros_like(flat_grads_list[0]) recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
self._update_partitoned_grad( LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store,
grad_store,
non_moe_grad_in_bucket_current_rank, non_moe_grad_in_bucket_current_rank,
recieved_grad, recieved_grad,
group_id, group_id,
@ -429,35 +433,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
flat_grads_list = list( flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size)
) )
recieved_grad = torch.zeros_like(flat_grads_list[0]) recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter( dist.reduce_scatter(
recieved_grad, recieved_grad,
flat_grads_list, flat_grads_list,
group=self.moe_extra_dp_pg, group=bucket_store.moe_extra_dp_pg,
) )
param_slice = self._world_size // self.moe_extra_dp_pg_size param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad: for split_recieved_grad in recieved_grad:
split_recieved_grad = _unflatten_dense_tensors( split_recieved_grad = _unflatten_dense_tensors(
split_recieved_grad, moe_grad_in_bucket_current_rank split_recieved_grad, moe_grad_in_bucket_current_rank
) )
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
param_id = self._bucket_store.get_param_id_of_grad(grad) param_id = bucket_store.get_param_id_of_grad(grad)
self._add_grad(real_grad, param_slice, group_id, param_id) LowLevelZeroOptimizer.add_grad(
grad_store, real_grad, param_slice, group_id, param_id
)
self._bucket_store.reset() bucket_store.reset()
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: @staticmethod
def update_unpartitoned_grad(
bucket_store: BucketStore,
grad_store: GradientStore,
origin_grad_list: List,
flat_grad_list: List,
group_id: int,
) -> None:
for rank, grad_list in enumerate(origin_grad_list): for rank, grad_list in enumerate(origin_grad_list):
sync_tensor(flat_grad_list[rank], grad_list) sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list: for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad) param_id = bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, self._world_size, group_id, param_id, rank) LowLevelZeroOptimizer.add_grad(grad_store, grad, bucket_store.zero_world_size, group_id, param_id, rank)
def _update_partitoned_grad( @staticmethod
self, def update_partitoned_grad(
bucket_store: BucketStore,
grad_store: GradientStore,
origin_grad_list: List, origin_grad_list: List,
flat_grad: torch.Tensor, flat_grad: torch.Tensor,
group_id: int, group_id: int,
@ -465,23 +480,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
) -> None: ) -> None:
sync_tensor(flat_grad, origin_grad_list) sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list: for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad) param_id = bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id) LowLevelZeroOptimizer.add_grad(grad_store, grad, partition_num, group_id, param_id)
def _add_grad( @staticmethod
self, def add_grad(
grad_store: GradientStore,
grad: torch.Tensor, grad: torch.Tensor,
partition_num: int, partition_num: int,
group_id: int, group_id: int,
param_id: int, param_id: int,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: if len(grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else: else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
def _add_to_bucket(self, param, group_id): @staticmethod
def add_to_bucket(
param: nn.Parameter,
group_id: int,
bucket_store: BucketStore,
param_store: ParameterStore,
grad_store: GradientStore,
):
param_size = param.numel() param_size = param.numel()
# check if the bucket is full # check if the bucket is full
@ -489,13 +512,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# or got a grad of param from another group # or got a grad of param from another group
# after reduction, the bucket will be empty # after reduction, the bucket will be empty
if ( if (
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size bucket_store.num_elements_in_bucket() + param_size > bucket_store.reduce_bucket_size
or group_id != self._bucket_store.current_group_id or group_id != bucket_store.current_group_id
): ):
self._run_reduction() LowLevelZeroOptimizer.run_reduction(bucket_store, grad_store)
padding_size = self._param_store.get_param_padding_size(param) padding_size = param_store.get_param_padding_size(param)
self._bucket_store.add_param_grad(group_id, param, padding_size) bucket_store.add_param_grad(group_id, param, padding_size)
################################ ################################
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
@ -503,7 +526,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._grad_store._partition_grads and not self._grad_store.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible" ), "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
@ -511,31 +534,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync: if not self._grad_store.require_grad_sync:
return return
self._reduce_grad(self._partition_grads) self._reduce_grad(self._grad_store._partition_grads)
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._bucket_store._overlap_communication:
get_accelerator().synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
assert not ( assert not (
self._partition_grads and not self.require_grad_sync self._grad_store._partition_grads and not self._grad_store.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
if not self.require_grad_sync: if not self._grad_store.require_grad_sync:
return return
self._reduce_grad(self._partition_grads) self._reduce_grad(self._grad_store._partition_grads)
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._bucket_store._overlap_communication:
get_accelerator().synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
@ -566,7 +589,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None): def step(self, closure=None):
assert closure is None, "closure is not supported by step()" assert closure is None, "closure is not supported by step()"
if not self.require_grad_sync: if not self._grad_store.require_grad_sync:
return return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
@ -585,7 +608,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# and should not be updated # and should not be updated
real_working_params = dict() real_working_params = dict()
real_master_params = dict() real_master_params = dict()
grad_index = 0 if self._partition_grads else self._local_rank grad_index = 0 if self._grad_store._partition_grads else self._bucket_store.zero_local_rank
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id] master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = [] real_working_params[group_id] = []
@ -598,14 +621,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0: if len(grads) > 0:
# moe hybrid zero # moe hybrid zero
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
real_working_params[group_id].append(working_param) real_working_params[group_id].append(working_param)
if self._partition_grads: if self._grad_store._partition_grads:
grad = grads grad = grads
else: else:
param_slice = self._world_size // self.moe_extra_dp_pg_size param_slice = self._bucket_store.zero_world_size // self._bucket_store.moe_extra_dp_pg_size
grad = grads[ grad = grads[
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice self._bucket_store.moe_extra_dp_pg_rank
* param_slice : (self._bucket_store.moe_extra_dp_pg_rank + 1)
* param_slice
] ]
grad = flatten(grad) grad = flatten(grad)
else: else:
@ -674,25 +699,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param): for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size) for _ in range(self._bucket_store.moe_extra_dp_pg_size)
] ]
dist.all_gather( dist.all_gather(
all_splited_param, all_splited_param,
splited_param.to(device).to(self._dtype), splited_param.to(device).to(self._dtype),
group=self.moe_extra_dp_pg, group=self._bucket_store.moe_extra_dp_pg,
) )
else: else:
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size) for _ in range(self._bucket_store.zero_world_size)
] ]
dist.all_gather( dist.all_gather(
all_splited_param, all_splited_param,
splited_param.to(device).to(self._dtype), splited_param.to(device).to(self._dtype),
group=self.dp_pg, group=self._bucket_store.torch_pg,
) )
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@ -720,7 +745,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device=get_accelerator().get_current_device(), device=get_accelerator().get_current_device(),
dtype=torch.float, dtype=torch.float,
) )
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self._bucket_store.torch_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
else: else:
@ -738,7 +763,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
torch.distributed.all_reduce( torch.distributed.all_reduce(
total_norm_exponentiated_cuda, total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=self.dp_pg, group=self._bucket_store.torch_pg,
) )
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
@ -773,27 +798,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id] param_group = self._working_param_groups[group_id]
for param in param_group: for param in param_group:
if param.requires_grad and param.grad is not None: if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id) LowLevelZeroOptimizer.add_to_bucket(
param,
group_id,
self._bucket_store,
self._param_store,
self._grad_store,
)
self._run_reduction() LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
def _reduce_grad(self, partition_grad): def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1 # if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication: if not partition_grad and not self._bucket_store._overlap_communication:
self._sync_grad() self._sync_grad()
else: else:
self._run_reduction() LowLevelZeroOptimizer.run_reduction(self._bucket_store, self._grad_store)
# this context comes from pytorch DDP # this context comes from pytorch DDP
@contextmanager @contextmanager
def no_sync(self): def no_sync(self):
old_require_grad_sync = self.require_grad_sync old_require_grad_sync = self._grad_store.require_grad_sync
self.require_grad_sync = False self._grad_store.require_grad_sync = False
try: try:
yield yield
finally: finally:
self.require_grad_sync = old_require_grad_sync self._grad_store.require_grad_sync = old_require_grad_sync
############## ##############
# State Dict # # State Dict #
@ -833,16 +864,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)] working_param = self._param_store.master_to_working_param[id(param)]
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
] ]
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
else: else:
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.zero_world_size)
] ]
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self._bucket_store.torch_pg)
param_state = ( param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )
@ -862,17 +895,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_idx, state in zero_state_dict["state"].items(): for param_idx, state in zero_state_dict["state"].items():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size padding_size = (
self._bucket_store.zero_world_size - v.numel() % self._bucket_store.zero_world_size
) % self._bucket_store.zero_world_size
with torch.no_grad(): with torch.no_grad():
v = v.flatten() v = v.flatten()
if padding_size > 0: if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size]) v = torch.nn.functional.pad(v, [0, padding_size])
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) v_list = v.split(v.numel() // self._bucket_store.moe_extra_dp_pg_size)
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() zero_state_dict["state"][param_idx][k] = (
v_list[self._bucket_store.moe_extra_dp_pg_rank].detach().clone()
)
else: else:
v_list = v.split(v.numel() // self._world_size) v_list = v.split(v.numel() // self._bucket_store.zero_world_size)
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() zero_state_dict["state"][param_idx][k] = (
v_list[self._bucket_store.zero_local_rank].detach().clone()
)
self.optim.load_state_dict(zero_state_dict) self.optim.load_state_dict(zero_state_dict)
@ -904,16 +943,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items(): for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
] ]
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.moe_extra_dp_pg)
else: else:
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype)
for _ in range(self._bucket_store.zero_world_size)
] ]
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) dist.all_gather(state_tensor, v.to(device), group=self._bucket_store.torch_pg)
state_tensor = ( state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )
@ -944,14 +985,30 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = p.data.view(-1) working_param = p.data.view(-1)
if padding_size > 0: if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size]) working_param = torch.nn.functional.pad(working_param, [0, padding_size])
if self.moe_extra_dp_pg is not None and is_moe_tensor(p): if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(p):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else: else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) master_param.copy_(
working_param.chunk(self._bucket_store.zero_world_size)[self._bucket_store.zero_local_rank]
)
if hasattr(self, "master_moe_params"): if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param) master_moe_param.copy_(working_moe_param)
def remove_hooks(self) -> None:
"""remove the registered hooks
Args:
plugin (LowLevelZeroPlugin): the plugin to bound this method.
"""
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
assert hasattr(param, "_grad_handle")
param._grad_handle.remove()
delattr(param, "_grad_handle")
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param return self._param_store.working_to_master_param

View File

@ -80,7 +80,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
skipped_models.append(name) skipped_models.append(name)
continue continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
get_accelerator().empty_cache() get_accelerator().empty_cache()
if err is None: if err is None:

View File

@ -64,7 +64,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
)
grad = grads[grad_index] grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)