mirror of https://github.com/hpcaitech/ColossalAI
[zero] migrate zero1&2 (#1878)
* add zero1&2 optimizer * rename test ditectory * rename test files * change tolerance in testpull/1893/head
parent
cc55ff0aa4
commit
6e51d296f0
|
@ -2,9 +2,11 @@ from typing import Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
|
||||
|
||||
from .zero_optimizer import ZeroOptimizer
|
||||
|
||||
|
||||
|
@ -36,4 +38,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
|||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .low_level_optim import LowLevelZeroOptimizer
|
||||
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
__all__ = ['ShardedOptimizerV2']
|
||||
__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer']
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from .bucket_store import BucketStore
|
||||
from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
|
@ -0,0 +1,17 @@
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, dp_parallel_mode=ParallelMode.DATA):
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
|
@ -0,0 +1,44 @@
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_parallel_mode):
|
||||
super().__init__(dp_parallel_mode)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
|
||||
self._num_elements_in_bucket[reduce_rank] += num_elements
|
||||
|
||||
def add_grad(self, tensor, reduce_rank: int = None):
|
||||
self._grads[reduce_rank].append(tensor)
|
||||
|
||||
def add_param(self, tensor, reduce_rank: int = None):
|
||||
self._params[reduce_rank].append(tensor)
|
||||
|
||||
def reset(self):
|
||||
keys = [None] + list(range(self._world_size))
|
||||
self._grads = {rank: [] for rank in keys}
|
||||
self._params = {rank: [] for rank in keys}
|
||||
self._num_elements_in_bucket = {rank: 0 for rank in keys}
|
||||
|
||||
def reset_by_rank(self, reduce_rank=None):
|
||||
self._grads[reduce_rank] = []
|
||||
self._params[reduce_rank] = []
|
||||
self._num_elements_in_bucket[reduce_rank] = 0
|
||||
|
||||
def get_grad(self, reduce_rank: int = None):
|
||||
return self._grads[reduce_rank]
|
||||
|
||||
def get_param(self, reduce_rank: int = None):
|
||||
return self._params[reduce_rank]
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def add_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
|
@ -0,0 +1,96 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, dp_paralle_mode):
|
||||
super().__init__(dp_paralle_mode)
|
||||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
self._rank_group_id_to_flat_fp16_param = dict()
|
||||
|
||||
# param reduction data structures
|
||||
self._is_param_reduced = dict()
|
||||
self._reduced_param = []
|
||||
|
||||
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
|
||||
"""
|
||||
Set the mapping between parameter to rank, each parameter should be owned by a rank.
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
:param rank: The rank of which the process is responsible for updating the parameter
|
||||
:type rank: int
|
||||
"""
|
||||
|
||||
self._fp16_param_to_rank[tensor] = rank
|
||||
|
||||
def get_param_rank(self, tensor: Tensor) -> int:
|
||||
"""
|
||||
Gives the rank which the parameter belongs to
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
"""
|
||||
return self._fp16_param_to_rank[tensor]
|
||||
|
||||
def belongs_to_current_rank(self, tensor) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||
|
||||
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
|
||||
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
|
||||
|
||||
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
||||
return self._rank_groupid_to_fp16_param_list[rank][group_id]
|
||||
|
||||
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_group_id_to_flat_fp16_param:
|
||||
self._rank_group_id_to_flat_fp16_param[rank] = dict()
|
||||
|
||||
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
|
||||
|
||||
def is_param_reduced(self, tensor):
|
||||
return self._is_param_reduced[tensor]
|
||||
|
||||
def set_param_reduction_state(self, tensor, state):
|
||||
self._is_param_reduced[tensor] = state
|
||||
|
||||
def get_param_reduction_states(self):
|
||||
return self._is_param_reduced
|
||||
|
||||
def reset_previous_reduced_params(self):
|
||||
self._reduced_param = []
|
||||
|
||||
def add_previous_reduced_param(self, tensor):
|
||||
self._reduced_param.append(tensor)
|
||||
|
||||
def clear_grads_of_previous_reduced_params(self):
|
||||
if len(self._reduced_param) > 0:
|
||||
for param in self._reduced_param:
|
||||
param.grad = None
|
||||
self.reset_previous_reduced_params()
|
|
@ -0,0 +1,53 @@
|
|||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
return self._max_size
|
||||
|
||||
@property
|
||||
def current_size(self):
|
||||
return self._current_size
|
||||
|
||||
def is_full_or_oversized(self):
|
||||
return self._current_size >= self._max_size
|
||||
|
||||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
||||
+ f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
self._current_size += tensor_size
|
||||
|
||||
def will_exceed_max_size(self, tensor_size):
|
||||
expected_size = self._current_size + tensor_size
|
||||
return expected_size > self._max_size
|
||||
|
||||
def get_bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
|
@ -0,0 +1,583 @@
|
|||
from functools import partial
|
||||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
calculate_global_norm_from_list,
|
||||
compute_norm,
|
||||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
reduce_tensor,
|
||||
release_param_grad,
|
||||
split_half_float_double,
|
||||
sync_param,
|
||||
)
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
|
||||
# grad scaler config
|
||||
initial_scale=2**32,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=1000,
|
||||
hysteresis=2,
|
||||
max_scale: int = 2**32,
|
||||
|
||||
# grad clipping
|
||||
clip_grad_norm=2.0,
|
||||
verbose=False,
|
||||
|
||||
# communication
|
||||
reduce_bucket_size=500000000,
|
||||
communication_dtype=torch.float16,
|
||||
overlap_communication=False,
|
||||
|
||||
# stage 2
|
||||
partition_grad=False,
|
||||
dp_parallel_mode=ParallelMode.DATA,
|
||||
mp_parallel_mode=ParallelMode.MODEL,
|
||||
|
||||
# cpu offload
|
||||
cpu_offload=False):
|
||||
|
||||
# TODO: add support for
|
||||
# 1. fp16 master weights
|
||||
# 2. contiguous gradients
|
||||
# 3. cpu offload
|
||||
# 4. support when some parameters requires_grad = False
|
||||
|
||||
self._optimizer = optimizer
|
||||
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
# stage 2
|
||||
self._partition_grads = partition_grad
|
||||
|
||||
# cpu_offload
|
||||
self._cpu_offload = cpu_offload
|
||||
|
||||
# get process groups
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
self._mp_parallel_mode = mp_parallel_mode
|
||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
||||
|
||||
self._dp_group = gpc.get_group(dp_parallel_mode)
|
||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
||||
self._mp_group = gpc.get_group(mp_parallel_mode)
|
||||
else:
|
||||
self._mp_group = None
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
verbose=verbose)
|
||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
# check argument conflict
|
||||
self._sanity_checks()
|
||||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(self._dp_parallel_mode)
|
||||
self._grad_store = GradientStore(self._dp_parallel_mode)
|
||||
self._bucket_store = BucketStore(self._dp_parallel_mode)
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
||||
params = param_group['params']
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
self._fp16_param_groups[group_id] = params
|
||||
|
||||
# assign parameters to ranks
|
||||
# the params in the list are sorted
|
||||
params_per_rank = self._partition_param_list(params)
|
||||
|
||||
# store the mapping between param to rank
|
||||
# each param should belong to only one rank
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
# move_tensor(params, device='cpu')
|
||||
for param in params:
|
||||
param.data = param.data.cpu()
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._world_size):
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.cuda()
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
|
||||
# sync parameters
|
||||
for rank in range(self._world_size):
|
||||
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group['params'] = [fp32_flat_current_rank]
|
||||
|
||||
# set reduction state
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
self._param_store.set_param_reduction_state(param, False)
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# or stage 2 is used
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_communication or self._partition_grads:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
self._initialize_optimizer_states()
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _partition_param_list(self, param_list):
|
||||
params_per_rank = [[] for _ in range(self._world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._world_size)]
|
||||
|
||||
# partititon the parameters in a greedy fashion
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for param in sorted_params:
|
||||
# allocate this parameter to the rank with
|
||||
# the smallest numel for load balancing purpose
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
|
||||
ranks=[0],
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
return params_per_rank
|
||||
|
||||
def _initialize_optimizer_states(self):
|
||||
# create a dummy zero tensor which has the same shape as that of the param
|
||||
# set this dummpy zero tensor as grad
|
||||
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
|
||||
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
|
||||
fp32_partition_param.grad = fp32_partition_grad
|
||||
|
||||
# update the parameter with zero gradients for initialization of optimizer states
|
||||
self._optimizer.step()
|
||||
|
||||
# remove the grad of the paramter to save memory
|
||||
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
|
||||
fp32_flat_tensor.grad = None
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
assert self._dtype == torch.float16, \
|
||||
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
|
||||
|
||||
###########################################################
|
||||
# Backward Reduction Hook
|
||||
###########################################################
|
||||
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
# on each param, we register a hook to its AccumulateGrad object
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
# determines the reduction destionation rank
|
||||
# this is only valid for stage 2
|
||||
# dst_rank = None means using all-reduce
|
||||
# else using reduce
|
||||
if self._partition_grads:
|
||||
reduce_rank = self._param_store.get_param_rank(param)
|
||||
else:
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank):
|
||||
# get the AccumulateGrad object of the param itself
|
||||
accum_grad_obj = get_grad_accumulate_object(param)
|
||||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank)
|
||||
|
||||
# define hook
|
||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args):
|
||||
reduction_func()
|
||||
|
||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
||||
|
||||
# the param must not be reduced to ensure correctness
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
|
||||
+ 'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# the param must have grad for reduction
|
||||
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
|
||||
|
||||
self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
self._bucket_store.add_grad(param.grad, reduce_rank)
|
||||
self._bucket_store.add_param(param, reduce_rank)
|
||||
|
||||
def _reduce_grads_in_bucket(self, reduce_rank=None):
|
||||
# reduce grads
|
||||
self._reduce_grads_by_rank(reduce_rank=reduce_rank,
|
||||
grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
|
||||
|
||||
# use communication stream if overlapping
|
||||
# communication with computation
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
|
||||
|
||||
for param in params_in_bucket:
|
||||
# the is_param_reduced flag should be False showing that
|
||||
# this param is not reduced before calling self._reduce_grads_by_rank
|
||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||
|
||||
if is_param_reduced:
|
||||
msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
|
||||
'duplicate reduction will lead to arithmetic incorrectness'
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
# if partition grads = True
|
||||
# we do not keep the gradient after reduction
|
||||
if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
|
||||
if self._overlap_communication:
|
||||
# we need to keep this gradient for now as reduction may
|
||||
# be completed yet since it is using a different cuda stream
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
else:
|
||||
param.grad = None
|
||||
|
||||
self._bucket_store.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
|
||||
|
||||
##############################
|
||||
# Reduction Utility Function #
|
||||
##############################
|
||||
def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
stream = self._comm_stream
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(tensor=flat,
|
||||
dtype=self._communication_dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=self._dp_parallel_mode)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
|
||||
################################
|
||||
# torch.optim.Optimizer methods
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=True):
|
||||
loss = self.loss_scale * loss
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Set parameter gradients to zero. If set_to_none = True, gradient
|
||||
will be set to None to save memory.
|
||||
|
||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
for group_id, param_group in self._fp16_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
else:
|
||||
if param.grad is not None:
|
||||
param.grad.detach()
|
||||
param.grad.zero_()
|
||||
|
||||
####################
|
||||
# Update Parameter #
|
||||
####################
|
||||
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
norm_groups = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
dp_group=self._dp_group,
|
||||
mp_group=self._mp_group)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert param_shape == flat_fp32_avg_grads.shape, \
|
||||
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
|
||||
|
||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
||||
|
||||
# update the parameters
|
||||
self._optimizer.step()
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# broadcast the updated model weights
|
||||
handles = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._world_size):
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
##################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
############################
|
||||
|
||||
def sync_grad(self):
|
||||
if not self._partition_grads:
|
||||
self._reduce_grad_stage1()
|
||||
else:
|
||||
# TODO: support async comm in reduce
|
||||
self._reduce_grad_stage2()
|
||||
|
||||
# update param already reduced flag
|
||||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, state in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
torch.cuda.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# accumulate gradient
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
||||
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients[group_id]) == param_idx:
|
||||
avg_gradients[group_id].append(param.grad)
|
||||
else:
|
||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
||||
param_idx += 1
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
self.zero_grad()
|
||||
|
||||
def _reduce_grad_stage1(self):
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
if not self._overlap_communication:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._reduce_and_remove_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
self._reduce_grads_in_bucket()
|
||||
|
||||
def _reduce_grad_stage2(self):
|
||||
# when partition_grads is True, reduction hooks
|
||||
# are attached in the __init__ function, so we
|
||||
# only need to reduce the gradients
|
||||
# left in the communication bucket
|
||||
for reduce_rank in range(self._world_size):
|
||||
self._reduce_grads_in_bucket(reduce_rank)
|
|
@ -0,0 +1,185 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def check_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are equal within tolerance
|
||||
"""
|
||||
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
|
||||
|
||||
|
||||
def check_completely_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are completely equal
|
||||
"""
|
||||
assert torch.all(a == b), f'a = {a}, b = {b}'
|
||||
|
||||
|
||||
def check_sharded_param_consistency():
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
deliver the same numerical results despite different communication
|
||||
pattern
|
||||
|
||||
we use these prefixes to differentiate the zero stage
|
||||
oss: partition optimizer states
|
||||
pg: partition gradients and optimizer states
|
||||
|
||||
"""
|
||||
|
||||
# create layers
|
||||
oss_linear1 = nn.Linear(128, 256)
|
||||
oss_linear2 = nn.Linear(256, 512)
|
||||
|
||||
# create model
|
||||
oss_model = nn.Sequential(oss_linear1, oss_linear2)
|
||||
pg_model = copy.deepcopy(oss_model)
|
||||
|
||||
oss_model = oss_model.cuda().half()
|
||||
pg_model = pg_model.cuda().half()
|
||||
|
||||
# create optimizer
|
||||
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
|
||||
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
|
||||
oss_optimizer = LowLevelZeroOptimizer(oss_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
pg_optimizer = LowLevelZeroOptimizer(pg_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda().half()
|
||||
|
||||
# forward
|
||||
oss_output = oss_model(input_data)
|
||||
pg_output = pg_model(input_data)
|
||||
check_completely_equal(oss_output, pg_output)
|
||||
|
||||
# backward
|
||||
oss_optimizer.backward(oss_output.mean().float())
|
||||
pg_optimizer.backward(pg_output.mean().float())
|
||||
|
||||
# check grad
|
||||
# as this param is small, the backward reduction
|
||||
# will not be fired
|
||||
oss_linear1_grad = oss_model[0].weight.grad
|
||||
oss_linear2_grad = oss_model[1].weight.grad
|
||||
pg_linear1_grad = pg_model[0].weight.grad
|
||||
pg_linear2_grad = pg_model[1].weight.grad
|
||||
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
|
||||
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
|
||||
|
||||
# step
|
||||
oss_optimizer.sync_grad()
|
||||
pg_optimizer.sync_grad()
|
||||
|
||||
# step
|
||||
oss_optimizer.step()
|
||||
pg_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
|
||||
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
|
||||
|
||||
|
||||
def check_sharded_optim_against_torch_ddp():
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
2. torch: use torch DDP and fp32 parameters
|
||||
|
||||
We feed these two sets of models with the same input and check if the
|
||||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
|
||||
# create layer
|
||||
zero_linear1 = nn.Linear(128, 256)
|
||||
zero_linear2 = nn.Linear(256, 512)
|
||||
|
||||
# create model
|
||||
zero_model = nn.Sequential(zero_linear1, zero_linear2)
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
torch_model = DDP(torch_model.cuda())
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
|
||||
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.half())
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
check_equal(zero_output, torch_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
zero_linear1_grad = zero_model[0].weight.grad
|
||||
zero_linear2_grad = zero_model[1].weight.grad
|
||||
torch_linear1_grad = torch_model.module[0].weight.grad
|
||||
torch_linear2_grad = torch_model.module[1].weight.grad
|
||||
check_equal(zero_linear1_grad, torch_linear1_grad)
|
||||
check_equal(zero_linear2_grad, torch_linear2_grad)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.sync_grad()
|
||||
zero_optimizer.step()
|
||||
|
||||
# torch ddp step
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_equal(zero_model[0].weight, torch_model.module[0].weight)
|
||||
check_equal(zero_model[1].weight, torch_model.module[1].weight)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
check_sharded_optim_against_torch_ddp()
|
||||
check_sharded_param_consistency()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim()
|
Loading…
Reference in New Issue