[zero] refactor low level zero for shard evenly (#4030)

* refactor low level zero

* fix zero2 and support cpu offload

* avg gradient and modify unit test

* refactor grad store, support layer drop

* refactor bucket store, support grad accumulation

* fix and update unit test of zero and ddp

* compatible with tp, ga and unit test

* fix memory leak and polish

* add zero layer drop unittest

* polish code

* fix import err in unit test

* support diffenert comm dtype, modify docstring style

* polish code

* test padding and fix

* fix unit test of low level zero

* fix pad recording in bucket store

* support some models

* polish
pull/4359/head
LuGY 2023-06-30 15:30:50 +08:00 committed by Hongxin Liu
parent 5187c96b7c
commit c6ab96983a
8 changed files with 424 additions and 470 deletions

View File

@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm return total_norm
def sync_param(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,

View File

@ -1,3 +1,8 @@
from typing import Dict
import torch
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
@ -7,35 +12,102 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict() # init and reset
self.current_group_id = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()
self._param_list = []
self._padding_size = []
self.reset() self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None): def num_elements_in_bucket(self) -> int:
return self._num_elements_in_bucket[reduce_rank] """Return the total number of elements in bucket
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): Returns:
self._num_elements_in_bucket[reduce_rank] += num_elements int: the total number of elements in bucket
"""
def add_param(self, tensor, reduce_rank: int = None): return self._num_elements_in_bucket
self._params[reduce_rank].append(tensor)
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
Args:
group_id (int): The index of a parameter group
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._param_list.append(param)
self._padding_size.append(padding_size)
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id
def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
}
"""
for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
Returns:
Dict: The dictionary of gradients slices
"""
return self._grad_in_bucket
def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
Returns:
Tensor: the flattened gradients slices in the bucket
"""
flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to
Args:
grad (Tensor): the gradient slice
Returns:
int: the id of a parameter which the gradient slice belongs to
"""
return self.grad_to_param_mapping[id(grad)]
def reset(self): def reset(self):
keys = [None] + list(range(self._world_size)) self.grad_to_param_mapping = dict()
self._params = {rank: [] for rank in keys} self._num_elements_in_bucket = 0
self._num_elements_in_bucket = {rank: 0 for rank in keys} self._param_list = []
self._padding_size = []
def reset_by_rank(self, reduce_rank=None): self._grad_in_bucket = dict()
self._params[reduce_rank] = [] for rank in range(self._world_size):
self._num_elements_in_bucket[reduce_rank] = 0 self._grad_in_bucket[rank] = []
def get_grad(self, reduce_rank: int = None):
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]

View File

@ -1,88 +1,92 @@
from typing import List from typing import List
from torch import Tensor from torch import Tensor
from torch._utils import _flatten_dense_tensors
from .base_store import BaseStore from .base_store import BaseStore
class GradientStore(BaseStore): class GradientStore(BaseStore):
def __init__(self, *args): def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args) super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
def append_accumulate_grad_object(self, obj):
""" """
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not self._grads_of_params mapping the paramater and its gradient slices
be attached successfully. data structure:
{
group_id:{
param_id: [grad_rank0, grad_rank1, ...]
}
}
"""
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank
:param obj: An object of :class:`AccumulateGrad` class def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
:type obj: :class:`AccumulateGrad` """Return list of gradient slices of a specific parameter
Args:
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
Returns:
List: the list of gradient slices of a parameter.
""" """
self._grad_acc_objs.append(obj) if group_id in self._grads_of_params:
if param_id in self._grads_of_params[group_id]:
return self._grads_of_params[group_id][param_id]
# the param has no grad, for instance, in layer drop
return []
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
""" """Append a gradient slice to the parameter's gradient slice list
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id]
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
Args:
grad (Tensor): The gradient slice to append to list
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
""" """
if group_id in self._averaged_gradients: if group_id not in self._grads_of_params:
self._averaged_gradients[group_id].append(tensor) self._grads_of_params[group_id] = dict()
if param_id not in self._grads_of_params[group_id]:
self._grads_of_params[group_id][param_id] = [grad]
else: else:
self._averaged_gradients[group_id] = [tensor] self._grads_of_params[group_id][param_id].append(grad)
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
""" """For old gradient accumulation, not in use now.
Add an average gradient to the list of averaged gradients of a parameter group Add a gradient slice on an existing slice of the parameter's gradient
:param group_id: The index of a parameter group Args:
:param tensor_idx: The index of a tensor in the list of averaged gradients grad (Tensor): The split gradient to append to list
:param tensor: A :class:`torch.Tensor` object grad_idx (int): The index of the existing slice
:type group_id: int group_id (int): The index of a parameter group
:type tensor_idx: int param_id (int): The id of a parameter
:type tensor: torch.Tensor
"""
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
""" """
self._averaged_gradients[group_id] = [] self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
def reset_all_average_gradients(self) -> None: def get_working_grads_by_group_id(self, group_id: int) -> List:
"""Return list of working gradient slices in the group
Args:
group_id (int): The index of a parameter group
Returns:
List: the list working gradient slices in the group
""" """
Reset the bookkeeping data structure for averaged gradients to an empty list
""" grad_list = []
self._averaged_gradients = dict() for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])
return grad_list
def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
def reset_all_gradients(self):
self._grads_of_params = dict()

View File

@ -1,5 +1,3 @@
from typing import List
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -10,88 +8,43 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup): def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg) super().__init__(torch_pg)
# param partitioning data structures
self._param_to_rank = dict()
self._rank_group_id_to_param_list = dict()
self._rank_group_id_to_flat_param = dict()
# param reduction data structures # record the padding size of each param
self._is_param_reduced = dict() self._padding_map = dict()
self._reduced_param = []
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: # mapping working param and master param
""" self.master_to_working_param = dict()
Set the mapping between parameter to rank, each parameter should be owned by a rank. self.working_to_master_param = dict()
:param tensor: A :class:`torch.Tensor` object def record_param_padding_size(self, param: Tensor, padding_size: int):
:type tensor: torch.Tensor """Record the padding size of a param
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int Args:
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
""" """
self._param_to_rank[tensor] = rank self._padding_map[id(param)] = padding_size
def get_param_rank(self, tensor: Tensor) -> int: def get_param_padding_size(self, param: Tensor) -> int:
""" """Return the padding size of the parameter
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object Args:
:type tensor: torch.Tensor param (Tensor): The parameter
"""
return self._param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool: Returns:
""" int: the padding size of the parameter
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
""" """
tensor_rank = self._param_to_rank[tensor] return self._padding_map[id(param)]
return tensor_rank == self._local_rank
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
if rank not in self._rank_group_id_to_param_list: """Mapping master parameter and working parameter
self._rank_group_id_to_param_list[rank] = dict()
if group_id not in self._rank_group_id_to_param_list[rank]: Args:
self._rank_group_id_to_param_list[rank][group_id] = [] master_param (Tensor): The parameter copy in optimizer
working_param (Tensor): The parameter of the model
"""
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list) self.master_to_working_param[id(master_param)] = working_param
self.working_to_master_param[id(working_param)] = master_param
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_group_id_to_param_list[rank][group_id]
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_param:
self._rank_group_id_to_flat_param[rank] = dict()
self._rank_group_id_to_flat_param[rank][group_id] = tensor
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
def set_param_reduction_state(self, tensor, state):
self._is_param_reduced[tensor] = state
def get_param_reduction_states(self):
return self._is_param_reduced
def reset_previous_reduced_params(self):
self._reduced_param = []
def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor)
def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0:
for param in self._reduced_param:
param.grad = None
self.reset_previous_reduced_params()

View File

@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Optional
@ -16,6 +17,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import ( from ._utils import (
@ -23,12 +25,10 @@ from ._utils import (
compute_norm, compute_norm,
flatten, flatten,
has_inf_or_nan, has_inf_or_nan,
reduce_tensor_dp_group,
release_param_grad, release_param_grad,
split_by_dtype, sync_tensor,
sync_param,
) )
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket from .bookkeeping import BucketStore, GradientStore, ParameterStore
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -50,7 +50,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def check_local_overflow(self) -> bool: def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups): for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad): if avg_grad is not None and has_inf_or_nan(avg_grad):
return True return True
return False return False
@ -77,14 +77,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
overlap_communication: bool = False, overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None):
# TODO: add support for assert not (partition_grad and grad_accumulate_interval > 1), \
# 1. fp16 master weights "gradient accumulation is not compatible with ZeRO-2"
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
# 5. support layer drop
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger() self._logger = get_dist_logger()
@ -95,6 +92,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
# grad accumulation
self.require_grad_sync = True
self._accumulate_intervel = grad_accumulate_interval
self._accumulate_step = 0
colo_pg = self._search_colo_process_group() colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup): if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank() self._local_rank = colo_pg.dp_local_rank()
@ -122,7 +124,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# working and master params for mixed precision training # working and master params for mixed precision training
self._working_param_groups = dict() self._working_param_groups = dict()
self._master_flat_param_groups_of_current_rank = dict() self._master_param_groups_of_current_rank = dict()
# communication params # communication params
self._overlap_communication = overlap_communication self._overlap_communication = overlap_communication
@ -145,7 +147,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# ParameterStore will manage the tensor buffers used for zero # ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training # it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_torch_group) self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group) self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(self._dp_torch_group) self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
@ -160,55 +162,17 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# add the working params to working_param_groups for bookkeeping # add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params self._working_param_groups[group_id] = group_params
# assign parameters to ranks master_param_current_rank = self._create_master_param_current_rank(group_params)
# the params in the list are sorted
params_per_rank = self._partition_param_list(group_params)
# store the mapping between param to rank self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor
# move_tensor(params, device='cpu')
for param in group_params:
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
# create a copy of fp32 master weights of the parameters for which this rank is responsible
working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
master_flat_current_rank = working_flat_current_rank.float()
device = 'cpu' if self._cpu_offload else get_current_device()
master_flat_current_rank = master_flat_current_rank.to(device)
master_flat_current_rank.requires_grad = True
self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank
# need to replace the params in the `params` field in the optimizer # need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors # so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank # managed by this data parallel rank
param_group['params'] = [master_flat_current_rank] param_group['params'] = master_param_current_rank
# set reduction state # intialize communication stream for
for param in self._working_param_groups[group_id]: # communication-compuation overlapping
self._param_store.set_param_reduction_state(param, False)
# initialize communication stream for
# communication-computation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = torch.cuda.Stream() self._comm_stream = torch.cuda.Stream()
@ -265,29 +229,36 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg return colo_pg
def _partition_param_list(self, param_list): def _create_master_param_current_rank(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)] # split each param evenly by world size
numel_per_rank = [0 for _ in range(self._world_size)] params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()
# partition the parameters in a greedy fashion for param in reversed(param_list):
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
for param in sorted_params: self._param_store.record_param_padding_size(param, padding_size)
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
numel_per_rank[rank_to_go] += param.numel()
if self._verbose: with torch.no_grad():
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) if padding_size > 0:
return params_per_rank padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
return params_current_rank
########################### ###########################
# Backward Reduction Hook # # Backward Reduction Hook #
########################### ###########################
def _grad_handler(self, param, grad, reduce_rank): def _grad_handler(self, param, group_id, grad):
self._add_to_reduction_bucket(param, reduce_rank) # if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
return grad return grad
def _attach_reduction_hook(self): def _attach_reduction_hook(self):
@ -297,149 +268,96 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group = self._working_param_groups[group_id] param_group = self._working_param_groups[group_id]
for param in param_group: for param in param_group:
if param.requires_grad: if param.requires_grad:
# determines the reduction destination rank param.register_hook(partial(self._grad_handler, param, group_id))
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if self._partition_grads:
reduce_rank = self._param_store.get_param_rank(param)
else:
reduce_rank = None
param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype,
dst_local_rank=reduce_rank,
dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank:
bucket.unflatten_and_copy(reduced_flat)
def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_by_dtype(grads)
for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
bucket_size=bucket_size,
reduce_rank=reduce_rank)
####################### #######################
# Reduction Functions # # Reduction Functions #
####################### #######################
def _run_reduction(self, reduce_rank=None): def _run_reduction(self):
# reduce grads if self._bucket_store.num_elements_in_bucket() > 0:
self._reduce_grads(reduce_rank=reduce_rank, self._bucket_store.build_grad_in_bucket()
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), flat_grads = self._bucket_store.get_flatten_grad()
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) flat_grads /= self._world_size
# use communication stream if overlapping
# communication with computation
if self._overlap_communication: if self._overlap_communication:
stream = self._comm_stream stream = self._comm_stream
else: else:
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) group_id = self._bucket_store.current_group_id
for param in params_in_bucket: grad_dtype = flat_grads.dtype
# the is_param_reduced flag should be False showing that if self._communication_dtype is not None:
# this param is not reduced before calling self._reduce_grads_by_rank flat_grads = flat_grads.to(self._communication_dtype)
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced: if not self._partition_grads:
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ dist.all_reduce(flat_grads, group=self._dp_torch_group)
'duplicate reduction will lead to arithmetic incorrectness' if flat_grads.dtype != grad_dtype:
raise RuntimeError(msg) flat_grads = flat_grads.to(grad_dtype)
# update the flag flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
self._param_store.set_param_reduction_state(param, True) grad_in_bucket = self._bucket_store.get_grad()
for rank, grad_list in grad_in_bucket.items():
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
# if partition grads = True
# we do not keep the gradient after reduction
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
if self._overlap_communication:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self._param_store.add_previous_reduced_param(param)
else: else:
param.grad = None flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
self._bucket_store.reset_by_rank(reduce_rank) if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
def _add_to_reduction_bucket(self, param, reduce_rank=None): grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
self._bucket_store.reset()
def _add_to_bucket(self, param, group_id):
param_size = param.numel() param_size = param.numel()
# check if the bucket is full # check if the bucket is full
# if full, will reduce the grads already in the bucket # if full, will reduce the grads already in the bucket
# or got a grad of param from another group
# after reduction, the bucket will be empty # after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
self._run_reduction(reduce_rank) group_id != self._bucket_store.current_group_id:
self._run_reduction()
# the param must not be reduced to ensure correctness padding_size = self._param_store.get_param_padding_size(param)
is_param_reduced = self._param_store.is_param_reduced(param) self._bucket_store.add_param_grad(group_id, param, padding_size)
if is_param_reduced:
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ 'duplicate reduction will lead to arithmetic incorrectness'
raise RuntimeError(msg)
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
self._bucket_store.add_param(param, reduce_rank)
################################ ################################
# torch.optim.Optimizer methods # torch.optim.Optimizer methods
################################ ################################
def backward(self, loss, retain_graph=False, sync_grad=True): def backward(self, loss, retain_graph=False):
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss) loss = self.mixed_precision_mixin.pre_backward(loss)
self._accumulate_step += 1
no_sync = self._accumulate_step < self._accumulate_intervel
with conditional_context(self.no_sync(), enable=no_sync):
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
# finish gradient reduction if no_sync:
if not self._partition_grads: return
self._reduce_grad_stage1()
else: self._reduce_grad(self._partition_grads)
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# gradient synchronization self.zero_grad()
if sync_grad:
self._sync_grad()
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
""" """
@ -467,68 +385,86 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None): def step(self, closure=None):
assert closure is None, 'closure is not supported by step()' assert closure is None, 'closure is not supported by step()'
if not self._accumulate_step == self._accumulate_intervel:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients() self._grad_store.reset_all_gradients()
if self._verbose: if self._verbose:
self._logger.info(f'Found overflow. Skip step') self._logger.info(f'Found overflow. Skip step')
self.zero_grad() self.zero_grad()
self._accumulate_step -= 1
return return
# copy the grad of working param to master param # record all grads for unscale and clip
single_grad_partition_groups = [] grad_partition_groups = []
norm_groups = [] norm_groups = []
# sometimes not all params are 'really' working
# for instance, when layer drop, the dropped layer has no grad
# and should not be updated
real_working_params = dict()
real_master_params = dict()
grad_index = 0 if self._partition_grads else self._local_rank
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = []
real_master_params[group_id] = []
for splited_param in master_params:
working_param = self._param_store.master_to_working_param[id(splited_param)]
# if a working param requires grad and has no grad
# it is not 'really' working, e.g. the droped layer
# else the splited grad should be attached to the splited param
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
# compute norm # compute norm
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
params=self._param_store.get_params_by_rank_group(group_id=group_id, norm_group = compute_norm(gradients=working_grads,
rank=self._local_rank), params=real_working_params[group_id],
dp_group=self._dp_torch_group, dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group) mp_group=self._mp_torch_group)
norm_groups.append(norm_group) norm_groups.append(norm_group)
# create flat gradient for the flat fp32 master params self._grad_store.reset_grads_by_group_id(group_id)
working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_working_avg_grads = flatten(working_avg_grads)
dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype # update the params in the optimizer
flat_master_avg_grads = flat_working_avg_grads.to(dtype) self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_master_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'
single_grad_partition_groups.append(flat_master_avg_grads)
device = self._master_flat_param_groups_of_current_rank[group_id].device
self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
self._grad_store.reset_average_gradients_by_group(group_id)
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters # update the parameters
self.optim.step() self.optim.step()
# release the master grad
release_param_grad(self._master_flat_param_groups_of_current_rank.values()) # release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank # update working partition updated by the current rank
for group_id in range(len(self._working_param_groups)):
working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
master_param = self._master_flat_param_groups_of_current_rank[group_id]
working_param.data.copy_(master_param)
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
for index in range(self._world_size): master_working_param = self.optim.param_groups[group_id]['params']
rank = self._dp_global_ranks[index]
working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
for handle in handles: for idx, splited_param in enumerate(master_working_param):
handle.wait() full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
working_param = real_working_params[group_id][idx]
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
working_param.data.copy_(full_master_param)
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
# reset accumulate step
self._accumulate_step = 0
############################# #############################
# Mixed Precision Utilities # # Mixed Precision Utilities #
@ -553,49 +489,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Gradient Synchronization # # Gradient Synchronization #
############################ ############################
def _sync_grad(self): def _reduce_grad(self, partition_grad):
# update param already reduced flag # if not overlapping communication (no reduction hook is attached) when zero1
reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
# accumulate gradient
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients_group) == param_idx:
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
else:
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self.zero_grad()
def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not self._overlap_communication: if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)): for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id] param_group = self._working_param_groups[group_id]
for param in param_group: for param in param_group:
if param.grad is not None: if param.grad is not None:
self._add_to_reduction_bucket(param) self._add_to_bucket(param, group_id)
# we need to reduce the gradients # run reduction
# left in the communication bucket
self._run_reduction() self._run_reduction()
def _reduce_grad_stage2(self): # this context comes from pytorch DDP
# when partition_grads is True, reduction hooks @contextmanager
# are attached in the __init__ function, so we def no_sync(self):
# only need to reduce the gradients old_require_grad_sync = self.require_grad_sync
# left in the communication bucket self.require_grad_sync = False
for reduce_rank in range(self._world_size): try:
self._run_reduction(reduce_rank) yield
finally:
self.require_grad_sync = old_require_grad_sync

View File

@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP # These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] _AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
# These models have no parameters # These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] _LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
""" """
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
skipped_models = [] skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():

View File

@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True, overlap_communication=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0, clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0) clip_grad_norm=1.0,
grad_accumulate_interval=2)
# create data # create data
seed_all(2021 + local_rank) seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data): def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward # zero-dp forward
zero1_output = zero1_model(cur_data) zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data) zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output) assert torch.equal(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) zero2_optimizer.backward(zero2_output.sum().float())
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None: if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad) assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad() fwd_bwd_func(0, input_data1, True)
zero2_optimizer._sync_grad() fwd_bwd_func(1, input_data2, False)
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
# step # step
zero1_optimizer.step() zero1_optimizer.step()
@ -101,7 +101,8 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False, overlap_communication=False,
reduce_bucket_size=262144, reduce_bucket_size=262144,
clip_grad_norm=1.0) clip_grad_norm=1.0,
grad_accumulate_interval=2)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -115,12 +116,18 @@ def exam_zero_1_grad_acc():
zero_output = zero_model(cur_data) zero_output = zero_model(cur_data)
# torch-ddp forward # torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward # torch-ddp backward
if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward() torch_output.sum().backward()
if check_flag: if check_flag:
@ -129,8 +136,6 @@ def exam_zero_1_grad_acc():
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad) assert torch.equal(p.grad, z1p.grad)
zero_optimizer._sync_grad()
fwd_bwd_func(0, input_data1, True) fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False) fwd_bwd_func(1, input_data2, False)
@ -148,7 +153,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc() exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc() # gradient accumulation is not compatible with ZeRO-2
# exam_zero_1_2_grad_acc()
@pytest.mark.dist @pytest.mark.dist

View File

@ -2,6 +2,7 @@ import copy
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
@ -16,8 +17,9 @@ class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256) self.linear1 = nn.Linear(123, 253)
self.linear2 = nn.Linear(256, 512) self.linear_drop = nn.Linear(253, 253)
self.linear2 = nn.Linear(253, 512)
def forward(self, x): def forward(self, x):
x = self.linear1(x) x = self.linear1(x)
@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol) assert_close(a, b, rtol=rtol, atol=atol)
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def exam_zero_1_2(): def exam_zero_1_2():
""" """
In this test, we want to test whether zero stage 1 and 2 In this test, we want to test whether zero stage 1 and 2
@ -72,23 +84,21 @@ def exam_zero_1_2():
initial_scale=128) initial_scale=128)
# create data # create data
seed_all(2001 + local_rank) seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda() input_data = torch.randn(32, 123).cuda()
zero1_output = zero1_model(input_data) zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data) zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output) assert torch.equal(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) zero2_optimizer.backward(zero2_output.mean().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): # check grad
if z2p.grad is not None: z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
assert torch.equal(z1p.grad, z2p.grad) for z1g, z2g in zip(z1g_list, z2g_list):
assert torch.equal(z1g, z2g)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
# step # step
zero1_optimizer.step() zero1_optimizer.step()
@ -100,7 +110,7 @@ def exam_zero_1_2():
@parameterize('dtype', [torch.float16, torch.bfloat16]) @parameterize('dtype', [torch.float16, torch.bfloat16])
def exam_zero_1_torch_ddp(dtype: torch.dtype): def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
""" """
In this test, two pairs of model and optimizers are created. In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters 1. zero: use sharded optimizer and fp16 parameters
@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype) zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer # create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
seed_all(1453 + local_rank) seed_all(1453 + local_rank)
# create # create
input_data = torch.rand(32, 128).cuda() input_data = torch.rand(32, 123).cuda()
# zero-dp forward # zero-dp forward
zero_output = zero_model(input_data.to(dtype)) zero_output = zero_model(input_data.to(dtype))
@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
loose_close(zero_output, torch_output, dtype=dtype) loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward # torch-ddp backward
torch_output.mean().backward() torch_output.mean().backward()
# check grad # check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p.grad, z1p.grad, dtype=dtype) if p.grad is not None:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
torch_grad_list = split_ddp_grad(p.grad, world_size)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype)
# zero-dp step # zero-dp step
zero_optimizer._sync_grad()
zero_optimizer.step() zero_optimizer.step()
# torch ddp step # torch ddp step
@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
# check updated param # check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
loose_close(p.data, z1p.data, dtype=dtype) loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp() exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_2() exam_zero_1_2()