[zero] refactor low level zero for shard evenly (#4030)

* refactor low level zero

* fix zero2 and support cpu offload

* avg gradient and modify unit test

* refactor grad store, support layer drop

* refactor bucket store, support grad accumulation

* fix and update unit test of zero and ddp

* compatible with tp, ga and unit test

* fix memory leak and polish

* add zero layer drop unittest

* polish code

* fix import err in unit test

* support diffenert comm dtype, modify docstring style

* polish code

* test padding and fix

* fix unit test of low level zero

* fix pad recording in bucket store

* support some models

* polish
pull/4359/head
LuGY 2023-06-30 15:30:50 +08:00 committed by Hongxin Liu
parent 5187c96b7c
commit c6ab96983a
8 changed files with 424 additions and 470 deletions

View File

@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm
def sync_param(flat_tensor, tensor_list):
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,

View File

@ -1,3 +1,8 @@
from typing import Dict
import torch
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup
from .base_store import BaseStore
@ -7,35 +12,102 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict()
# init and reset
self.current_group_id = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()
self._param_list = []
self._padding_size = []
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
Returns:
int: the total number of elements in bucket
"""
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
return self._num_elements_in_bucket
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
Args:
group_id (int): The index of a parameter group
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._param_list.append(param)
self._padding_size.append(padding_size)
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id
def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
}
"""
for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
Returns:
Dict: The dictionary of gradients slices
"""
return self._grad_in_bucket
def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
Returns:
Tensor: the flattened gradients slices in the bucket
"""
flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to
Args:
grad (Tensor): the gradient slice
Returns:
int: the id of a parameter which the gradient slice belongs to
"""
return self.grad_to_param_mapping[id(grad)]
def reset(self):
keys = [None] + list(range(self._world_size))
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
self.grad_to_param_mapping = dict()
self._num_elements_in_bucket = 0
self._param_list = []
self._padding_size = []
self._grad_in_bucket = dict()
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []

View File

@ -1,88 +1,92 @@
from typing import List
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def append_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
self._grads_of_params mapping the paramater and its gradient slices
data structure:
{
group_id:{
param_id: [grad_rank0, grad_rank1, ...]
}
}
"""
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
"""Return list of gradient slices of a specific parameter
Args:
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
Returns:
List: the list of gradient slices of a parameter.
"""
self._grad_acc_objs.append(obj)
if group_id in self._grads_of_params:
if param_id in self._grads_of_params[group_id]:
return self._grads_of_params[group_id][param_id]
# the param has no grad, for instance, in layer drop
return []
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id]
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
"""Append a gradient slice to the parameter's gradient slice list
Args:
grad (Tensor): The gradient slice to append to list
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
"""
if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
if group_id not in self._grads_of_params:
self._grads_of_params[group_id] = dict()
if param_id not in self._grads_of_params[group_id]:
self._grads_of_params[group_id][param_id] = [grad]
else:
self._averaged_gradients[group_id] = [tensor]
self._grads_of_params[group_id][param_id].append(grad)
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
"""
Add an average gradient to the list of averaged gradients of a parameter group
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
:param group_id: The index of a parameter group
:param tensor_idx: The index of a tensor in the list of averaged gradients
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor_idx: int
:type tensor: torch.Tensor
"""
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
Args:
grad (Tensor): The split gradient to append to list
grad_idx (int): The index of the existing slice
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
"""
self._averaged_gradients[group_id] = []
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
def reset_all_average_gradients(self) -> None:
def get_working_grads_by_group_id(self, group_id: int) -> List:
"""Return list of working gradient slices in the group
Args:
group_id (int): The index of a parameter group
Returns:
List: the list working gradient slices in the group
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
"""
self._averaged_gradients = dict()
grad_list = []
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])
return grad_list
def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
def reset_all_gradients(self):
self._grads_of_params = dict()

View File

@ -1,5 +1,3 @@
from typing import List
from torch import Tensor
from torch.distributed import ProcessGroup
@ -10,88 +8,43 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._param_to_rank = dict()
self._rank_group_id_to_param_list = dict()
self._rank_group_id_to_flat_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
self._reduced_param = []
# record the padding size of each param
self._padding_map = dict()
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
# mapping working param and master param
self.master_to_working_param = dict()
self.working_to_master_param = dict()
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
def record_param_padding_size(self, param: Tensor, padding_size: int):
"""Record the padding size of a param
Args:
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._param_to_rank[tensor] = rank
self._padding_map[id(param)] = padding_size
def get_param_rank(self, tensor: Tensor) -> int:
"""
Gives the rank which the parameter belongs to
def get_param_padding_size(self, param: Tensor) -> int:
"""Return the padding size of the parameter
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._param_to_rank[tensor]
Args:
param (Tensor): The parameter
def belongs_to_current_rank(self, tensor) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
Returns:
int: the padding size of the parameter
"""
tensor_rank = self._param_to_rank[tensor]
return tensor_rank == self._local_rank
return self._padding_map[id(param)]
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_group_id_to_param_list:
self._rank_group_id_to_param_list[rank] = dict()
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
"""Mapping master parameter and working parameter
if group_id not in self._rank_group_id_to_param_list[rank]:
self._rank_group_id_to_param_list[rank][group_id] = []
Args:
master_param (Tensor): The parameter copy in optimizer
working_param (Tensor): The parameter of the model
"""
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_group_id_to_param_list[rank][group_id]
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_param:
self._rank_group_id_to_flat_param[rank] = dict()
self._rank_group_id_to_flat_param[rank][group_id] = tensor
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()
self.master_to_working_param[id(master_param)] = working_param
self.working_to_master_param[id(working_param)] = master_param

View File

@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from contextlib import contextmanager
from functools import partial
from typing import Optional
@ -16,6 +17,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device
from ._utils import (
@ -23,12 +25,10 @@ from ._utils import (
compute_norm,
flatten,
has_inf_or_nan,
reduce_tensor_dp_group,
release_param_grad,
split_by_dtype,
sync_param,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
from .bookkeeping import BucketStore, GradientStore, ParameterStore
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -50,7 +50,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
@ -77,14 +77,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):
# TODO: add support for
# 1. fp16 master weights
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
# 5. support layer drop
assert not (partition_grad and grad_accumulate_interval > 1), \
"gradient accumulation is not compatible with ZeRO-2"
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@ -95,6 +92,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._cpu_offload = cpu_offload
# grad accumulation
self.require_grad_sync = True
self._accumulate_intervel = grad_accumulate_interval
self._accumulate_step = 0
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
@ -122,7 +124,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_flat_param_groups_of_current_rank = dict()
self._master_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
@ -145,7 +147,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer
@ -160,55 +162,17 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank = self._partition_param_list(group_params)
master_param_current_rank = self._create_master_param_current_rank(group_params)
# store the mapping between param to rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
# move_tensor(params, device='cpu')
for param in group_params:
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
# create a copy of fp32 master weights of the parameters for which this rank is responsible
working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
master_flat_current_rank = working_flat_current_rank.float()
device = 'cpu' if self._cpu_offload else get_current_device()
master_flat_current_rank = master_flat_current_rank.to(device)
master_flat_current_rank.requires_grad = True
self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = [master_flat_current_rank]
param_group['params'] = master_param_current_rank
# set reduction state
for param in self._working_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
# initialize communication stream for
# communication-computation overlapping
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
@ -265,29 +229,36 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()
# partition the parameters in a greedy fashion
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for param in sorted_params:
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
numel_per_rank[rank_to_go] += param.numel()
for param in reversed(param_list):
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
self._param_store.record_param_padding_size(param, padding_size)
if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
return params_current_rank
###########################
# Backward Reduction Hook #
###########################
def _grad_handler(self, param, grad, reduce_rank):
self._add_to_reduction_bucket(param, reduce_rank)
def _grad_handler(self, param, group_id, grad):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
return grad
def _attach_reduction_hook(self):
@ -297,149 +268,96 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
# determines the reduction destination rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if self._partition_grads:
reduce_rank = self._param_store.get_param_rank(param)
else:
reduce_rank = None
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_by_dtype(grads)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)
param.register_hook(partial(self._grad_handler, param, group_id))
#######################
# Reduction Functions #
#######################
def _run_reduction(self, reduce_rank=None):
# reduce grads
self._reduce_grads(reduce_rank=reduce_rank,
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
group_id = self._bucket_store.current_group_id
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
grad_dtype = flat_grads.dtype
if self._communication_dtype is not None:
flat_grads = flat_grads.to(self._communication_dtype)
for param in params_in_bucket:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced = self._param_store.is_param_reduced(param)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=self._dp_torch_group)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
grad_in_bucket = self._bucket_store.get_grad()
# update the flag
self._param_store.set_param_reduction_state(param, True)
for rank, grad_list in grad_in_bucket.items():
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
# if partition grads = True
# we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param)
else:
param.grad = None
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
self._bucket_store.reset_by_rank(reduce_rank)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
def _add_to_reduction_bucket(self, param, reduce_rank=None):
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
self._bucket_store.reset()
def _add_to_bucket(self, param, group_id):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._run_reduction(reduce_rank)
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
group_id != self._bucket_store.current_group_id:
self._run_reduction()
# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
padding_size = self._param_store.get_param_padding_size(param)
self._bucket_store.add_param_grad(group_id, param, padding_size)
################################
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=False, sync_grad=True):
def backward(self, loss, retain_graph=False):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
self._accumulate_step += 1
no_sync = self._accumulate_step < self._accumulate_intervel
with conditional_context(self.no_sync(), enable=no_sync):
loss.backward(retain_graph=retain_graph)
if no_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# gradient synchronization
if sync_grad:
self._sync_grad()
self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
@ -467,68 +385,86 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self._accumulate_step == self._accumulate_intervel:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
self._grad_store.reset_all_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self._accumulate_step -= 1
return
# copy the grad of working param to master param
single_grad_partition_groups = []
# record all grads for unscale and clip
grad_partition_groups = []
norm_groups = []
# sometimes not all params are 'really' working
# for instance, when layer drop, the dropped layer has no grad
# and should not be updated
real_working_params = dict()
real_master_params = dict()
grad_index = 0 if self._partition_grads else self._local_rank
for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = []
real_master_params[group_id] = []
for splited_param in master_params:
working_param = self._param_store.master_to_working_param[id(splited_param)]
# if a working param requires grad and has no grad
# it is not 'really' working, e.g. the droped layer
# else the splited grad should be attached to the splited param
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
# compute norm
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
params=self._param_store.get_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
norm_group = compute_norm(gradients=working_grads,
params=real_working_params[group_id],
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 master params
working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_working_avg_grads = flatten(working_avg_grads)
self._grad_store.reset_grads_by_group_id(group_id)
dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype
flat_master_avg_grads = flat_working_avg_grads.to(dtype)
param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_master_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'
single_grad_partition_groups.append(flat_master_avg_grads)
device = self._master_flat_param_groups_of_current_rank[group_id].device
self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
self._grad_store.reset_average_gradients_by_group(group_id)
# update the params in the optimizer
self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
self.optim.step()
# release the master grad
release_param_grad(self._master_flat_param_groups_of_current_rank.values())
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
for group_id in range(len(self._working_param_groups)):
working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
master_param = self._master_flat_param_groups_of_current_rank[group_id]
working_param.data.copy_(master_param)
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
master_working_param = self.optim.param_groups[group_id]['params']
for handle in handles:
handle.wait()
for idx, splited_param in enumerate(master_working_param):
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
working_param = real_working_params[group_id][idx]
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
working_param.data.copy_(full_master_param)
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
# reset accumulate step
self._accumulate_step = 0
#############################
# Mixed Precision Utilities #
@ -553,49 +489,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Gradient Synchronization #
############################
def _sync_grad(self):
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
# accumulate gradient
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients_group) == param_idx:
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
else:
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self.zero_grad()
def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not self._overlap_communication:
if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._add_to_reduction_bucket(param)
self._add_to_bucket(param, group_id)
# we need to reduce the gradients
# left in the communication bucket
# run reduction
self._run_reduction()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._run_reduction(reduce_rank)
# this context comes from pytorch DDP
@contextmanager
def no_sync(self):
old_require_grad_sync = self.require_grad_sync
self.require_grad_sync = False
try:
yield
finally:
self.require_grad_sync = old_require_grad_sync

View File

@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
]
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""
passed_models = []
failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():

View File

@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
clip_grad_norm=1.0,
grad_accumulate_interval=2)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data):
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
# step
zero1_optimizer.step()
@ -101,7 +101,8 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
clip_grad_norm=1.0,
grad_accumulate_interval=2)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -115,13 +116,19 @@ def exam_zero_1_grad_acc():
zero_output = zero_model(cur_data)
# torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
torch_output.sum().backward()
if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
if check_flag:
# check grad
@ -129,8 +136,6 @@ def exam_zero_1_grad_acc():
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)
zero_optimizer._sync_grad()
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
@ -148,7 +153,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc()
# gradient accumulation is not compatible with ZeRO-2
# exam_zero_1_2_grad_acc()
@pytest.mark.dist

View File

@ -2,6 +2,7 @@ import copy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
@ -16,8 +17,9 @@ class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
self.linear1 = nn.Linear(123, 253)
self.linear_drop = nn.Linear(253, 253)
self.linear2 = nn.Linear(253, 512)
def forward(self, x):
x = self.linear1(x)
@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol)
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def exam_zero_1_2():
"""
In this test, we want to test whether zero stage 1 and 2
@ -72,23 +84,21 @@ def exam_zero_1_2():
initial_scale=128)
# create data
seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda()
input_data = torch.randn(32, 123).cuda()
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
# check grad
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
for z1g, z2g in zip(z1g_list, z2g_list):
assert torch.equal(z1g, z2g)
# step
zero1_optimizer.step()
@ -100,7 +110,7 @@ def exam_zero_1_2():
@parameterize('dtype', [torch.float16, torch.bfloat16])
def exam_zero_1_torch_ddp(dtype: torch.dtype):
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
seed_all(1453 + local_rank)
# create
input_data = torch.rand(32, 128).cuda()
input_data = torch.rand(32, 123).cuda()
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p.grad, z1p.grad, dtype=dtype)
if p.grad is not None:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
torch_grad_list = split_ddp_grad(p.grad, world_size)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype)
# zero-dp step
zero_optimizer._sync_grad()
zero_optimizer.step()
# torch ddp step
@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp()
exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_2()