mirror of https://github.com/hpcaitech/ColossalAI
[zero] refactor low level zero for shard evenly (#4030)
* refactor low level zero * fix zero2 and support cpu offload * avg gradient and modify unit test * refactor grad store, support layer drop * refactor bucket store, support grad accumulation * fix and update unit test of zero and ddp * compatible with tp, ga and unit test * fix memory leak and polish * add zero layer drop unittest * polish code * fix import err in unit test * support diffenert comm dtype, modify docstring style * polish code * test padding and fix * fix unit test of low level zero * fix pad recording in bucket store * support some models * polishpull/4359/head
parent
5187c96b7c
commit
c6ab96983a
|
@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
|||
return total_norm
|
||||
|
||||
|
||||
def sync_param(flat_tensor, tensor_list):
|
||||
def sync_tensor(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
@ -7,35 +12,102 @@ class BucketStore(BaseStore):
|
|||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
# init and reset
|
||||
self.current_group_id = 0
|
||||
# mapping gardient slices and parameter
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
self._param_list = []
|
||||
self._padding_size = []
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
def num_elements_in_bucket(self) -> int:
|
||||
"""Return the total number of elements in bucket
|
||||
|
||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
Returns:
|
||||
int: the total number of elements in bucket
|
||||
"""
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
return self._num_elements_in_bucket
|
||||
|
||||
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
|
||||
"""Add a param to bucket and record the padding size of a param for gradient padding
|
||||
|
||||
Args:
|
||||
group_id (int): The index of a parameter group
|
||||
param (Tensor): The parameter
|
||||
padding_size (int): The padding size of the parameter
|
||||
"""
|
||||
|
||||
self._param_list.append(param)
|
||||
self._padding_size.append(padding_size)
|
||||
self._num_elements_in_bucket += (param.numel() + padding_size)
|
||||
self.current_group_id = group_id
|
||||
|
||||
def build_grad_in_bucket(self):
|
||||
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
|
||||
|
||||
Data structure of self._grad_in_bucket:
|
||||
{
|
||||
rank0: [grad0_rank0, grad1_rank0, ...]
|
||||
rank1: [grad1_rank1, grad1_rank1, ...]
|
||||
}
|
||||
"""
|
||||
|
||||
for param, padding_size in zip(self._param_list, self._padding_size):
|
||||
with torch.no_grad():
|
||||
grad = param.grad.detach().flatten()
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
grad_list = grad.split(grad.numel() // self._world_size)
|
||||
for rank in range(self._world_size):
|
||||
grad_current_rank = grad_list[rank].detach()
|
||||
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
|
||||
self._grad_in_bucket[rank].append(grad_current_rank)
|
||||
param.grad = None
|
||||
|
||||
def get_grad(self) -> Dict:
|
||||
"""Return the dictionary of gradients slices, of which the keys are ranks
|
||||
|
||||
Returns:
|
||||
Dict: The dictionary of gradients slices
|
||||
"""
|
||||
|
||||
return self._grad_in_bucket
|
||||
|
||||
def get_flatten_grad(self) -> Tensor:
|
||||
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
|
||||
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
|
||||
|
||||
Returns:
|
||||
Tensor: the flattened gradients slices in the bucket
|
||||
"""
|
||||
|
||||
flat_grad = []
|
||||
for grad_list in self._grad_in_bucket.values():
|
||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||
return flat_grad
|
||||
|
||||
def get_param_id_of_grad(self, grad: Tensor) -> int:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
||||
Args:
|
||||
grad (Tensor): the gradient slice
|
||||
|
||||
Returns:
|
||||
int: the id of a parameter which the gradient slice belongs to
|
||||
"""
|
||||
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
param_list = self.get_param(reduce_rank)
|
||||
for param in param_list:
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
return [param.grad for param in param_list]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
||||
self.grad_to_param_mapping = dict()
|
||||
self._num_elements_in_bucket = 0
|
||||
self._param_list = []
|
||||
self._padding_size = []
|
||||
self._grad_in_bucket = dict()
|
||||
for rank in range(self._world_size):
|
||||
self._grad_in_bucket[rank] = []
|
||||
|
|
|
@ -1,88 +1,92 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def append_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
self._grads_of_params mapping the paramater and its gradient slices
|
||||
data structure:
|
||||
{
|
||||
group_id:{
|
||||
param_id: [grad_rank0, grad_rank1, ...]
|
||||
}
|
||||
}
|
||||
"""
|
||||
self._grads_of_params = dict()
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
"""Return list of gradient slices of a specific parameter
|
||||
|
||||
Args:
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
|
||||
Returns:
|
||||
List: the list of gradient slices of a parameter.
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
if group_id in self._grads_of_params:
|
||||
if param_id in self._grads_of_params[group_id]:
|
||||
return self._grads_of_params[group_id][param_id]
|
||||
# the param has no grad, for instance, in layer drop
|
||||
return []
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
if group_id not in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
|
||||
"""Append a gradient slice to the parameter's gradient slice list
|
||||
|
||||
Args:
|
||||
grad (Tensor): The gradient slice to append to list
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
if group_id not in self._grads_of_params:
|
||||
self._grads_of_params[group_id] = dict()
|
||||
if param_id not in self._grads_of_params[group_id]:
|
||||
self._grads_of_params[group_id][param_id] = [grad]
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Add an average gradient to the list of averaged gradients of a parameter group
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""For old gradient accumulation, not in use now.
|
||||
Add a gradient slice on an existing slice of the parameter's gradient
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor_idx: The index of a tensor in the list of averaged gradients
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor_idx: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
Args:
|
||||
grad (Tensor): The split gradient to append to list
|
||||
grad_idx (int): The index of the existing slice
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
|
||||
|
||||
def reset_all_average_gradients(self) -> None:
|
||||
def get_working_grads_by_group_id(self, group_id: int) -> List:
|
||||
"""Return list of working gradient slices in the group
|
||||
|
||||
Args:
|
||||
group_id (int): The index of a parameter group
|
||||
|
||||
Returns:
|
||||
List: the list working gradient slices in the group
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
"""
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
grad_list = []
|
||||
for param_grads in self._grads_of_params[group_id].values():
|
||||
grad_list.append(param_grads[self._working_index])
|
||||
|
||||
return grad_list
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -10,88 +8,43 @@ class ParameterStore(BaseStore):
|
|||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
# param partitioning data structures
|
||||
self._param_to_rank = dict()
|
||||
self._rank_group_id_to_param_list = dict()
|
||||
self._rank_group_id_to_flat_param = dict()
|
||||
|
||||
# param reduction data structures
|
||||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
# record the padding size of each param
|
||||
self._padding_map = dict()
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
# mapping working param and master param
|
||||
self.master_to_working_param = dict()
|
||||
self.working_to_master_param = dict()
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
:param rank: The rank of which the process is responsible for updating the parameter
|
||||
:type rank: int
|
||||
def record_param_padding_size(self, param: Tensor, padding_size: int):
|
||||
"""Record the padding size of a param
|
||||
|
||||
Args:
|
||||
param (Tensor): The parameter
|
||||
padding_size (int): The padding size of the parameter
|
||||
"""
|
||||
|
||||
self._param_to_rank[tensor] = rank
|
||||
self._padding_map[id(param)] = padding_size
|
||||
|
||||
def get_param_rank(self, tensor: Tensor) -> int:
|
||||
"""
|
||||
Gives the rank which the parameter belongs to
|
||||
def get_param_padding_size(self, param: Tensor) -> int:
|
||||
"""Return the padding size of the parameter
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
"""
|
||||
return self._param_to_rank[tensor]
|
||||
Args:
|
||||
param (Tensor): The parameter
|
||||
|
||||
def belongs_to_current_rank(self, tensor) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
Returns:
|
||||
int: the padding size of the parameter
|
||||
"""
|
||||
|
||||
tensor_rank = self._param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
return self._padding_map[id(param)]
|
||||
|
||||
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||
if rank not in self._rank_group_id_to_param_list:
|
||||
self._rank_group_id_to_param_list[rank] = dict()
|
||||
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
|
||||
"""Mapping master parameter and working parameter
|
||||
|
||||
if group_id not in self._rank_group_id_to_param_list[rank]:
|
||||
self._rank_group_id_to_param_list[rank][group_id] = []
|
||||
Args:
|
||||
master_param (Tensor): The parameter copy in optimizer
|
||||
working_param (Tensor): The parameter of the model
|
||||
"""
|
||||
|
||||
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)
|
||||
|
||||
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
||||
return self._rank_group_id_to_param_list[rank][group_id]
|
||||
|
||||
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_group_id_to_flat_param:
|
||||
self._rank_group_id_to_flat_param[rank] = dict()
|
||||
|
||||
self._rank_group_id_to_flat_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_group_id_to_flat_param[rank][group_id]
|
||||
|
||||
def is_param_reduced(self, tensor):
|
||||
return self._is_param_reduced[tensor]
|
||||
|
||||
def set_param_reduction_state(self, tensor, state):
|
||||
self._is_param_reduced[tensor] = state
|
||||
|
||||
def get_param_reduction_states(self):
|
||||
return self._is_param_reduced
|
||||
|
||||
def reset_previous_reduced_params(self):
|
||||
self._reduced_param = []
|
||||
|
||||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
param.grad = None
|
||||
self.reset_previous_reduced_params()
|
||||
self.master_to_working_param[id(master_param)] = working_param
|
||||
self.working_to_master_param[id(working_param)] = master_param
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
|
@ -16,6 +17,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -23,12 +25,10 @@ from ._utils import (
|
|||
compute_norm,
|
||||
flatten,
|
||||
has_inf_or_nan,
|
||||
reduce_tensor_dp_group,
|
||||
release_param_grad,
|
||||
split_by_dtype,
|
||||
sync_param,
|
||||
sync_tensor,
|
||||
)
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
@ -50,7 +50,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||
|
||||
def check_local_overflow(self) -> bool:
|
||||
for group_id in range(self.num_working_param_groups):
|
||||
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
|
||||
for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
return True
|
||||
return False
|
||||
|
@ -77,14 +77,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
grad_accumulate_interval: int = 1,
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
# 2. contiguous gradients
|
||||
# 3. cpu offload
|
||||
# 4. support when some parameters requires_grad = False
|
||||
# 5. support layer drop
|
||||
assert not (partition_grad and grad_accumulate_interval > 1), \
|
||||
"gradient accumulation is not compatible with ZeRO-2"
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
|
@ -95,6 +92,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self._cpu_offload = cpu_offload
|
||||
|
||||
# grad accumulation
|
||||
self.require_grad_sync = True
|
||||
self._accumulate_intervel = grad_accumulate_interval
|
||||
self._accumulate_step = 0
|
||||
|
||||
colo_pg = self._search_colo_process_group()
|
||||
if isinstance(colo_pg, ProcessGroup):
|
||||
self._local_rank = colo_pg.dp_local_rank()
|
||||
|
@ -122,7 +124,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_flat_param_groups_of_current_rank = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
|
@ -145,7 +147,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(self._dp_torch_group)
|
||||
self._grad_store = GradientStore(self._dp_torch_group)
|
||||
self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
|
||||
self._bucket_store = BucketStore(self._dp_torch_group)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
|
@ -160,55 +162,17 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# add the working params to working_param_groups for bookkeeping
|
||||
self._working_param_groups[group_id] = group_params
|
||||
|
||||
# assign parameters to ranks
|
||||
# the params in the list are sorted
|
||||
params_per_rank = self._partition_param_list(group_params)
|
||||
master_param_current_rank = self._create_master_param_current_rank(group_params)
|
||||
|
||||
# store the mapping between param to rank
|
||||
# each param should belong to only one rank
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
self._param_store.add_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
# move_tensor(params, device='cpu')
|
||||
for param in group_params:
|
||||
param.data = param.data.cpu()
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._world_size):
|
||||
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
|
||||
with torch.no_grad():
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.data.cuda()
|
||||
self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
|
||||
# sync parameters
|
||||
for rank in range(self._world_size):
|
||||
flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
|
||||
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
||||
|
||||
# create a copy of fp32 master weights of the parameters for which this rank is responsible
|
||||
working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
|
||||
master_flat_current_rank = working_flat_current_rank.float()
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
master_flat_current_rank = master_flat_current_rank.to(device)
|
||||
master_flat_current_rank.requires_grad = True
|
||||
self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank
|
||||
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group['params'] = [master_flat_current_rank]
|
||||
param_group['params'] = master_param_current_rank
|
||||
|
||||
# set reduction state
|
||||
for param in self._working_param_groups[group_id]:
|
||||
self._param_store.set_param_reduction_state(param, False)
|
||||
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
|
||||
|
@ -265,29 +229,36 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
|
||||
return colo_pg
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
params_per_rank = [[] for _ in range(self._world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
||||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
params_current_rank = []
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
|
||||
# partition the parameters in a greedy fashion
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for param in sorted_params:
|
||||
# allocate this parameter to the rank with
|
||||
# the smallest numel for load balancing purpose
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
for param in reversed(param_list):
|
||||
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
||||
self._param_store.record_param_padding_size(param, padding_size)
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
|
||||
return params_per_rank
|
||||
with torch.no_grad():
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
|
||||
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
|
||||
params_current_rank.append(splited_param_current_rank)
|
||||
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
|
||||
|
||||
return params_current_rank
|
||||
|
||||
###########################
|
||||
# Backward Reduction Hook #
|
||||
###########################
|
||||
|
||||
def _grad_handler(self, param, grad, reduce_rank):
|
||||
self._add_to_reduction_bucket(param, reduce_rank)
|
||||
def _grad_handler(self, param, group_id, grad):
|
||||
# if run with no_sync context, would not sync grad when backward
|
||||
if self.require_grad_sync:
|
||||
self._add_to_bucket(param, group_id)
|
||||
return grad
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
|
@ -297,149 +268,96 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
# determines the reduction destination rank
|
||||
# this is only valid for stage 2
|
||||
# dst_rank = None means using all-reduce
|
||||
# else using reduce
|
||||
if self._partition_grads:
|
||||
reduce_rank = self._param_store.get_param_rank(param)
|
||||
else:
|
||||
reduce_rank = None
|
||||
|
||||
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
|
||||
|
||||
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduce_global_rank = None
|
||||
if reduce_rank is not None:
|
||||
reduce_global_rank = self._dp_global_ranks[reduce_rank]
|
||||
reduced_flat = reduce_tensor_dp_group(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_local_rank=reduce_rank,
|
||||
dst_global_rank=reduce_global_rank,
|
||||
group=self._dp_torch_group)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_grads(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_by_dtype(grads)
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
|
||||
bucket_size=bucket_size,
|
||||
reduce_rank=reduce_rank)
|
||||
param.register_hook(partial(self._grad_handler, param, group_id))
|
||||
|
||||
#######################
|
||||
# Reduction Functions #
|
||||
#######################
|
||||
|
||||
def _run_reduction(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
def _run_reduction(self):
|
||||
if self._bucket_store.num_elements_in_bucket() > 0:
|
||||
self._bucket_store.build_grad_in_bucket()
|
||||
flat_grads = self._bucket_store.get_flatten_grad()
|
||||
flat_grads /= self._world_size
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(stream):
|
||||
group_id = self._bucket_store.current_group_id
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
grad_dtype = flat_grads.dtype
|
||||
if self._communication_dtype is not None:
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if not self._partition_grads:
|
||||
dist.all_reduce(flat_grads, group=self._dp_torch_group)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
|
||||
'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
||||
grad_in_bucket = self._bucket_store.get_grad()
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
for rank, grad_list in grad_in_bucket.items():
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
|
||||
# if partition grads = True
|
||||
# we do not keep the gradient after reduction
|
||||
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
|
||||
if self._overlap_communication:
|
||||
# we need to keep this gradient for now as reduction may
|
||||
# be completed yet since it is using a different cuda stream
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
else:
|
||||
param.grad = None
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
|
||||
def _add_to_reduction_bucket(self, param, reduce_rank=None):
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
|
||||
for grad in grad_in_bucket_current_rank:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
def _add_to_bucket(self, param, group_id):
|
||||
param_size = param.numel()
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._run_reduction(reduce_rank)
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
self._run_reduction()
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
self._bucket_store.add_param_grad(group_id, param, padding_size)
|
||||
|
||||
################################
|
||||
# torch.optim.Optimizer methods
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||
def backward(self, loss, retain_graph=False):
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
# finish gradient reduction
|
||||
if not self._partition_grads:
|
||||
self._reduce_grad_stage1()
|
||||
else:
|
||||
# TODO: support async comm in reduce
|
||||
self._reduce_grad_stage2()
|
||||
self._accumulate_step += 1
|
||||
no_sync = self._accumulate_step < self._accumulate_intervel
|
||||
with conditional_context(self.no_sync(), enable=no_sync):
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
if no_sync:
|
||||
return
|
||||
|
||||
self._reduce_grad(self._partition_grads)
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# gradient synchronization
|
||||
if sync_grad:
|
||||
self._sync_grad()
|
||||
self.zero_grad()
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
|
@ -467,68 +385,86 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
if not self._accumulate_step == self._accumulate_intervel:
|
||||
return
|
||||
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_average_gradients()
|
||||
self._grad_store.reset_all_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad()
|
||||
self._accumulate_step -= 1
|
||||
return
|
||||
|
||||
# copy the grad of working param to master param
|
||||
single_grad_partition_groups = []
|
||||
# record all grads for unscale and clip
|
||||
grad_partition_groups = []
|
||||
norm_groups = []
|
||||
|
||||
# sometimes not all params are 'really' working
|
||||
# for instance, when layer drop, the dropped layer has no grad
|
||||
# and should not be updated
|
||||
real_working_params = dict()
|
||||
real_master_params = dict()
|
||||
|
||||
grad_index = 0 if self._partition_grads else self._local_rank
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_params = self._master_param_groups_of_current_rank[group_id]
|
||||
real_working_params[group_id] = []
|
||||
real_master_params[group_id] = []
|
||||
for splited_param in master_params:
|
||||
working_param = self._param_store.master_to_working_param[id(splited_param)]
|
||||
# if a working param requires grad and has no grad
|
||||
# it is not 'really' working, e.g. the droped layer
|
||||
# else the splited grad should be attached to the splited param
|
||||
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
||||
if len(grads) > 0:
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
|
||||
splited_param.grad = grad
|
||||
grad_partition_groups.append(grad)
|
||||
real_master_params[group_id].append(splited_param)
|
||||
|
||||
# compute norm
|
||||
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
|
||||
params=self._param_store.get_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
norm_group = compute_norm(gradients=working_grads,
|
||||
params=real_working_params[group_id],
|
||||
dp_group=self._dp_torch_group,
|
||||
mp_group=self._mp_torch_group)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
# create flat gradient for the flat fp32 master params
|
||||
working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
flat_working_avg_grads = flatten(working_avg_grads)
|
||||
self._grad_store.reset_grads_by_group_id(group_id)
|
||||
|
||||
dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_master_avg_grads = flat_working_avg_grads.to(dtype)
|
||||
|
||||
param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert param_shape == flat_master_avg_grads.shape, \
|
||||
f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'
|
||||
|
||||
single_grad_partition_groups.append(flat_master_avg_grads)
|
||||
device = self._master_flat_param_groups_of_current_rank[group_id].device
|
||||
self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
|
||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
# release the master grad
|
||||
release_param_grad(self._master_flat_param_groups_of_current_rank.values())
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
|
||||
master_param = self._master_flat_param_groups_of_current_rank[group_id]
|
||||
working_param.data.copy_(master_param)
|
||||
|
||||
# broadcast the updated model weights
|
||||
handles = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
for index in range(self._world_size):
|
||||
rank = self._dp_global_ranks[index]
|
||||
working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
|
||||
handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
|
||||
handles.append(handle)
|
||||
master_working_param = self.optim.param_groups[group_id]['params']
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
|
||||
dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
|
||||
working_param = real_working_params[group_id][idx]
|
||||
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
|
||||
working_param.data.copy_(full_master_param)
|
||||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
# reset accumulate step
|
||||
self._accumulate_step = 0
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
|
@ -553,49 +489,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# Gradient Synchronization #
|
||||
############################
|
||||
|
||||
def _sync_grad(self):
|
||||
# update param already reduced flag
|
||||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, _ in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
|
||||
# accumulate gradient
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)
|
||||
|
||||
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients_group) == param_idx:
|
||||
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
|
||||
else:
|
||||
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
|
||||
param_idx += 1
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
self.zero_grad()
|
||||
|
||||
def _reduce_grad_stage1(self):
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
def _reduce_grad(self, partition_grad):
|
||||
# if not overlapping communication (no reduction hook is attached) when zero1
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
if not partition_grad and not self._overlap_communication:
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._add_to_reduction_bucket(param)
|
||||
self._add_to_bucket(param, group_id)
|
||||
|
||||
# we need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
# run reduction
|
||||
self._run_reduction()
|
||||
|
||||
def _reduce_grad_stage2(self):
|
||||
# when partition_grads is True, reduction hooks
|
||||
# are attached in the __init__ function, so we
|
||||
# only need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
for reduce_rank in range(self._world_size):
|
||||
self._run_reduction(reduce_rank)
|
||||
# this context comes from pytorch DDP
|
||||
@contextmanager
|
||||
def no_sync(self):
|
||||
old_require_grad_sync = self.require_grad_sync
|
||||
self.require_grad_sync = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.require_grad_sync = old_require_grad_sync
|
||||
|
|
|
@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# These models are not compatible with AMP
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
|
||||
_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
|
||||
# These models have no parameters
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
|
||||
# These models will get stuck
|
||||
_STUCK_MODELS = [
|
||||
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
|
||||
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
|
||||
]
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
|
||||
|
||||
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
|
@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||
"""
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
|
|
|
@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc():
|
|||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0)
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
# create data
|
||||
seed_all(2021 + local_rank)
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data):
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
# zero-dp forward
|
||||
zero1_output = zero1_model(cur_data)
|
||||
zero2_output = zero2_model(cur_data)
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
|
||||
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
if check_flag:
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
zero1_optimizer._sync_grad()
|
||||
zero2_optimizer._sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1)
|
||||
fwd_bwd_func(1, input_data2)
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -101,7 +101,8 @@ def exam_zero_1_grad_acc():
|
|||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=False,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0)
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
||||
|
@ -115,13 +116,19 @@ def exam_zero_1_grad_acc():
|
|||
zero_output = zero_model(cur_data)
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
# torch-ddp backward
|
||||
torch_output.sum().backward()
|
||||
if number < 1:
|
||||
with torch_model.no_sync():
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
else:
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
|
||||
if check_flag:
|
||||
# check grad
|
||||
|
@ -129,8 +136,6 @@ def exam_zero_1_grad_acc():
|
|||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
zero_optimizer._sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
|
||||
|
@ -148,7 +153,8 @@ def run_dist(rank, world_size, port):
|
|||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_grad_acc()
|
||||
exam_zero_1_2_grad_acc()
|
||||
# gradient accumulation is not compatible with ZeRO-2
|
||||
# exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
@ -16,8 +17,9 @@ class MlpModel(nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
self.linear1 = nn.Linear(123, 253)
|
||||
self.linear_drop = nn.Linear(253, 253)
|
||||
self.linear2 = nn.Linear(253, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
|
@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def exam_zero_1_2():
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
|
@ -72,23 +84,21 @@ def exam_zero_1_2():
|
|||
initial_scale=128)
|
||||
# create data
|
||||
seed_all(2001 + local_rank)
|
||||
input_data = torch.randn(32, 128).cuda()
|
||||
input_data = torch.randn(32, 123).cuda()
|
||||
|
||||
zero1_output = zero1_model(input_data)
|
||||
zero2_output = zero2_model(input_data)
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
|
||||
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
|
||||
zero1_optimizer.backward(zero1_output.mean().float())
|
||||
zero2_optimizer.backward(zero2_output.mean().float())
|
||||
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
zero1_optimizer._sync_grad()
|
||||
zero2_optimizer._sync_grad()
|
||||
# check grad
|
||||
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
for z1g, z2g in zip(z1g_list, z2g_list):
|
||||
assert torch.equal(z1g, z2g)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -100,7 +110,7 @@ def exam_zero_1_2():
|
|||
|
||||
|
||||
@parameterize('dtype', [torch.float16, torch.bfloat16])
|
||||
def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
|
@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
|||
torch_model = MlpModel().cuda()
|
||||
zero_model = copy.deepcopy(torch_model).to(dtype)
|
||||
|
||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
|
||||
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
|||
|
||||
seed_all(1453 + local_rank)
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda()
|
||||
input_data = torch.rand(32, 123).cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.to(dtype))
|
||||
|
@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
|||
loose_close(zero_output, torch_output, dtype=dtype)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
loose_close(p.grad, z1p.grad, dtype=dtype)
|
||||
if p.grad is not None:
|
||||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
|
||||
torch_grad_list = split_ddp_grad(p.grad, world_size)
|
||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
||||
loose_close(zero_grad, torch_grad, dtype=dtype)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer._sync_grad()
|
||||
zero_optimizer.step()
|
||||
|
||||
# torch ddp step
|
||||
|
@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
|||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, torch.max(torch.abs(p.data - z1p.data)))
|
||||
loose_close(p.data, z1p.data, dtype=dtype)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_torch_ddp()
|
||||
exam_zero_1_torch_ddp(world_size=world_size)
|
||||
exam_zero_1_2()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue