From 6e51d296f07c0ad34d7f85cf9a70d4ceee15ede7 Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 11 Nov 2022 09:26:40 +0800 Subject: [PATCH] [zero] migrate zero1&2 (#1878) * add zero1&2 optimizer * rename test ditectory * rename test files * change tolerance in test --- colossalai/zero/__init__.py | 6 +- colossalai/zero/sharded_optim/__init__.py | 3 +- .../sharded_optim/bookkeeping/__init__.py | 6 + .../sharded_optim/bookkeeping/base_store.py | 17 + .../sharded_optim/bookkeeping/bucket_store.py | 44 ++ .../bookkeeping/gradient_store.py | 66 ++ .../bookkeeping/parameter_store.py | 96 +++ .../bookkeeping/tensor_bucket.py | 53 ++ .../zero/sharded_optim/low_level_optim.py | 583 ++++++++++++++++++ .../test_zero/low_level_zero/test_zero1_2.py | 185 ++++++ 10 files changed, 1056 insertions(+), 3 deletions(-) create mode 100644 colossalai/zero/sharded_optim/bookkeeping/__init__.py create mode 100644 colossalai/zero/sharded_optim/bookkeeping/base_store.py create mode 100644 colossalai/zero/sharded_optim/bookkeeping/bucket_store.py create mode 100644 colossalai/zero/sharded_optim/bookkeeping/gradient_store.py create mode 100644 colossalai/zero/sharded_optim/bookkeeping/parameter_store.py create mode 100644 colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py create mode 100644 colossalai/zero/sharded_optim/low_level_optim.py create mode 100644 tests/test_zero/low_level_zero/test_zero1_2.py diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 0e320f912..3a896322f 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -2,9 +2,11 @@ from typing import Tuple import torch import torch.nn as nn + from colossalai.logging import get_dist_logger from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 +from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 + from .zero_optimizer import ZeroOptimizer @@ -36,4 +38,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model return zero_model, zero_optimizer -__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] +__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py index b71a70aef..30c26fb75 100644 --- a/colossalai/zero/sharded_optim/__init__.py +++ b/colossalai/zero/sharded_optim/__init__.py @@ -1,3 +1,4 @@ +from .low_level_optim import LowLevelZeroOptimizer from .sharded_optim_v2 import ShardedOptimizerV2 -__all__ = ['ShardedOptimizerV2'] +__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/sharded_optim/bookkeeping/__init__.py new file mode 100644 index 000000000..7bcacfabf --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/__init__.py @@ -0,0 +1,6 @@ +from .bucket_store import BucketStore +from .gradient_store import GradientStore +from .parameter_store import ParameterStore +from .tensor_bucket import TensorBucket + +__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/sharded_optim/bookkeeping/base_store.py new file mode 100644 index 000000000..d4436acaa --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/base_store.py @@ -0,0 +1,17 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class BaseStore: + + def __init__(self, dp_parallel_mode=ParallelMode.DATA): + self._world_size = gpc.get_world_size(dp_parallel_mode) + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + + @property + def world_size(self): + return self._world_size + + @property + def local_rank(self): + return self._local_rank diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py new file mode 100644 index 000000000..0f2b1bb88 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py @@ -0,0 +1,44 @@ +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +from .base_store import BaseStore + + +class BucketStore(BaseStore): + + def __init__(self, dp_parallel_mode): + super().__init__(dp_parallel_mode) + self._grads = dict() + self._params = dict() + self._num_elements_in_bucket = dict() + + self.reset() + + def num_elements_in_bucket(self, reduce_rank: int = None): + return self._num_elements_in_bucket[reduce_rank] + + def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): + self._num_elements_in_bucket[reduce_rank] += num_elements + + def add_grad(self, tensor, reduce_rank: int = None): + self._grads[reduce_rank].append(tensor) + + def add_param(self, tensor, reduce_rank: int = None): + self._params[reduce_rank].append(tensor) + + def reset(self): + keys = [None] + list(range(self._world_size)) + self._grads = {rank: [] for rank in keys} + self._params = {rank: [] for rank in keys} + self._num_elements_in_bucket = {rank: 0 for rank in keys} + + def reset_by_rank(self, reduce_rank=None): + self._grads[reduce_rank] = [] + self._params[reduce_rank] = [] + self._num_elements_in_bucket[reduce_rank] = 0 + + def get_grad(self, reduce_rank: int = None): + return self._grads[reduce_rank] + + def get_param(self, reduce_rank: int = None): + return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py new file mode 100644 index 000000000..8a9128a18 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -0,0 +1,66 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class GradientStore(BaseStore): + + def __init__(self, *args): + super().__init__(*args) + # bookkeeping data structures + self._averaged_gradients = dict() + + # for backward reduction hooks + self._grad_acc_objs = [] + + def add_accumulate_grad_object(self, obj): + """ + Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not + be attached successfully. + + :param obj: An object of :class:`AccumulateGrad` class + :type obj: :class:`AccumulateGrad` + """ + + self._grad_acc_objs.append(obj) + + def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: + """ + Return average gradients of a parameter group + + :param group_id: The index of parameter group + :type group_id: int + + :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. + :rtype: List[torch.Tensor] + """ + + return self._averaged_gradients[group_id] + + def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + """ + Append an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor: torch.Tensor + + """ + + if group_id in self._averaged_gradients: + self._averaged_gradients[group_id].append(tensor) + else: + self._averaged_gradients[group_id] = [tensor] + + def reset_average_gradients_by_group(self, group_id: int) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + + :param group_id: The index of a parameter group + :type group_id: int + """ + + self._averaged_gradients[group_id] = [] diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py new file mode 100644 index 000000000..09ebaaf99 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py @@ -0,0 +1,96 @@ +from typing import List + +from torch import Tensor + +from .base_store import BaseStore + + +class ParameterStore(BaseStore): + + def __init__(self, dp_paralle_mode): + super().__init__(dp_paralle_mode) + # param partitioning data structures + self._fp16_param_to_rank = dict() + self._rank_groupid_to_fp16_param_list = dict() + self._rank_group_id_to_flat_fp16_param = dict() + + # param reduction data structures + self._is_param_reduced = dict() + self._reduced_param = [] + + def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: + """ + Set the mapping between parameter to rank, each parameter should be owned by a rank. + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + :param rank: The rank of which the process is responsible for updating the parameter + :type rank: int + """ + + self._fp16_param_to_rank[tensor] = rank + + def get_param_rank(self, tensor: Tensor) -> int: + """ + Gives the rank which the parameter belongs to + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + """ + return self._fp16_param_to_rank[tensor] + + def belongs_to_current_rank(self, tensor) -> bool: + """ + Check whether a parameter is supposed to be updated by the process of the current rank + + :param tensor: A :class:`torch.Tensor` object + :type tensor: torch.Tensor + + :return: True if the parameter should be updated by the current rank. Otherwise false. + :rtype: bool + """ + + tensor_rank = self._fp16_param_to_rank[tensor] + return tensor_rank == self._local_rank + + def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: + if rank not in self._rank_groupid_to_fp16_param_list: + self._rank_groupid_to_fp16_param_list[rank] = dict() + + if group_id not in self._rank_groupid_to_fp16_param_list[rank]: + self._rank_groupid_to_fp16_param_list[rank][group_id] = [] + + self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list) + + def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]: + return self._rank_groupid_to_fp16_param_list[rank][group_id] + + def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None: + if rank not in self._rank_group_id_to_flat_fp16_param: + self._rank_group_id_to_flat_fp16_param[rank] = dict() + + self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor + + def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor: + return self._rank_group_id_to_flat_fp16_param[rank][group_id] + + def is_param_reduced(self, tensor): + return self._is_param_reduced[tensor] + + def set_param_reduction_state(self, tensor, state): + self._is_param_reduced[tensor] = state + + def get_param_reduction_states(self): + return self._is_param_reduced + + def reset_previous_reduced_params(self): + self._reduced_param = [] + + def add_previous_reduced_param(self, tensor): + self._reduced_param.append(tensor) + + def clear_grads_of_previous_reduced_params(self): + if len(self._reduced_param) > 0: + for param in self._reduced_param: + param.grad = None + self.reset_previous_reduced_params() diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py new file mode 100644 index 000000000..b32816a04 --- /dev/null +++ b/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py @@ -0,0 +1,53 @@ +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class TensorBucket: + + def __init__(self, size): + self._max_size = size + self._current_size = 0 + self._bucket = [] + + @property + def max_size(self): + return self._max_size + + @property + def current_size(self): + return self._current_size + + def is_full_or_oversized(self): + return self._current_size >= self._max_size + + def is_empty(self): + return len(self._bucket) == 0 + + def add_to_bucket(self, tensor, allow_oversize=False): + tensor_size = tensor.numel() + + if not allow_oversize and self.will_exceed_max_size(tensor_size): + msg = f"The param bucket max size {self._max_size} is exceeded" \ + + f"by tensor (size {tensor_size})" + raise RuntimeError(msg) + + self._bucket.append(tensor) + self._current_size += tensor_size + + def will_exceed_max_size(self, tensor_size): + expected_size = self._current_size + tensor_size + return expected_size > self._max_size + + def get_bucket(self): + return self._bucket + + def empty(self): + self._bucket = [] + self._size = 0 + + def flatten(self): + return _flatten_dense_tensors(self._bucket) + + def unflatten_and_copy(self, flat_tensor): + unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + for old, new in zip(self._bucket, unflattened_tensor_list): + old.copy_(new) diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py new file mode 100644 index 000000000..a945a8481 --- /dev/null +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -0,0 +1,583 @@ +from functools import partial +from itertools import groupby + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.cuda import get_current_device + +from ._utils import ( + calculate_global_norm_from_list, + compute_norm, + flatten, + get_grad_accumulate_object, + has_inf_or_nan, + reduce_tensor, + release_param_grad, + split_half_float_double, + sync_param, +) +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket + + +class LowLevelZeroOptimizer(ColossalaiOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2. + """ + + def __init__( + self, + optimizer: Optimizer, + + # grad scaler config + initial_scale=2**32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale: int = 2**32, + + # grad clipping + clip_grad_norm=2.0, + verbose=False, + + # communication + reduce_bucket_size=500000000, + communication_dtype=torch.float16, + overlap_communication=False, + + # stage 2 + partition_grad=False, + dp_parallel_mode=ParallelMode.DATA, + mp_parallel_mode=ParallelMode.MODEL, + + # cpu offload + cpu_offload=False): + + # TODO: add support for + # 1. fp16 master weights + # 2. contiguous gradients + # 3. cpu offload + # 4. support when some parameters requires_grad = False + + self._optimizer = optimizer + self._dtype = self._optimizer.param_groups[0]['params'][0].dtype + self._logger = get_dist_logger() + self._verbose = verbose + + # stage 2 + self._partition_grads = partition_grad + + # cpu_offload + self._cpu_offload = cpu_offload + + # get process groups + self._dp_parallel_mode = dp_parallel_mode + self._mp_parallel_mode = mp_parallel_mode + self._local_rank = gpc.get_local_rank(dp_parallel_mode) + self._world_size = gpc.get_world_size(dp_parallel_mode) + + self._dp_group = gpc.get_group(dp_parallel_mode) + if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: + self._mp_group = gpc.get_group(mp_parallel_mode) + else: + self._mp_group = None + + # fp16 and fp32 params for mixed precision training + self._fp16_param_groups = dict() + self._fp32_flat_param_groups_of_current_rank = dict() + + # communication params + self._overlap_communication = overlap_communication + self._reduce_bucket_size = reduce_bucket_size + self._communication_dtype = communication_dtype + + # gradient scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + verbose=verbose) + self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) + + # gradient clipping + self._clip_grad_norm = clip_grad_norm + + # check argument conflict + self._sanity_checks() + + # ParameterStore will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + self._param_store = ParameterStore(self._dp_parallel_mode) + self._grad_store = GradientStore(self._dp_parallel_mode) + self._bucket_store = BucketStore(self._dp_parallel_mode) + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self._optimizer.param_groups): + params = param_group['params'] + + # add the fp16 params to fp16_param_groups for bookkeeping + self._fp16_param_groups[group_id] = params + + # assign parameters to ranks + # the params in the list are sorted + params_per_rank = self._partition_param_list(params) + + # store the mapping between param to rank + # each param should belong to only one rank + for rank, params in enumerate(params_per_rank): + self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) + for param in params: + self._param_store.set_param_to_rank(param, rank) + + # move to cpu to make room to create the flat tensor + # move_tensor(params, device='cpu') + for param in params: + param.data = param.data.cpu() + + # flatten the reordered tensors + for rank in range(self._world_size): + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + flat_tensor = flatten(tensor_list) + flat_tensor = flat_tensor.cuda() + self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) + + # sync parameters + for rank in range(self._world_size): + flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id) + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) + + # create a copy of fp32 weights of the parameters for which this rank is responsible + fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id) + fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach() + device = 'cpu' if self._cpu_offload else get_current_device() + fp32_flat_current_rank = fp32_flat_current_rank.to(device) + fp32_flat_current_rank.requires_grad = True + self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group['params'] = [fp32_flat_current_rank] + + # set reduction state + for param in self._fp16_param_groups[group_id]: + self._param_store.set_param_reduction_state(param, False) + + # intialize communication stream for + # communication-compuation overlapping + if self._overlap_communication: + self._comm_stream = torch.cuda.Stream() + + # reduction hook is only used if overlapping communication + # or stage 2 is used + # if it is stage 1 without overlapping, no hook will be attached + if self._overlap_communication or self._partition_grads: + self._attach_reduction_hook() + + self._initialize_optimizer_states() + + @property + def loss_scale(self): + return self.grad_scaler.scale + + @property + def num_param_groups(self): + return len(self._fp16_param_groups) + + def _partition_param_list(self, param_list): + params_per_rank = [[] for _ in range(self._world_size)] + numel_per_rank = [0 for _ in range(self._world_size)] + + # partititon the parameters in a greedy fashion + sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for param in sorted_params: + # allocate this parameter to the rank with + # the smallest numel for load balancing purpose + rank_to_go = numel_per_rank.index(min(numel_per_rank)) + params_per_rank[rank_to_go].append(param) + numel_per_rank[rank_to_go] += param.numel() + + if self._verbose: + self._logger.info(f'Number of elements on ranks: {numel_per_rank}', + ranks=[0], + parallel_mode=self._dp_parallel_mode) + return params_per_rank + + def _initialize_optimizer_states(self): + # create a dummy zero tensor which has the same shape as that of the param + # set this dummpy zero tensor as grad + for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)): + fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp32_partition_grad = torch.zeros_like(fp32_partition_param) + fp32_partition_param.grad = fp32_partition_grad + + # update the parameter with zero gradients for initialization of optimizer states + self._optimizer.step() + + # remove the grad of the paramter to save memory + for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): + fp32_flat_tensor.grad = None + + def _sanity_checks(self): + assert torch.cuda.is_available(), 'CUDA is required' + assert self._dtype == torch.float16, \ + f'Parameters are expected to be of type torch.float16, but got {self._dtype}' + + ########################################################### + # Backward Reduction Hook + ########################################################### + + def _attach_reduction_hook(self): + # we iterate over the fp16 params + # on each param, we register a hook to its AccumulateGrad object + for group_id in range(self.num_param_groups): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.requires_grad: + # determines the reduction destionation rank + # this is only valid for stage 2 + # dst_rank = None means using all-reduce + # else using reduce + if self._partition_grads: + reduce_rank = self._param_store.get_param_rank(param) + else: + reduce_rank = None + + def _define_and_attach(param, reduce_rank): + # get the AccumulateGrad object of the param itself + accum_grad_obj = get_grad_accumulate_object(param) + self._grad_store.add_accumulate_grad_object(accum_grad_obj) + + reduction_func = partial(self._reduce_and_remove_grads_by_bucket, + param=param, + reduce_rank=reduce_rank) + + # define hook + # NOT IMPORTANT BUT GOOD TO KNOW: + # args here is not grad, but allow_unreacable and accumulate_grad + def reduce_grad_hook(*args): + reduction_func() + + accum_grad_obj.register_hook(reduce_grad_hook) + + _define_and_attach(param, reduce_rank) + + def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._reduce_grads_in_bucket(reduce_rank) + + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_grad(param.grad, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) + + def _reduce_grads_in_bucket(self, reduce_rank=None): + # reduce grads + self._reduce_grads_by_rank(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + + # use communication stream if overlapping + # communication with computation + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + + for param in params_in_bucket: + # the is_param_reduced flag should be False showing that + # this param is not reduced before calling self._reduce_grads_by_rank + is_param_reduced = self._param_store.is_param_reduced(param) + + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) + + # update the flag + self._param_store.set_param_reduction_state(param, True) + + # if partition grads = True + # we do not keep the gradient after reduction + if self._partition_grads and not self._param_store.belongs_to_current_rank(param): + if self._overlap_communication: + # we need to keep this gradient for now as reduction may + # be completed yet since it is using a different cuda stream + self._param_store.add_previous_reduced_param(param) + else: + param.grad = None + + self._bucket_store.reset_by_rank(reduce_rank) + + def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_half_float_double(grads) + + for tensor_list in grad_buckets_by_dtype: + self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) + + ############################## + # Reduction Utility Function # + ############################## + def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) + + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) + + if param_bucket.is_full_or_oversized(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() + + if not param_bucket.is_empty(): + self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) + + def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduced_flat = reduce_tensor(tensor=flat, + dtype=self._communication_dtype, + dst_rank=reduce_rank, + parallel_mode=self._dp_parallel_mode) + + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=True): + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + for group_id, param_group in self._fp16_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, 'closure is not supported by step()' + + # check for overflow + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) + + # update loss scale if overflow occurs + if found_inf: + self._grad_store._averaged_gradients = dict() + self.zero_grad() + return + + # copy the grad of fp16 param to fp32 param + single_grad_partition_groups = [] + norm_groups = [] + + for group_id in range(self.num_param_groups): + # compute norm + norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], + params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, + rank=self._local_rank), + dp_group=self._dp_group, + mp_group=self._mp_group) + norm_groups.append(norm_group) + + # create flat gradient for the flat fp32 params + fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_fp16_avg_grads = flatten(fp16_avg_grads) + + dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype + flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + + param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape + assert param_shape == flat_fp32_avg_grads.shape, \ + f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}' + + single_grad_partition_groups.append(flat_fp32_avg_grads) + device = self._fp32_flat_param_groups_of_current_rank[group_id].device + self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) + self._grad_store._averaged_gradients[group_id] = [] + self._grad_store._averaged_gradients[group_id] = [] + + # unscale and clip grads + global_norm = calculate_global_norm_from_list(norm_list=norm_groups) + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + + # update the parameters + self._optimizer.step() + # release the fp32 grad + release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + + # update fp16 partition updated by the current rank + for group_id in range(len(self._fp16_param_groups)): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device) + fp16_param.data.copy_(fp32_param) + + # broadcast the updated model weights + handles = [] + for group_id in range(self.num_param_groups): + for rank in range(self._world_size): + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True) + handles.append(handle) + + for handle in handles: + handle.wait() + + ################## + # FP16 Utilities # + ################## + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group_id in range(len(self._fp16_param_groups)): + for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group) + + # all-reduce over model parallel group + if self._mp_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group) + + if self._found_overflow.item() > 0: + return True + else: + return False + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): + # compute combined scale factor for this group + combined_scale = self.loss_scale + + if self._clip_grad_norm > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + combined_scale = clip * self.loss_scale + + for grad in grad_groups_flat: + grad.data.mul_(1. / combined_scale) + + ############################ + # Gradient Synchronization # + ############################ + + def sync_grad(self): + if not self._partition_grads: + self._reduce_grad_stage1() + else: + # TODO: support async comm in reduce + self._reduce_grad_stage2() + + # update param already reduced flag + reduction_states = self._param_store.get_param_reduction_states() + for tensor, state in reduction_states.items(): + reduction_states[tensor] = False + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + + # accumulate gradient + avg_gradients = self._grad_store._averaged_gradients + for group_id in range(self.num_param_groups): + param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) + + if group_id not in avg_gradients: + avg_gradients[group_id] = [] + + param_idx = 0 + for param in param_group: + if param.grad is not None: + if len(avg_gradients[group_id]) == param_idx: + avg_gradients[group_id].append(param.grad) + else: + avg_gradients[group_id][param_idx].add_(param.grad) + param_idx += 1 + + # the gradients needed are stored in the avg_gradients buffer + # thus, can clear this + self.zero_grad() + + def _reduce_grad_stage1(self): + # if not overlapping communication (no reduction hook is attached) + # we need to manually reduce these gradients + if not self._overlap_communication: + for group_id in range(len(self._fp16_param_groups)): + param_group = self._fp16_param_groups[group_id] + for param in param_group: + if param.grad is not None: + self._reduce_and_remove_grads_by_bucket(param) + + # we need to reduce the gradients + # left in the communication bucket + self._reduce_grads_in_bucket() + + def _reduce_grad_stage2(self): + # when partition_grads is True, reduction hooks + # are attached in the __init__ function, so we + # only need to reduce the gradients + # left in the communication bucket + for reduce_rank in range(self._world_size): + self._reduce_grads_in_bucket(reduce_rank) diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py new file mode 100644 index 000000000..8a510daaf --- /dev/null +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -0,0 +1,185 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +def check_equal(a, b): + """ + This function checks if two tensors are equal within tolerance + """ + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + + +def check_completely_equal(a, b): + """ + This function checks if two tensors are completely equal + """ + assert torch.all(a == b), f'a = {a}, b = {b}' + + +def check_sharded_param_consistency(): + """ + In this test, we want to test whether zero stage 1 and 2 + deliver the same numerical results despite different communication + pattern + + we use these prefixes to differentiate the zero stage + oss: partition optimizer states + pg: partition gradients and optimizer states + + """ + + # create layers + oss_linear1 = nn.Linear(128, 256) + oss_linear2 = nn.Linear(256, 512) + + # create model + oss_model = nn.Sequential(oss_linear1, oss_linear2) + pg_model = copy.deepcopy(oss_model) + + oss_model = oss_model.cuda().half() + pg_model = pg_model.cuda().half() + + # create optimizer + oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) + pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) + oss_optimizer = LowLevelZeroOptimizer(oss_optimizer, + overlap_communication=True, + initial_scale=1, + clip_grad_norm=0.0) + pg_optimizer = LowLevelZeroOptimizer(pg_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=1, + clip_grad_norm=0.0) + + # create + input_data = torch.rand(32, 128).cuda().half() + + # forward + oss_output = oss_model(input_data) + pg_output = pg_model(input_data) + check_completely_equal(oss_output, pg_output) + + # backward + oss_optimizer.backward(oss_output.mean().float()) + pg_optimizer.backward(pg_output.mean().float()) + + # check grad + # as this param is small, the backward reduction + # will not be fired + oss_linear1_grad = oss_model[0].weight.grad + oss_linear2_grad = oss_model[1].weight.grad + pg_linear1_grad = pg_model[0].weight.grad + pg_linear2_grad = pg_model[1].weight.grad + check_completely_equal(oss_linear1_grad, pg_linear1_grad) + check_completely_equal(oss_linear2_grad, pg_linear2_grad) + + # step + oss_optimizer.sync_grad() + pg_optimizer.sync_grad() + + # step + oss_optimizer.step() + pg_optimizer.step() + + # check updated param + check_completely_equal(oss_model[0].weight, pg_model[0].weight) + check_completely_equal(oss_model[1].weight, pg_model[1].weight) + + +def check_sharded_optim_against_torch_ddp(): + """ + In this test, two pairs of model and optimizers are created. + 1. zero: use sharded optimizer and fp16 parameters + 2. torch: use torch DDP and fp32 parameters + + We feed these two sets of models with the same input and check if the + differences in model output and updated parameters are within tolerance. + """ + + # create layer + zero_linear1 = nn.Linear(128, 256) + zero_linear2 = nn.Linear(256, 512) + + # create model + zero_model = nn.Sequential(zero_linear1, zero_linear2) + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda().half() + torch_model = DDP(torch_model.cuda()) + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + clip_grad_norm=0.0) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) + + # create + input_data = torch.rand(32, 128).cuda() + + # zero-dp forward + zero_output = zero_model(input_data.half()) + + # torch-ddp forward + torch_output = torch_model(input_data) + check_equal(zero_output, torch_output) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) + + # torch-ddp backward + torch_output.mean().backward() + + # check grad + zero_linear1_grad = zero_model[0].weight.grad + zero_linear2_grad = zero_model[1].weight.grad + torch_linear1_grad = torch_model.module[0].weight.grad + torch_linear2_grad = torch_model.module[1].weight.grad + check_equal(zero_linear1_grad, torch_linear1_grad) + check_equal(zero_linear2_grad, torch_linear2_grad) + + # zero-dp step + zero_optimizer.sync_grad() + zero_optimizer.step() + + # torch ddp step + torch_optimizer.step() + + # check updated param + check_equal(zero_model[0].weight, torch_model.module[0].weight) + check_equal(zero_model[1].weight, torch_model.module[1].weight) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + check_sharded_optim_against_torch_ddp() + check_sharded_param_consistency() + + +@pytest.mark.dist +def test_sharded_optim(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim()