2023-03-04 12:08:11 +00:00
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
2023-07-06 09:20:04 +00:00
|
|
|
import copy
|
2023-06-30 07:30:50 +00:00
|
|
|
from contextlib import contextmanager
|
2022-11-11 01:26:40 +00:00
|
|
|
from functools import partial
|
2023-10-12 03:32:37 +00:00
|
|
|
from typing import Dict, Iterator, List, Optional, Tuple
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2023-09-05 07:04:02 +00:00
|
|
|
import torch.nn as nn
|
2023-10-12 03:32:37 +00:00
|
|
|
from torch import Tensor, inf
|
2023-11-02 02:21:24 +00:00
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
2023-07-04 09:41:28 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2022-11-11 01:26:40 +00:00
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
2023-11-20 08:12:41 +00:00
|
|
|
import colossalai.utils.device as device_utils
|
2023-06-05 07:58:31 +00:00
|
|
|
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|
|
|
BF16MixedPrecisionMixin,
|
|
|
|
FP16MixedPrecisionMixin,
|
|
|
|
MixedPrecisionMixin,
|
|
|
|
)
|
2023-07-04 04:00:33 +00:00
|
|
|
from colossalai.interface import OptimizerWrapper
|
2022-11-11 01:26:40 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2023-11-02 02:21:24 +00:00
|
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
2023-09-19 06:20:26 +00:00
|
|
|
|
2023-07-04 09:41:28 +00:00
|
|
|
# from colossalai.tensor import ColoParameter, ProcessGroup
|
2023-11-20 08:12:41 +00:00
|
|
|
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-10-12 03:32:37 +00:00
|
|
|
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
2023-06-30 07:30:50 +00:00
|
|
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
|
2023-06-05 07:58:31 +00:00
|
|
|
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
2023-09-19 06:20:26 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_working_param_groups: int,
|
|
|
|
grad_store: GradientStore,
|
|
|
|
initial_scale: float = 2**16,
|
|
|
|
min_scale: float = 1,
|
|
|
|
growth_factor: float = 2,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: int = 1000,
|
|
|
|
hysteresis: int = 2,
|
|
|
|
max_scale: float = 2**32,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(
|
|
|
|
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
|
|
|
)
|
2023-06-05 07:58:31 +00:00
|
|
|
self.num_working_param_groups = num_working_param_groups
|
|
|
|
self.grad_store = grad_store
|
|
|
|
|
|
|
|
def check_local_overflow(self) -> bool:
|
|
|
|
for group_id in range(self.num_working_param_groups):
|
2023-06-30 07:30:50 +00:00
|
|
|
for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
|
2023-06-05 07:58:31 +00:00
|
|
|
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-07-04 04:00:33 +00:00
|
|
|
class LowLevelZeroOptimizer(OptimizerWrapper):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""Optimizer used for ZeRO-1 and ZeRO-2."""
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
def __init__(
|
2023-09-19 06:20:26 +00:00
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
initial_scale: int = 2**16, # grad scaler config
|
|
|
|
min_scale: int = 1,
|
|
|
|
growth_factor: float = 2.0,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: int = 2000,
|
|
|
|
hysteresis: int = 2,
|
|
|
|
max_scale: int = 2**24,
|
|
|
|
clip_grad_norm: float = 0.0, # grad clipping
|
|
|
|
verbose: bool = False,
|
|
|
|
reduce_bucket_size: int = 1024 * 1024, # communication
|
|
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
|
|
overlap_communication: bool = False,
|
|
|
|
partition_grad: bool = False, # stage 2 flag
|
|
|
|
cpu_offload: bool = False, # cpu offload
|
|
|
|
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
|
|
|
forced_dtype: Optional[torch.dtype] = None,
|
2023-11-02 02:21:24 +00:00
|
|
|
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
2023-10-13 07:57:45 +00:00
|
|
|
master_weights: bool = True, # master weights
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-01-03 09:22:34 +00:00
|
|
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
2023-09-19 06:20:26 +00:00
|
|
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
2022-11-11 01:26:40 +00:00
|
|
|
self._logger = get_dist_logger()
|
|
|
|
self._verbose = verbose
|
|
|
|
|
|
|
|
# stage 2
|
|
|
|
self._partition_grads = partition_grad
|
|
|
|
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# grad accumulation
|
|
|
|
self.require_grad_sync = True
|
|
|
|
|
2023-07-04 09:41:28 +00:00
|
|
|
# if process_group is none, will use the default one
|
|
|
|
self.dp_pg = dp_process_group
|
|
|
|
self._local_rank = dist.get_rank(group=self.dp_pg)
|
|
|
|
self._world_size = dist.get_world_size(group=self.dp_pg)
|
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
# extra dp
|
|
|
|
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
|
|
|
|
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
|
|
|
|
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
|
|
|
|
# And moe working and master param are split by extra dp pg.
|
|
|
|
self.moe_extra_dp_pg = moe_extra_dp_process_group
|
|
|
|
if self.moe_extra_dp_pg is not None:
|
|
|
|
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
|
|
|
|
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
|
|
|
|
|
2023-04-27 10:43:14 +00:00
|
|
|
# working and master params for mixed precision training
|
|
|
|
self._working_param_groups = dict()
|
2023-06-30 07:30:50 +00:00
|
|
|
self._master_param_groups_of_current_rank = dict()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
# communication params
|
|
|
|
self._overlap_communication = overlap_communication
|
|
|
|
self._reduce_bucket_size = reduce_bucket_size
|
|
|
|
self._communication_dtype = communication_dtype
|
|
|
|
|
|
|
|
# gradient clipping
|
|
|
|
self._clip_grad_norm = clip_grad_norm
|
|
|
|
|
2023-10-13 07:57:45 +00:00
|
|
|
# master weights copy
|
|
|
|
self._master_weights = master_weights
|
|
|
|
|
2022-11-29 05:00:30 +00:00
|
|
|
if forced_dtype:
|
2023-01-03 09:22:34 +00:00
|
|
|
for group in self.optim.param_groups:
|
2023-09-19 06:20:26 +00:00
|
|
|
group_params = group["params"]
|
2022-11-29 05:00:30 +00:00
|
|
|
for param in group_params:
|
|
|
|
param.data = param.data.to(forced_dtype)
|
|
|
|
self._dtype = forced_dtype
|
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
# check argument conflict
|
|
|
|
self._sanity_checks()
|
|
|
|
|
|
|
|
# ParameterStore will manage the tensor buffers used for zero
|
|
|
|
# it will not manage the tensors used by mixed precision training
|
2023-07-04 09:41:28 +00:00
|
|
|
self._param_store = ParameterStore(self.dp_pg)
|
|
|
|
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
|
|
|
|
self._bucket_store = BucketStore(self.dp_pg)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
# moe param should not be stored in working_groups
|
|
|
|
# because they have different parallel strategy
|
|
|
|
# so we need to store them separately in param_groups
|
|
|
|
# instead of working_groups
|
|
|
|
moe_params = list()
|
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
# iterate over the param group in the optimizer
|
|
|
|
# partition these param groups for data parallel training
|
|
|
|
# and add buffers to parameter store for future access
|
2023-01-03 09:22:34 +00:00
|
|
|
for group_id, param_group in enumerate(self.optim.param_groups):
|
2023-02-13 10:00:16 +00:00
|
|
|
group_params = list()
|
2023-09-19 06:20:26 +00:00
|
|
|
for param in param_group["params"]:
|
2023-02-13 10:00:16 +00:00
|
|
|
if param.requires_grad:
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
# skip moe param
|
|
|
|
if is_moe_tensor(param):
|
|
|
|
moe_params.append(param)
|
|
|
|
continue
|
2023-02-13 10:00:16 +00:00
|
|
|
group_params.append(param)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-04-27 10:43:14 +00:00
|
|
|
# add the working params to working_param_groups for bookkeeping
|
|
|
|
self._working_param_groups[group_id] = group_params
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
master_param_current_rank = self._create_master_param_current_rank(group_params)
|
|
|
|
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
# need to replace the params in the `params` field in the optimizer
|
|
|
|
# so that when the optimizer calls step(), it only updates the tensors
|
|
|
|
# managed by this data parallel rank
|
2023-09-19 06:20:26 +00:00
|
|
|
param_group["params"] = master_param_current_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
# if there are moe params, store in addtional group in optim
|
|
|
|
if len(moe_params) > 0:
|
|
|
|
param_group = dict()
|
|
|
|
for key, value in self.optim.param_groups[0].items():
|
|
|
|
if key != "params":
|
|
|
|
param_group[key] = value
|
|
|
|
param_group["params"] = moe_params
|
|
|
|
self.optim.param_groups.append(param_group)
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# intialize communication stream for
|
|
|
|
# communication-compuation overlapping
|
2022-11-11 01:26:40 +00:00
|
|
|
if self._overlap_communication:
|
2023-11-20 08:12:41 +00:00
|
|
|
self._comm_stream = device_utils.Stream()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
# reduction hook is only used if overlapping communication
|
|
|
|
# or stage 2 is used
|
|
|
|
# if it is stage 1 without overlapping, no hook will be attached
|
|
|
|
if self._overlap_communication or self._partition_grads:
|
|
|
|
self._attach_reduction_hook()
|
|
|
|
|
2023-06-05 07:58:31 +00:00
|
|
|
# initialize mixed precision mixin
|
|
|
|
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
|
|
|
if self._dtype is torch.float16:
|
2023-09-19 06:20:26 +00:00
|
|
|
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
|
|
|
|
self.num_param_groups,
|
|
|
|
self._grad_store,
|
|
|
|
initial_scale=initial_scale,
|
|
|
|
min_scale=min_scale,
|
|
|
|
growth_factor=growth_factor,
|
|
|
|
backoff_factor=backoff_factor,
|
|
|
|
growth_interval=growth_interval,
|
|
|
|
hysteresis=hysteresis,
|
|
|
|
max_scale=max_scale,
|
|
|
|
)
|
2023-06-05 07:58:31 +00:00
|
|
|
elif self._dtype is torch.bfloat16:
|
|
|
|
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
|
|
|
|
2023-01-03 09:22:34 +00:00
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
return self._dtype
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def num_param_groups(self):
|
2023-04-27 10:43:14 +00:00
|
|
|
return len(self._working_param_groups)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-01-13 06:56:17 +00:00
|
|
|
def _sanity_checks(self):
|
2023-11-20 08:12:41 +00:00
|
|
|
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
|
2023-01-13 06:56:17 +00:00
|
|
|
for param_group in self.optim.param_groups:
|
2023-09-19 06:20:26 +00:00
|
|
|
group_params = param_group["params"]
|
2023-01-13 06:56:17 +00:00
|
|
|
for param in group_params:
|
2023-09-19 06:20:26 +00:00
|
|
|
assert (
|
|
|
|
param.dtype == self._dtype
|
|
|
|
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
2023-01-13 06:56:17 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def _create_master_param_current_rank(self, param_list):
|
|
|
|
# split each param evenly by world size
|
|
|
|
params_current_rank = []
|
2023-09-19 06:20:26 +00:00
|
|
|
device = "cpu" if self._cpu_offload else get_current_device()
|
2023-06-30 07:30:50 +00:00
|
|
|
|
2023-07-06 09:20:04 +00:00
|
|
|
for param in param_list:
|
2023-06-30 07:30:50 +00:00
|
|
|
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
|
|
|
self._param_store.record_param_padding_size(param, padding_size)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
if padding_size > 0:
|
|
|
|
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
2023-10-13 07:57:45 +00:00
|
|
|
# reset working params' ptr when no master weights
|
|
|
|
if self._master_weights == False:
|
|
|
|
param.data = padding_param[: param.numel()].view(param.shape)
|
2023-06-30 07:30:50 +00:00
|
|
|
else:
|
|
|
|
padding_param = param.data.view(-1)
|
2023-11-02 02:21:24 +00:00
|
|
|
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
|
|
|
|
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
|
|
|
|
splited_params = splited_params[self.moe_extra_dp_pg_rank]
|
|
|
|
else:
|
|
|
|
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
|
|
|
splited_params = splited_params[self._local_rank]
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-10-13 07:57:45 +00:00
|
|
|
# use fp32 when master_weights is True
|
|
|
|
if self._master_weights is True:
|
2023-11-02 02:21:24 +00:00
|
|
|
splited_param_current_rank = splited_params.detach().float().to(device)
|
2023-10-13 07:57:45 +00:00
|
|
|
else:
|
2023-11-02 02:21:24 +00:00
|
|
|
splited_param_current_rank = splited_params
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
params_current_rank.append(splited_param_current_rank)
|
|
|
|
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
return params_current_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-01-18 02:36:10 +00:00
|
|
|
###########################
|
|
|
|
# Backward Reduction Hook #
|
|
|
|
###########################
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def _grad_handler(self, param, group_id, grad):
|
|
|
|
# if run with no_sync context, would not sync grad when backward
|
|
|
|
if self.require_grad_sync:
|
|
|
|
self._add_to_bucket(param, group_id)
|
2023-01-18 02:36:10 +00:00
|
|
|
return grad
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
def _attach_reduction_hook(self):
|
2023-04-27 10:43:14 +00:00
|
|
|
# we iterate over the working params
|
2022-11-11 01:26:40 +00:00
|
|
|
# on each param, we register a hook to its AccumulateGrad object
|
|
|
|
for group_id in range(self.num_param_groups):
|
2023-04-27 10:43:14 +00:00
|
|
|
param_group = self._working_param_groups[group_id]
|
2022-11-11 01:26:40 +00:00
|
|
|
for param in param_group:
|
|
|
|
if param.requires_grad:
|
2023-06-30 07:30:50 +00:00
|
|
|
param.register_hook(partial(self._grad_handler, param, group_id))
|
2023-01-18 02:36:10 +00:00
|
|
|
|
|
|
|
#######################
|
|
|
|
# Reduction Functions #
|
|
|
|
#######################
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def _run_reduction(self):
|
|
|
|
if self._bucket_store.num_elements_in_bucket() > 0:
|
|
|
|
self._bucket_store.build_grad_in_bucket()
|
2023-08-11 07:09:24 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
flat_grads = self._bucket_store.get_flatten_grad()
|
|
|
|
flat_grads /= self._world_size
|
|
|
|
else:
|
|
|
|
# record moe and non moe param
|
|
|
|
moe_list = []
|
|
|
|
for param in self._bucket_store._param_list:
|
|
|
|
moe_list.append(is_moe_tensor(param))
|
|
|
|
|
|
|
|
# divide them into different groups
|
|
|
|
moe_grad_list = []
|
|
|
|
non_moe_grad_list = []
|
|
|
|
for grad_list in self._bucket_store._grad_in_bucket.values():
|
|
|
|
non_moe_cur_grad = []
|
|
|
|
moe_cur_grad = []
|
|
|
|
for i in range(len(grad_list)):
|
|
|
|
if moe_list[i] == True:
|
|
|
|
moe_cur_grad.append(grad_list[i])
|
|
|
|
else:
|
|
|
|
non_moe_cur_grad.append(grad_list[i])
|
|
|
|
if len(moe_cur_grad) > 0:
|
|
|
|
moe_grad_list.append(moe_cur_grad)
|
|
|
|
if len(non_moe_cur_grad) > 0:
|
|
|
|
non_moe_grad_list.append(non_moe_cur_grad)
|
|
|
|
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
non_moe_flat_grads = []
|
|
|
|
for grad_list in non_moe_grad_list:
|
|
|
|
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
|
|
|
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
|
|
|
|
non_moe_flat_grads /= self._world_size
|
|
|
|
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
moe_flat_grads = []
|
|
|
|
for grad_list in moe_grad_list:
|
|
|
|
moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
|
|
|
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
|
2023-08-11 07:09:24 +00:00
|
|
|
|
|
|
|
# ready to add other tensors to bucket
|
|
|
|
self._bucket_store.reset_num_elements_in_bucket()
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
if self._overlap_communication:
|
|
|
|
stream = self._comm_stream
|
2023-08-11 07:09:24 +00:00
|
|
|
# in case of the memory being reused in the default stream
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
flat_grads.record_stream(stream)
|
|
|
|
else:
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
non_moe_flat_grads.record_stream(stream)
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
moe_flat_grads.record_stream(stream)
|
2023-08-11 07:09:24 +00:00
|
|
|
# waiting for ops in the default stream finishing
|
2023-11-20 08:12:41 +00:00
|
|
|
stream.wait_stream(device_utils.current_stream())
|
2023-06-30 07:30:50 +00:00
|
|
|
else:
|
2023-11-20 08:12:41 +00:00
|
|
|
stream = device_utils.current_stream()
|
2023-06-30 07:30:50 +00:00
|
|
|
|
2023-11-20 08:12:41 +00:00
|
|
|
with device_utils.stream(stream):
|
2023-06-30 07:30:50 +00:00
|
|
|
group_id = self._bucket_store.current_group_id
|
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
grad_dtype = flat_grads.dtype
|
|
|
|
if self._communication_dtype is not None:
|
|
|
|
flat_grads = flat_grads.to(self._communication_dtype)
|
2023-06-30 07:30:50 +00:00
|
|
|
|
|
|
|
if not self._partition_grads:
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
dist.all_reduce(flat_grads, group=self.dp_pg)
|
|
|
|
if flat_grads.dtype != grad_dtype:
|
|
|
|
flat_grads = flat_grads.to(grad_dtype)
|
|
|
|
|
|
|
|
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
|
|
|
grad_in_bucket = self._bucket_store.get_grad()
|
|
|
|
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
|
|
|
|
|
|
|
# sync extra zero group
|
|
|
|
else:
|
|
|
|
# sync non moe param in global dp group
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
|
|
|
|
flat_grads_per_rank = non_moe_flat_grads.split(
|
|
|
|
non_moe_flat_grads.numel() // self._world_size
|
|
|
|
)
|
|
|
|
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
|
|
|
|
|
|
|
|
# sync moe param only in zero group
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
|
|
|
|
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
|
|
|
|
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
else:
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is None:
|
|
|
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
|
|
|
|
|
|
|
if recieved_grad.dtype != grad_dtype:
|
|
|
|
recieved_grad = recieved_grad.to(grad_dtype)
|
|
|
|
|
|
|
|
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
|
|
|
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
|
|
|
else:
|
|
|
|
# categorize moe and non moe param
|
|
|
|
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
|
|
|
moe_grad_in_bucket_current_rank = []
|
|
|
|
non_moe_grad_in_bucket_current_rank = []
|
|
|
|
for idx, grad in enumerate(grad_in_bucket_current_rank):
|
|
|
|
if moe_list[idx] == True:
|
|
|
|
moe_grad_in_bucket_current_rank.append(grad)
|
|
|
|
else:
|
|
|
|
non_moe_grad_in_bucket_current_rank.append(grad)
|
|
|
|
|
|
|
|
if len(non_moe_grad_list) > 0:
|
|
|
|
flat_grads_list = list(
|
|
|
|
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
|
|
|
|
)
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
|
|
|
self._update_partitoned_grad(
|
|
|
|
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
|
|
|
|
)
|
|
|
|
|
|
|
|
if len(moe_grad_list) > 0:
|
|
|
|
flat_grads_list = list(
|
|
|
|
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
|
|
|
|
)
|
|
|
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
|
|
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
|
|
|
|
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
|
|
|
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
|
|
|
|
for split_recieved_grad in recieved_grad:
|
|
|
|
split_recieved_grad = _unflatten_dense_tensors(
|
|
|
|
split_recieved_grad, moe_grad_in_bucket_current_rank
|
|
|
|
)
|
|
|
|
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
self._add_grad(real_grad, param_slice, group_id, param_id)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self._bucket_store.reset()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
|
|
|
|
for rank, grad_list in enumerate(origin_grad_list):
|
|
|
|
sync_tensor(flat_grad_list[rank], grad_list)
|
|
|
|
for grad in grad_list:
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
self._add_grad(grad, self._world_size, group_id, param_id, rank)
|
|
|
|
|
|
|
|
def _update_partitoned_grad(
|
|
|
|
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
|
|
|
|
) -> None:
|
|
|
|
sync_tensor(flat_grad, origin_grad_list)
|
|
|
|
for grad in origin_grad_list:
|
|
|
|
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
|
|
|
self._add_grad(grad, partition_num, group_id, param_id)
|
|
|
|
|
|
|
|
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
|
|
|
|
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
|
|
|
|
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
|
|
|
else:
|
|
|
|
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def _add_to_bucket(self, param, group_id):
|
2023-01-18 02:36:10 +00:00
|
|
|
param_size = param.numel()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-01-18 02:36:10 +00:00
|
|
|
# check if the bucket is full
|
|
|
|
# if full, will reduce the grads already in the bucket
|
2023-06-30 07:30:50 +00:00
|
|
|
# or got a grad of param from another group
|
2023-01-18 02:36:10 +00:00
|
|
|
# after reduction, the bucket will be empty
|
2023-09-19 06:20:26 +00:00
|
|
|
if (
|
|
|
|
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
|
|
|
|
or group_id != self._bucket_store.current_group_id
|
|
|
|
):
|
2023-06-30 07:30:50 +00:00
|
|
|
self._run_reduction()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
padding_size = self._param_store.get_param_padding_size(param)
|
|
|
|
self._bucket_store.add_param_grad(group_id, param, padding_size)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
################################
|
|
|
|
# torch.optim.Optimizer methods
|
|
|
|
################################
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def backward(self, loss, retain_graph=False):
|
2023-09-19 06:20:26 +00:00
|
|
|
assert not (
|
|
|
|
self._partition_grads and not self.require_grad_sync
|
|
|
|
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
2023-07-04 04:00:33 +00:00
|
|
|
|
2023-06-05 07:58:31 +00:00
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-07-04 04:00:33 +00:00
|
|
|
loss.backward(retain_graph=retain_graph)
|
2023-06-30 07:30:50 +00:00
|
|
|
|
2023-07-04 04:00:33 +00:00
|
|
|
if not self.require_grad_sync:
|
2023-06-30 07:30:50 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
self._reduce_grad(self._partition_grads)
|
2022-11-29 05:00:30 +00:00
|
|
|
|
|
|
|
# clear reduced grads
|
|
|
|
if self._overlap_communication:
|
2023-11-20 08:12:41 +00:00
|
|
|
device_utils.synchronize()
|
2022-11-29 05:00:30 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self.zero_grad()
|
2023-01-29 09:52:58 +00:00
|
|
|
|
2023-08-28 02:51:16 +00:00
|
|
|
def backward_by_grad(self, tensor, grad):
|
2023-09-19 06:20:26 +00:00
|
|
|
assert not (
|
|
|
|
self._partition_grads and not self.require_grad_sync
|
|
|
|
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
2023-08-30 13:29:18 +00:00
|
|
|
|
2023-08-28 02:51:16 +00:00
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
|
|
|
torch.autograd.backward(tensor, grad)
|
|
|
|
|
2023-08-30 13:29:18 +00:00
|
|
|
if not self.require_grad_sync:
|
|
|
|
return
|
|
|
|
self._reduce_grad(self._partition_grads)
|
|
|
|
|
|
|
|
# clear reduced grads
|
|
|
|
if self._overlap_communication:
|
2023-11-20 08:12:41 +00:00
|
|
|
device_utils.synchronize()
|
2023-08-30 13:29:18 +00:00
|
|
|
|
|
|
|
self.zero_grad()
|
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
def zero_grad(self, set_to_none=True):
|
|
|
|
"""
|
|
|
|
Set parameter gradients to zero. If set_to_none = True, gradient
|
|
|
|
will be set to None to save memory.
|
|
|
|
|
|
|
|
:param set_to_none: Whether set the gradient to None. Default value is True.
|
|
|
|
:type set_to_none: bool
|
|
|
|
"""
|
2023-06-05 07:58:31 +00:00
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
self.mixed_precision_mixin.pre_zero_grad()
|
2023-04-27 10:43:14 +00:00
|
|
|
for _, param_group in self._working_param_groups.items():
|
2022-11-11 01:26:40 +00:00
|
|
|
for param in param_group:
|
|
|
|
if set_to_none:
|
|
|
|
param.grad = None
|
|
|
|
else:
|
|
|
|
if param.grad is not None:
|
|
|
|
param.grad.detach()
|
|
|
|
param.grad.zero_()
|
|
|
|
|
|
|
|
####################
|
|
|
|
# Update Parameter #
|
|
|
|
####################
|
|
|
|
|
|
|
|
def step(self, closure=None):
|
2023-09-19 06:20:26 +00:00
|
|
|
assert closure is None, "closure is not supported by step()"
|
2023-07-04 04:00:33 +00:00
|
|
|
if not self.require_grad_sync:
|
2023-06-30 07:30:50 +00:00
|
|
|
return
|
|
|
|
|
2023-06-05 07:58:31 +00:00
|
|
|
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grad_store.reset_all_gradients()
|
2023-04-17 03:25:35 +00:00
|
|
|
if self._verbose:
|
2023-09-19 06:20:26 +00:00
|
|
|
self._logger.info(f"Found overflow. Skip step")
|
2022-11-11 01:26:40 +00:00
|
|
|
self.zero_grad()
|
|
|
|
return
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# record all grads for unscale and clip
|
|
|
|
grad_partition_groups = []
|
2022-11-11 01:26:40 +00:00
|
|
|
norm_groups = []
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# sometimes not all params are 'really' working
|
|
|
|
# for instance, when layer drop, the dropped layer has no grad
|
|
|
|
# and should not be updated
|
|
|
|
real_working_params = dict()
|
|
|
|
real_master_params = dict()
|
|
|
|
grad_index = 0 if self._partition_grads else self._local_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
for group_id in range(self.num_param_groups):
|
2023-06-30 07:30:50 +00:00
|
|
|
master_params = self._master_param_groups_of_current_rank[group_id]
|
|
|
|
real_working_params[group_id] = []
|
|
|
|
real_master_params[group_id] = []
|
|
|
|
for splited_param in master_params:
|
|
|
|
working_param = self._param_store.master_to_working_param[id(splited_param)]
|
|
|
|
# if a working param requires grad and has no grad
|
|
|
|
# it is not 'really' working, e.g. the droped layer
|
|
|
|
# else the splited grad should be attached to the splited param
|
|
|
|
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
|
|
|
if len(grads) > 0:
|
2023-11-02 02:21:24 +00:00
|
|
|
# moe hybrid zero
|
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
real_working_params[group_id].append(working_param)
|
|
|
|
if self._partition_grads:
|
|
|
|
grad = grads
|
|
|
|
else:
|
|
|
|
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
|
|
|
grad = grads[
|
|
|
|
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
|
|
|
|
]
|
|
|
|
grad = flatten(grad)
|
|
|
|
else:
|
|
|
|
real_working_params[group_id].append(working_param)
|
|
|
|
grad = grads[grad_index]
|
2023-10-13 07:57:45 +00:00
|
|
|
# no need to copy fp32 grad if master_weights is False
|
2023-11-02 02:21:24 +00:00
|
|
|
if self._master_weights:
|
|
|
|
grad = grad.to(splited_param.dtype).to(splited_param.device)
|
2023-06-30 07:30:50 +00:00
|
|
|
splited_param.grad = grad
|
|
|
|
grad_partition_groups.append(grad)
|
|
|
|
real_master_params[group_id].append(splited_param)
|
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
# compute norm
|
2023-06-30 07:30:50 +00:00
|
|
|
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
2023-10-12 03:32:37 +00:00
|
|
|
norm_group = self._compute_grad_norm(gradients=working_grads)
|
2022-11-11 01:26:40 +00:00
|
|
|
norm_groups.append(norm_group)
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grad_store.reset_grads_by_group_id(group_id)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# update the params in the optimizer
|
2023-09-19 06:20:26 +00:00
|
|
|
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
# unscale and clip grads
|
|
|
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
2023-06-30 07:30:50 +00:00
|
|
|
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
# TODO: we should store master param for ep
|
|
|
|
if len(self.param_groups) > len(self._working_param_groups):
|
|
|
|
for param in self.param_groups[-1]["params"]:
|
|
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
param.grad = param.grad.to(torch.float32)
|
|
|
|
|
2022-11-11 01:26:40 +00:00
|
|
|
# update the parameters
|
2023-01-03 09:22:34 +00:00
|
|
|
self.optim.step()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
# release the moe gradm
|
|
|
|
if len(self.param_groups) > len(self._working_param_groups):
|
|
|
|
for param in self.param_groups[-1]["params"]:
|
|
|
|
param.grad = None
|
|
|
|
param.data = param.data.to(self._dtype)
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# release the grad
|
|
|
|
grad_partition_groups = []
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# update working partition updated by the current rank
|
2023-11-20 08:12:41 +00:00
|
|
|
device = get_current_device()
|
2022-11-11 01:26:40 +00:00
|
|
|
for group_id in range(self.num_param_groups):
|
2023-09-19 06:20:26 +00:00
|
|
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
2023-06-30 07:30:50 +00:00
|
|
|
for idx, splited_param in enumerate(master_working_param):
|
|
|
|
working_param = real_working_params[group_id][idx]
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
|
|
|
all_splited_param = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
2023-11-02 02:21:24 +00:00
|
|
|
for _ in range(self.moe_extra_dp_pg_size)
|
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(
|
|
|
|
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
|
|
|
|
)
|
2023-11-02 02:21:24 +00:00
|
|
|
else:
|
|
|
|
all_splited_param = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
2023-11-02 02:21:24 +00:00
|
|
|
for _ in range(self._world_size)
|
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
2023-09-19 06:20:26 +00:00
|
|
|
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
|
|
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-10-12 03:32:37 +00:00
|
|
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
|
|
|
r"""
|
|
|
|
Compute and return the gradient norm for gradient clipping.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
gradients (List[Tensor]): The gradients to compute norm
|
|
|
|
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
float: The total norm of given gradients
|
|
|
|
"""
|
|
|
|
|
|
|
|
if len(gradients) == 0:
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
norm_type = float(norm_type)
|
|
|
|
if norm_type == inf:
|
|
|
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
2023-11-20 08:12:41 +00:00
|
|
|
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
|
2023-10-12 03:32:37 +00:00
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
|
|
|
total_norm = total_norm_cuda.item()
|
|
|
|
|
|
|
|
else:
|
|
|
|
total_norm_exponentiated = 0.0
|
|
|
|
for grad in gradients:
|
|
|
|
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
|
|
|
total_norm_exponentiated += grad_norm_exponentiated
|
|
|
|
|
|
|
|
# Sum across all model parallel GPUs.
|
2023-11-20 08:12:41 +00:00
|
|
|
total_norm_exponentiated_cuda = torch.tensor(
|
|
|
|
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
|
|
|
|
)
|
2023-10-12 03:32:37 +00:00
|
|
|
torch.distributed.all_reduce(
|
|
|
|
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
|
|
|
)
|
|
|
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
return total_norm
|
|
|
|
|
2023-04-27 10:43:14 +00:00
|
|
|
#############################
|
|
|
|
# Mixed Precision Utilities #
|
|
|
|
#############################
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
|
|
|
# compute combined scale factor for this group
|
2023-06-05 07:58:31 +00:00
|
|
|
div_scale = 1.0
|
|
|
|
if self.mixed_precision_mixin is not None:
|
|
|
|
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if self._clip_grad_norm > 0.0:
|
2022-11-11 01:26:40 +00:00
|
|
|
# norm is in fact norm*scale
|
2023-06-05 07:58:31 +00:00
|
|
|
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
2022-11-11 01:26:40 +00:00
|
|
|
if clip > 1:
|
2023-06-05 07:58:31 +00:00
|
|
|
div_scale = clip * div_scale
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
for grad in grad_groups_flat:
|
2023-09-19 06:20:26 +00:00
|
|
|
grad.data.mul_(1.0 / div_scale)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
############################
|
|
|
|
# Gradient Synchronization #
|
|
|
|
############################
|
|
|
|
|
2023-07-11 10:03:13 +00:00
|
|
|
# this method is used to sync gradient manually
|
2023-11-03 05:32:43 +00:00
|
|
|
def _sync_grad(self):
|
2023-07-11 10:03:13 +00:00
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
param_group = self._working_param_groups[group_id]
|
|
|
|
for param in param_group:
|
|
|
|
if param.requires_grad and param.grad is not None:
|
|
|
|
self._add_to_bucket(param, group_id)
|
|
|
|
|
|
|
|
self._run_reduction()
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def _reduce_grad(self, partition_grad):
|
|
|
|
# if not overlapping communication (no reduction hook is attached) when zero1
|
2022-11-11 01:26:40 +00:00
|
|
|
# we need to manually reduce these gradients
|
2023-06-30 07:30:50 +00:00
|
|
|
if not partition_grad and not self._overlap_communication:
|
2023-11-03 05:32:43 +00:00
|
|
|
self._sync_grad()
|
2023-07-11 10:03:13 +00:00
|
|
|
else:
|
|
|
|
self._run_reduction()
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
# this context comes from pytorch DDP
|
|
|
|
@contextmanager
|
|
|
|
def no_sync(self):
|
|
|
|
old_require_grad_sync = self.require_grad_sync
|
|
|
|
self.require_grad_sync = False
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
self.require_grad_sync = old_require_grad_sync
|
2023-07-06 09:20:04 +00:00
|
|
|
|
|
|
|
##############
|
|
|
|
# State Dict #
|
|
|
|
##############
|
2023-07-11 10:03:13 +00:00
|
|
|
|
|
|
|
def _pack_state(self, state: Dict) -> Dict:
|
2023-07-06 09:20:04 +00:00
|
|
|
# comes from pytorch optimizer.state_dict()
|
|
|
|
param_mappings = {}
|
|
|
|
start_index = 0
|
|
|
|
|
|
|
|
def pack_group(group):
|
|
|
|
nonlocal start_index
|
2023-09-19 06:20:26 +00:00
|
|
|
packed = {k: v for k, v in group.items() if k != "params"}
|
2023-07-06 09:20:04 +00:00
|
|
|
param_mappings.update(
|
2023-09-19 06:20:26 +00:00
|
|
|
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
|
|
|
|
)
|
|
|
|
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
|
|
|
|
start_index += len(packed["params"])
|
2023-07-06 09:20:04 +00:00
|
|
|
return packed
|
|
|
|
|
2023-07-11 10:03:13 +00:00
|
|
|
param_groups = [pack_group(g) for g in self.optim.param_groups]
|
2023-07-06 09:20:04 +00:00
|
|
|
# Remap state to use order indices as keys
|
|
|
|
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"state": packed_state, "param_groups": param_groups}
|
2023-07-06 09:20:04 +00:00
|
|
|
|
2023-07-11 10:03:13 +00:00
|
|
|
def state_dict(self) -> Dict:
|
2023-07-06 09:20:04 +00:00
|
|
|
"""Return a state_dict same with DDP
|
|
|
|
|
|
|
|
Returns:
|
2023-07-11 10:03:13 +00:00
|
|
|
Dict: the pytorch form state_dict
|
2023-07-06 09:20:04 +00:00
|
|
|
"""
|
|
|
|
zero_state = dict()
|
2023-11-20 08:12:41 +00:00
|
|
|
device = get_current_device()
|
2023-07-06 09:20:04 +00:00
|
|
|
for param, state in self.optim.state.items():
|
|
|
|
zero_state[param] = copy.deepcopy(state)
|
|
|
|
for k, v in state.items():
|
2023-09-19 06:20:26 +00:00
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
2023-07-06 09:20:04 +00:00
|
|
|
working_param = self._param_store.master_to_working_param[id(param)]
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
gather_tensor = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
2023-11-02 02:21:24 +00:00
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
2023-11-02 02:21:24 +00:00
|
|
|
else:
|
|
|
|
gather_tensor = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
2023-11-02 02:21:24 +00:00
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
|
2023-09-19 06:20:26 +00:00
|
|
|
param_state = (
|
|
|
|
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
)
|
2023-07-06 09:20:04 +00:00
|
|
|
zero_state[param][k] = param_state
|
|
|
|
|
|
|
|
states_dict = self._pack_state(zero_state)
|
|
|
|
|
|
|
|
return states_dict
|
|
|
|
|
2023-07-11 10:03:13 +00:00
|
|
|
def load_state_dict(self, state_dict: Dict):
|
2023-07-06 09:20:04 +00:00
|
|
|
"""Load state dict, requires the state_dict be the pytorch form
|
|
|
|
|
|
|
|
Args:
|
|
|
|
state_dict (dict): A pytorch form state_dict
|
|
|
|
"""
|
|
|
|
zero_state_dict = copy.deepcopy(state_dict)
|
2023-09-19 06:20:26 +00:00
|
|
|
for param_idx, state in zero_state_dict["state"].items():
|
2023-07-06 09:20:04 +00:00
|
|
|
for k, v in state.items():
|
2023-09-19 06:20:26 +00:00
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
2023-07-06 09:20:04 +00:00
|
|
|
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
|
|
|
|
with torch.no_grad():
|
|
|
|
v = v.flatten()
|
|
|
|
if padding_size > 0:
|
|
|
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
|
|
|
|
else:
|
|
|
|
v_list = v.split(v.numel() // self._world_size)
|
|
|
|
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
2023-07-06 09:20:04 +00:00
|
|
|
|
|
|
|
self.optim.load_state_dict(zero_state_dict)
|
2023-07-11 10:03:13 +00:00
|
|
|
|
|
|
|
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
|
|
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
|
|
|
Only include the 'state' in state_dict.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
Iterator[OrderedDict]: A generator of state dict shard
|
|
|
|
"""
|
|
|
|
ret_block = dict()
|
|
|
|
ret_block_size = 0
|
|
|
|
|
2023-11-20 08:12:41 +00:00
|
|
|
device = get_current_device()
|
2023-09-19 06:20:26 +00:00
|
|
|
local_states = self.optim.state_dict()["state"]
|
2023-07-11 10:03:13 +00:00
|
|
|
for param_idx, states in local_states.items():
|
|
|
|
current_block_size = 0
|
|
|
|
current_block = copy.deepcopy(states)
|
|
|
|
|
|
|
|
# find the working param of current param_id
|
|
|
|
for group_id, pg in self._master_param_groups_of_current_rank.items():
|
|
|
|
if (group_id + 1) * len(pg) < param_idx:
|
|
|
|
continue
|
|
|
|
master_param = pg[param_idx - (group_id) * len(pg)]
|
|
|
|
working_param = self._param_store.master_to_working_param[id(master_param)]
|
|
|
|
|
|
|
|
for k, v in states.items():
|
2023-09-19 06:20:26 +00:00
|
|
|
if isinstance(v, torch.Tensor) and k != "step":
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
|
|
|
state_tensor = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
2023-11-02 02:21:24 +00:00
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
2023-11-02 02:21:24 +00:00
|
|
|
else:
|
|
|
|
state_tensor = [
|
2023-11-20 08:12:41 +00:00
|
|
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
2023-11-02 02:21:24 +00:00
|
|
|
]
|
2023-11-20 08:12:41 +00:00
|
|
|
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
|
2023-09-19 06:20:26 +00:00
|
|
|
state_tensor = (
|
|
|
|
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
|
|
|
)
|
2023-07-11 10:03:13 +00:00
|
|
|
current_block_size += state_tensor.numel()
|
|
|
|
current_block[k] = state_tensor
|
|
|
|
|
|
|
|
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
|
|
|
yield ret_block, ret_block_size
|
|
|
|
ret_block = dict()
|
|
|
|
ret_block_size = 0
|
|
|
|
|
|
|
|
ret_block[param_idx] = current_block
|
|
|
|
ret_block_size += current_block_size
|
|
|
|
|
|
|
|
yield ret_block, ret_block_size
|
2023-09-05 07:04:02 +00:00
|
|
|
|
|
|
|
def update_master_params(self, model: nn.Module) -> None:
|
|
|
|
"""Update master params from working params
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to update master params
|
|
|
|
"""
|
|
|
|
for p in model.parameters():
|
|
|
|
p_id = id(p)
|
|
|
|
if p_id in self._param_store.working_to_master_param:
|
|
|
|
master_param = self._param_store.working_to_master_param[p_id]
|
|
|
|
padding_size = self._param_store.get_param_padding_size(p)
|
|
|
|
working_param = p.data.view(-1)
|
|
|
|
if padding_size > 0:
|
|
|
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
2023-11-02 02:21:24 +00:00
|
|
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
|
|
|
|
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
|
|
|
else:
|
|
|
|
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
2023-09-20 10:29:37 +00:00
|
|
|
|
|
|
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
|
|
|
return self._param_store.working_to_master_param
|
|
|
|
|
|
|
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
|
|
|
return self._param_store.master_to_working_param
|